Skip to content

Commit 7976657

Browse files
Fix bug and remove simulator (#7)
1 parent 17d324e commit 7976657

File tree

11 files changed

+69
-137
lines changed

11 files changed

+69
-137
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,6 @@ dmypy.json
130130

131131
# Pyre type checker
132132
.pyre/
133+
134+
# Simulator
135+
aco_routing/simulator.py

aco_routing/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from aco_routing.graph import *
2+
from aco_routing.dijkstra import *
3+
from aco_routing.aco import *

aco_routing/aco.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import random
33
from typing import List, Tuple
44

5-
from aco_routing.utils.graph import Graph
6-
from aco_routing.utils.ant import Ant
5+
from aco_routing.graph import Graph
6+
from aco_routing.ant import Ant
77

88

99
@dataclass
@@ -15,11 +15,11 @@ def _forward_ants(self, ants: List[Ant], max_iterations: int) -> None:
1515
1616
Args:
1717
ants (List[Ant]): A List of Ants.
18-
max_iterations (int, optional): The maximum number of steps an ant is allowed is to take in order to reach the destination.
19-
If it fails to find a path, it is tagged as unfit. Defaults to 50.
18+
max_iterations (int): The maximum number of steps an ant is allowed is to take in order to reach the destination.
19+
If it fails to find a path, it is tagged as unfit.
2020
"""
21-
for idx, ant in enumerate(ants):
22-
for i in range(max_iterations):
21+
for _, ant in enumerate(ants):
22+
for _ in range(max_iterations):
2323
if ant.reached_destination():
2424
ant.is_fit = True
2525
break
@@ -31,7 +31,7 @@ def _backward_ants(self, ants: List[Ant]) -> None:
3131
Args:
3232
ants (List[Ant]): A List of Ants.
3333
"""
34-
for idx, ant in enumerate(ants):
34+
for _, ant in enumerate(ants):
3535
if ant.is_fit:
3636
self.graph.deposit_pheromones_along_path(ant.path)
3737

@@ -40,8 +40,8 @@ def _deploy_search_ants(
4040
source: str,
4141
destination: str,
4242
num_ants: int,
43-
random_spawns: bool = False,
44-
cycles: int = 100,
43+
cycles: int,
44+
random_spawns: bool,
4545
max_iterations: int = 50,
4646
) -> None:
4747
"""Deploys search ants which traverse the graph to find the shortest path.
@@ -50,12 +50,12 @@ def _deploy_search_ants(
5050
source (str): The source node in the graph.
5151
destination (str): The destination node in the graph.
5252
num_ants (int): The number of ants to be spawned.
53-
random_spawns (bool): A flag to determine if the ants should be spawned at random nodes or always at the source node.
54-
cycles (int, optional): The number of cycles of generating and deploying ants (forward and backward). Defaults to 100.
55-
max_iterations (int, optional): The maximum number of steps an ant is allowed is to take in order to reach the destination.
56-
If it fails to find a path, it is tagged as unfit. Defaults to 50.
53+
cycles (int): The number of cycles of generating and deploying ants (forward and backward).
54+
random_spawns (bool): Indicates if the search ants should spawn at random nodes in the graph.
55+
max_iterations (int, optional): The maximum number of steps an ant is allowed is to take in order to reach the destination,
56+
after which it is tagged as unfit. Defaults to 50.
5757
"""
58-
for cycle in range(cycles):
58+
for _ in range(cycles):
5959
ants: List[Ant] = []
6060
for _ in range(num_ants):
6161
spawn_point = (
@@ -79,24 +79,42 @@ def _deploy_solution_ant(self, source: str, destination: str) -> List[str]:
7979
List[str]: The shortest path found by the ants (A list of node IDs).
8080
"""
8181
# Spawn an ant which favors pheromone values over edge costs.
82-
ant = Ant(self.graph, source, destination, alpha=0.99, beta=0.01)
82+
ant = Ant(self.graph, source, destination, is_solution_ant=True)
8383
while not ant.reached_destination():
8484
ant.take_step()
8585
return ant.path
8686

8787
def find_shortest_path(
88-
self, source: str, destination: str
88+
self,
89+
source: str,
90+
destination: str,
91+
num_ants: int,
92+
max_iterations: int,
93+
cycles: int,
94+
random_spawn: bool = True,
8995
) -> Tuple[List[str], float]:
9096
"""Finds the shortest path from the source to the destination in the graph using the traditional Ant Colony Optimization technique.
9197
9298
Args:
9399
source (str): The source node in the graph.
94100
destination (str): The destination node in the graph.
101+
num_ants (int): The number of search ants to be deployed.
102+
max_iterations (int): The maximum number of steps an ant is allowed is to take in order to reach the destination,
103+
after which it is tagged as unfit. Defaults to 50.
104+
cycles (int): The number of cycles/waves of search ants to be deployed.
105+
random_spawn (bool, optional): Indicates if the search ants should spawn at random nodes in the graph. Defaults to True.
95106
96107
Returns:
97-
List[str]: The shortest path found by the ants (A list of node IDs).
108+
List[str]: The shortest path found by the ants (a list of node IDs).
98109
float: The total travel time of the shortest path.
99110
"""
100-
self._deploy_search_ants(source, destination, num_ants=20, random_spawns=True)
111+
self._deploy_search_ants(
112+
source,
113+
destination,
114+
num_ants=num_ants,
115+
max_iterations=max_iterations,
116+
cycles=cycles,
117+
random_spawns=random_spawn,
118+
)
101119
shortest_path = self._deploy_solution_ant(source, destination)
102120
return shortest_path, self.graph.compute_path_travel_time(shortest_path)

aco_routing/utils/ant.py renamed to aco_routing/ant.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import random
33
from typing import Dict, List, Set
44

5-
from aco_routing.utils.graph import Edge, Graph
5+
from aco_routing.graph import Edge, Graph
66

77

88
@dataclass
@@ -13,11 +13,12 @@ class Ant:
1313
graph (Graph): The Graph object.
1414
source (str): The source node of the ant.
1515
destination (str): The destination node of the ant.
16-
alpha (float): The amount of importance given to the pheromone by the ant. Defaults to 0.9.
17-
beta (float): The amount of importance given to the travel time value by the ant. Defaults to 0.1.
18-
visited_nodes (Set): A set of nodes that have been visited by the ant.
19-
path (List[str]): A List of node IDs of the path taken by the ant so far.
20-
is_fit (bool): A flag which indicates if the ant has reached the destination (fit) or not (unfit). Defaults to False.
16+
alpha (float, optional): The amount of importance given to the pheromone by the ant. Defaults to 0.9.
17+
beta (float, optional): The amount of importance given to the travel time value by the ant. Defaults to 0.1.
18+
visited_nodes (Set, optional): A set of nodes that have been visited by the ant.
19+
path (List[str], optional): A List of node IDs of the path taken by the ant so far.
20+
is_fit (bool, optional): Indicates if the ant has reached the destination (fit) or not (unfit). Defaults to False.
21+
is_solution_ant (bool, optional): Indicates if the ant is the final/solution ant. Defaults to False.
2122
"""
2223

2324
graph: Graph
@@ -28,6 +29,7 @@ class Ant:
2829
visited_nodes: Set = field(default_factory=set)
2930
path: List[str] = field(default_factory=list)
3031
is_fit: bool = False
32+
is_solution_ant: bool = False
3133

3234
def __post_init__(self) -> None:
3335
self.current_node = self.source
@@ -149,11 +151,19 @@ def _pick_next_node(
149151
Returns:
150152
str: The ID of the next node to be visited by the ant.
151153
"""
154+
if self.is_solution_ant:
155+
# The final/solution ant greedily chooses the next node with the highest pheromone value.
156+
return max(
157+
unvisited_neighbors, key=lambda k: unvisited_neighbors[k].pheromones
158+
)
152159
edges_total = self._calculate_edges_total(unvisited_neighbors, alpha, beta)
160+
153161
probabilities = self._calculate_edge_probabilites(
154162
unvisited_neighbors, edges_total, alpha, beta
155163
)
156164
sorted_probabilities = self._sort_edge_probabilites(probabilities)
165+
166+
# Pick the next node based on the Roulette Wheel selection technique.
157167
return self._choose_neighbor_using_roulette_wheel(sorted_probabilities)
158168

159169
def take_step(self) -> None:
@@ -171,8 +181,11 @@ def take_step(self) -> None:
171181
# Find unvisited neighbors of the current node.
172182
unvisited_neighbors = self._get_unvisited_neighbors(all_neighbors)
173183

174-
# Pick the next node based on the Roulette Wheel selection technique.
184+
# Pick the next node of the ant.
175185
next_node = self._pick_next_node(unvisited_neighbors, self.alpha, self.beta)
176186

187+
if not next_node:
188+
return
189+
177190
self.path.append(next_node)
178191
self.current_node = next_node

aco_routing/dijkstra.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from dataclasses import dataclass
22
from typing import Dict, List, Tuple
33

4-
from aco_routing.utils.graph import Graph
4+
from aco_routing.graph import Graph
55

66

77
@dataclass
88
class Dijkstra:
9-
"""The basline Dijkstra's Algorithm to find the shortest path between 2 nodes in the graph.
9+
"""The basline Dijkstra's Algorithm to find the shortest path between two nodes in the graph.
1010
Reference: https://stackoverflow.com/a/61078380
1111
"""
1212

File renamed without changes.

aco_routing/utils/__init__.py

Whitespace-only changes.

aco_routing/utils/simulator.py

Lines changed: 0 additions & 100 deletions
This file was deleted.

aco_routing/example.py renamed to example.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
from aco_routing.utils.graph import Graph
2-
from aco_routing.dijkstra import Dijkstra
3-
from aco_routing.utils.simulator import Simulator
4-
from aco_routing.aco import ACO
1+
from aco_routing import Graph, Dijkstra, ACO
52

63
graph = Graph()
74

@@ -23,10 +20,10 @@
2320
dijkstra = Dijkstra(graph)
2421
aco = ACO(graph)
2522

23+
aco_path, aco_cost = aco.find_shortest_path(
24+
source, destination, num_ants=100, max_iterations=50, cycles=100
25+
)
2626
dijkstra_path, dijkstra_cost = dijkstra.find_shortest_path(source, destination)
27-
aco_path, aco_cost = aco.find_shortest_path(source, destination)
2827

2928
print(f"ACO - path: {aco_path}, cost: {aco_cost}")
3029
print(f"Dijkstra - path: {dijkstra_path}, cost: {dijkstra_cost}")
31-
32-
Simulator(graph).simulate(source, destination, num_episodes=100, plot=True)

requirements.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +0,0 @@
1-
matplotlib
2-
tqdm

0 commit comments

Comments
 (0)