Skip to content

Commit 2bc0925

Browse files
author
Saurav Agarwal
committed
update plots
1 parent 0c570e3 commit 2bc0925

File tree

3 files changed

+111
-47
lines changed

3 files changed

+111
-47
lines changed

python/evaluators/eval.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def __init__(self, in_config):
5252
self.num_features = self.cc_params.pNumGaussianFeatures
5353
self.num_envs = self.config["NumEnvironments"]
5454
self.num_steps = self.config["NumSteps"]
55-
os.makedirs(self.env_dir + "/init_maps", exist_ok=True)
5655

5756
self.columns = [
5857
BarColumn(bar_width=None),
@@ -92,7 +91,6 @@ def evaluate(self, save=True):
9291
env_main.WriteEnvironment(pos_file, env_file)
9392
world_idf = env_main.GetWorldIDFObject()
9493

95-
# env_main.PlotInitMap(self.env_dir + "/init_maps", f"{env_count}")
9694
robot_init_pos = env_main.GetRobotPositions(force_no_noise=True)
9795

9896
for controller_id in range(self.num_controllers):

python/evaluators/eval_area_coverage.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def __init__(self, in_config):
5252
self.num_features = self.cc_params.pNumGaussianFeatures
5353
self.num_envs = self.config["NumEnvironments"]
5454
self.num_steps = self.config["NumSteps"]
55-
os.makedirs(self.env_dir + "/init_maps", exist_ok=True)
5655

5756
self.columns = [
5857
BarColumn(bar_width=None),
@@ -93,7 +92,6 @@ def evaluate(self, save=True):
9392
env_main.WriteEnvironment(pos_file, env_file)
9493
world_idf = env_main.GetWorldIDFObject()
9594

96-
# env_main.PlotInitMap(self.env_dir + "/init_maps", f"{env_count}")
9795
robot_init_pos = env_main.GetRobotPositions(force_no_noise=True)
9896

9997
for controller_id in range(self.num_controllers):

utils/scripts/plot_costs_time.py

Lines changed: 111 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,20 @@ def __init__(self, base_dir, csv_file_name="eval.csv"):
1717
self.csv_file_name = csv_file_name
1818
# Catppuccin colors
1919
self.colors = [
20-
'rgba(239, 159, 118, 1.0)', # Peach
21-
'rgba(166, 209, 137, 1.0)', # Green
22-
'rgba(202, 158, 230, 1.0)', # Mauve
23-
'rgba(133, 193, 220, 1.0)', # Sapphire
24-
'rgba(231, 130, 132, 1.0)', # Red
25-
'rgba(129, 200, 190, 1.0)', # Teal
26-
'rgba(242, 213, 207, 1.0)', # Rosewater
27-
'rgba(229, 200, 144, 1.0)', # Yellow
28-
'rgba(108, 111, 133, 1.0)', # subtext0
29-
]
20+
'rgba(239, 159, 118, 1.0)', # Peach
21+
'rgba(166, 209, 137, 1.0)', # Green
22+
'rgba(202, 158, 230, 1.0)', # Mauve
23+
'rgba(133, 193, 220, 1.0)', # Sapphire
24+
'rgba(231, 130, 132, 1.0)', # Red
25+
'rgba(129, 200, 190, 1.0)', # Teal
26+
'rgba(242, 213, 207, 1.0)', # Rosewater
27+
'rgba(229, 200, 144, 1.0)', # Yellow
28+
'rgba(108, 111, 133, 1.0)', # subtext0
29+
]
3030
self.num_controllers = len(self.controller_dirs)
3131
self.num_envs = 0
3232
self.num_steps = 0
33+
self.time_steps = None
3334
self.all_costs = None
3435
self.best_envs = None
3536

@@ -56,6 +57,7 @@ def load_and_normalize_costs(self):
5657
costs_dict[controller_dir] = costs
5758
self.num_envs = costs_dict[self.controller_dirs[0]].shape[0]
5859
self.num_steps = costs_dict[self.controller_dirs[0]].shape[1]
60+
self.time_steps = np.arange(self.num_steps)
5961
self.all_costs = np.zeros((self.num_controllers, self.num_envs, self.num_steps))
6062
for idx, controller_dir in enumerate(self.controller_dirs):
6163
self.all_costs[idx] = costs_dict[controller_dir]
@@ -75,75 +77,141 @@ def compute_best_num_envs(self, costs_dict):
7577

7678
def plot_costs(self, costs_dict):
7779
"""Plot the normalized costs over time for each controller."""
78-
fig = make_subplots(rows=3, cols=1, vertical_spacing=0.05, shared_xaxes=True, specs=[[{'rowspan': 2}], [{}], [{}]])
80+
fig = go.Figure()
7981
for idx, controller_dir in enumerate(self.controller_dirs):
8082
costs = costs_dict[controller_dir]
8183
mean_cost = np.mean(costs, axis=0)
8284
std_cost = np.std(costs, axis=0)
83-
time_steps = np.arange(costs.shape[1])
8485
color = self.colors[idx % len(self.colors)] # Cycle through colors
85-
86+
8687
# Shaded area for standard deviation
8788
fig.add_trace(go.Scatter(
88-
x=np.concatenate([time_steps, time_steps[::-1]]),
89+
x=np.concatenate([self.time_steps, self.time_steps[::-1]]),
8990
y=np.concatenate([mean_cost + std_cost, (mean_cost - std_cost)[::-1]]),
9091
fill="toself",
9192
fillcolor=color.replace('1.0', '0.2'),
9293
line=dict(color='rgba(255,255,255,0)'),
9394
legendgroup=controller_dir,
9495
showlegend=False,
95-
),
96-
row=1, col=1)
96+
visible=True,
97+
))
9798

9899
for idx, controller_dir in enumerate(self.controller_dirs):
99100
costs = costs_dict[controller_dir]
100101
mean_cost = np.mean(costs, axis=0)
101-
std_cost = np.std(costs, axis=0)
102-
time_steps = np.arange(costs.shape[1])
103102
color = self.colors[idx % len(self.colors)] # Cycle through colors
104103

105-
best_envs = self.best_envs[idx]
106-
107104
# Mean cost line
108105
fig.add_trace(go.Scatter(
109-
x=time_steps,
106+
x=self.time_steps,
110107
y=mean_cost,
111108
mode="lines",
112-
name="",
109+
name=controller_dir,
113110
line=dict(color=color),
114111
legendgroup=controller_dir,
115-
legendgrouptitle_text=controller_dir,
116-
),
117-
row=1, col=1)
112+
visible=True,
113+
))
118114

115+
for idx, controller_dir in enumerate(self.controller_dirs):
116+
best_envs = self.best_envs[idx]
117+
color = self.colors[idx % len(self.colors)] # Cycle through colors
119118
fig.add_trace(go.Scatter(
120-
x=time_steps,
119+
x=self.time_steps,
121120
y=best_envs,
122121
mode="lines",
123-
showlegend=False,
122+
showlegend=True,
123+
name=controller_dir,
124124
line=dict(color=color),
125125
legendgroup=controller_dir,
126-
),
127-
row=3, col=1)
128-
126+
visible=False,
127+
))
128+
129+
for idx, controller_dir in enumerate(self.controller_dirs):
130+
final_costs = costs_dict[controller_dir][:, -1]
131+
color = self.colors[idx % len(self.colors)]
132+
fig.add_trace(
133+
go.Violin(
134+
y=final_costs,
135+
name=controller_dir,
136+
line_color=color,
137+
box_visible=True,
138+
meanline_visible=True,
139+
showlegend=False,
140+
points="all",visible=False,),
141+
)
142+
143+
144+
costs_button = dict(label="Costs",
145+
method="update",
146+
args=[{"visible": self.visibility_masking([True, True, False, False])},
147+
{"xaxis.title.text": "Time Steps",
148+
"yaxis.title.text": "Normalized Cost",
149+
"xaxis.type": "linear"}])
150+
151+
costs_button_wo_std = dict(label="Costs (mean only)",
152+
method="update",
153+
args=[{"visible": self.visibility_masking([False, True, False, False])},
154+
{"xaxis.title.text": "Time Steps",
155+
"yaxis.title.text": "Normalized Cost",
156+
"xaxis.type": "linear"}])
157+
158+
best_envs_button = dict(label="Best Environments",
159+
method="update",
160+
args=[{"visible": self.visibility_masking([False, False, True, False])},
161+
{"xaxis.title.text": "Time Steps",
162+
"yaxis.title.text": "Number of Best Environments",
163+
"xaxis.type": "linear"}])
164+
165+
viols_button = dict(label="Violin Plots",
166+
method="update",
167+
args=[{"visible": self.visibility_masking([False, False, False, True])},
168+
{"xaxis.title.text": "",
169+
"yaxis.title.text": "Final Normalized Cost",
170+
"xaxis.type": "category"}])
129171

130-
# Update plot layout
131172
fig.update_layout(
132-
yaxis_title="Normalized cost",
133-
legend=dict(
134-
# orientation="h",
135-
# xanchor="right",
136-
x=1,
137-
y=1,
138-
bgcolor="rgba(255, 255, 255, 0.8)"
139-
140-
),
141-
yaxis3_title="Number of Best Environments",
142-
xaxis3_title="Time Step",
143-
)
173+
updatemenus=[
174+
dict(
175+
type="buttons",
176+
direction="left",
177+
buttons=[costs_button, costs_button_wo_std, best_envs_button, viols_button],
178+
showactive=True,
179+
x=0.01, y=1.05,
180+
xanchor='left',
181+
yanchor='top'
182+
)
183+
],
184+
legend=dict(x=1.01, y=1),
185+
autosize=True,
186+
template="plotly_white",
187+
xaxis=dict(
188+
title="Time Steps",
189+
showgrid=True, # Show vertical grid lines
190+
gridcolor='lightgray', # Color of grid lines
191+
gridwidth=1,
192+
mirror=True,
193+
ticks='outside',
194+
showline=True,
195+
# make axis lines thicker
196+
linecolor='black',
197+
),
198+
yaxis=dict(
199+
title="Normalized Cost",
200+
showgrid=True,
201+
gridcolor='lightgray',
202+
gridwidth=1,
203+
mirror=True,
204+
ticks='outside',
205+
showline=True,
206+
linecolor='black',
207+
),)
144208

145209
return fig
146210

211+
def visibility_masking(self, group_masks):
212+
return [val for group_visible in group_masks for val in [group_visible] * self.num_controllers]
213+
214+
147215
def run_analysis(self):
148216
"""Load data, generate plots, and output results."""
149217
costs_dict = self.load_and_normalize_costs()

0 commit comments

Comments
 (0)