import numpy as np
import random
import matplotlib.pyplot as plt

# -----------------------------
# City Coordinates
# -----------------------------
cities = np.array([
    [0, 0],
    [1, 5],
    [5, 2],
    [6, 6],
    [8, 3],
    [2, 7]
])

num_cities = len(cities)

# -----------------------------
# Compute Distance Matrix
# -----------------------------
distance_matrix = np.zeros((num_cities, num_cities))

for i in range(num_cities):
    for j in range(num_cities):
        if i != j:
            distance_matrix[i][j] = np.linalg.norm(cities[i] - cities[j])

# -----------------------------
# ACO Parameters
# -----------------------------
num_ants = 6
num_iterations = 100

alpha = 1       # pheromone importance
beta = 1        # distance importance
evaporation = 0.5
Q = 1

# -----------------------------
# Initialize Pheromone Matrix
# -----------------------------
pheromone = np.ones((num_cities, num_cities))

best_path = None
best_distance = float('inf')

# -----------------------------
# Helper Functions
# -----------------------------
def path_distance(path):
    total = 0

    for i in range(len(path) - 1):
        total += distance_matrix[path[i]][path[i + 1]]

    # Return to starting city
    total += distance_matrix[path[-1]][path[0]]

    return total


def choose_next_city(current_city, unvisited):
    probabilities = []

    for city in unvisited:
        tau = pheromone[current_city][city] ** alpha
        eta = (1 / distance_matrix[current_city][city]) ** beta

        probabilities.append(tau * eta)

    probabilities = np.array(probabilities)
    probabilities /= probabilities.sum()

    return np.random.choice(list(unvisited), p=probabilities)

# -----------------------------
# Main ACO Loop
# -----------------------------
for iteration in range(num_iterations):

    all_paths = []
    all_distances = []

    for ant in range(num_ants):

        start_city = random.randint(0, num_cities - 1)

        path = [start_city]
        unvisited = set(range(num_cities))
        unvisited.remove(start_city)

        current_city = start_city

        while unvisited:
            next_city = choose_next_city(current_city, unvisited)

            path.append(next_city)
            unvisited.remove(next_city)

            current_city = next_city

        distance = path_distance(path)

        all_paths.append(path)
        all_distances.append(distance)

        # Update best solution
        if distance < best_distance:
            best_distance = distance
            best_path = path

    # -----------------------------
    # Evaporation
    # -----------------------------
    pheromone *= (1 - evaporation)

    # -----------------------------
    # Add New Pheromones
    # -----------------------------
    for path, distance in zip(all_paths, all_distances):

        deposit = Q / distance

        for i in range(len(path) - 1):
            a = path[i]
            b = path[i + 1]

            pheromone[a][b] += deposit
            pheromone[b][a] += deposit

        # closing edge
        pheromone[path[-1]][path[0]] += deposit
        pheromone[path[0]][path[-1]] += deposit

    print(f"Iteration {iteration+1}: Best Distance = {best_distance:.2f}")

# -----------------------------
# Print Final Result
# -----------------------------
print("\nBest Path:")
print(best_path)

print("Best Distance:")
print(best_distance)

# -----------------------------
# Plot Best Tour
# -----------------------------
best_cycle = best_path + [best_path[0]]

x = [cities[i][0] for i in best_cycle]
y = [cities[i][1] for i in best_cycle]

plt.figure(figsize=(8, 6))
plt.plot(x, y, 'o-')

for i, (cx, cy) in enumerate(cities):
    plt.text(cx + 0.1, cy + 0.1, str(i))

plt.title("Best TSP Route using ACO")
plt.xlabel("X")
plt.ylabel("Y")
plt.grid(True)
plt.show()
