@@ -17,19 +17,20 @@ def __init__(self, base_dir, csv_file_name="eval.csv"):
1717 self .csv_file_name = csv_file_name
1818 # Catppuccin colors
1919 self .colors = [
20- 'rgba(239, 159, 118, 1.0)' , # Peach
21- 'rgba(166, 209, 137, 1.0)' , # Green
22- 'rgba(202, 158, 230, 1.0)' , # Mauve
23- 'rgba(133, 193, 220, 1.0)' , # Sapphire
24- 'rgba(231, 130, 132, 1.0)' , # Red
25- 'rgba(129, 200, 190, 1.0)' , # Teal
26- 'rgba(242, 213, 207, 1.0)' , # Rosewater
27- 'rgba(229, 200, 144, 1.0)' , # Yellow
28- 'rgba(108, 111, 133, 1.0)' , # subtext0
29- ]
20+ 'rgba(239, 159, 118, 1.0)' , # Peach
21+ 'rgba(166, 209, 137, 1.0)' , # Green
22+ 'rgba(202, 158, 230, 1.0)' , # Mauve
23+ 'rgba(133, 193, 220, 1.0)' , # Sapphire
24+ 'rgba(231, 130, 132, 1.0)' , # Red
25+ 'rgba(129, 200, 190, 1.0)' , # Teal
26+ 'rgba(242, 213, 207, 1.0)' , # Rosewater
27+ 'rgba(229, 200, 144, 1.0)' , # Yellow
28+ 'rgba(108, 111, 133, 1.0)' , # subtext0
29+ ]
3030 self .num_controllers = len (self .controller_dirs )
3131 self .num_envs = 0
3232 self .num_steps = 0
33+ self .time_steps = None
3334 self .all_costs = None
3435 self .best_envs = None
3536
@@ -56,6 +57,7 @@ def load_and_normalize_costs(self):
5657 costs_dict [controller_dir ] = costs
5758 self .num_envs = costs_dict [self .controller_dirs [0 ]].shape [0 ]
5859 self .num_steps = costs_dict [self .controller_dirs [0 ]].shape [1 ]
60+ self .time_steps = np .arange (self .num_steps )
5961 self .all_costs = np .zeros ((self .num_controllers , self .num_envs , self .num_steps ))
6062 for idx , controller_dir in enumerate (self .controller_dirs ):
6163 self .all_costs [idx ] = costs_dict [controller_dir ]
@@ -75,75 +77,141 @@ def compute_best_num_envs(self, costs_dict):
7577
7678 def plot_costs (self , costs_dict ):
7779 """Plot the normalized costs over time for each controller."""
78- fig = make_subplots ( rows = 3 , cols = 1 , vertical_spacing = 0.05 , shared_xaxes = True , specs = [[{ 'rowspan' : 2 }], [{}], [{}]] )
80+ fig = go . Figure ( )
7981 for idx , controller_dir in enumerate (self .controller_dirs ):
8082 costs = costs_dict [controller_dir ]
8183 mean_cost = np .mean (costs , axis = 0 )
8284 std_cost = np .std (costs , axis = 0 )
83- time_steps = np .arange (costs .shape [1 ])
8485 color = self .colors [idx % len (self .colors )] # Cycle through colors
85-
86+
8687 # Shaded area for standard deviation
8788 fig .add_trace (go .Scatter (
88- x = np .concatenate ([time_steps , time_steps [::- 1 ]]),
89+ x = np .concatenate ([self . time_steps , self . time_steps [::- 1 ]]),
8990 y = np .concatenate ([mean_cost + std_cost , (mean_cost - std_cost )[::- 1 ]]),
9091 fill = "toself" ,
9192 fillcolor = color .replace ('1.0' , '0.2' ),
9293 line = dict (color = 'rgba(255,255,255,0)' ),
9394 legendgroup = controller_dir ,
9495 showlegend = False ,
95- ) ,
96- row = 1 , col = 1 )
96+ visible = True ,
97+ ) )
9798
9899 for idx , controller_dir in enumerate (self .controller_dirs ):
99100 costs = costs_dict [controller_dir ]
100101 mean_cost = np .mean (costs , axis = 0 )
101- std_cost = np .std (costs , axis = 0 )
102- time_steps = np .arange (costs .shape [1 ])
103102 color = self .colors [idx % len (self .colors )] # Cycle through colors
104103
105- best_envs = self .best_envs [idx ]
106-
107104 # Mean cost line
108105 fig .add_trace (go .Scatter (
109- x = time_steps ,
106+ x = self . time_steps ,
110107 y = mean_cost ,
111108 mode = "lines" ,
112- name = "" ,
109+ name = controller_dir ,
113110 line = dict (color = color ),
114111 legendgroup = controller_dir ,
115- legendgrouptitle_text = controller_dir ,
116- ),
117- row = 1 , col = 1 )
112+ visible = True ,
113+ ))
118114
115+ for idx , controller_dir in enumerate (self .controller_dirs ):
116+ best_envs = self .best_envs [idx ]
117+ color = self .colors [idx % len (self .colors )] # Cycle through colors
119118 fig .add_trace (go .Scatter (
120- x = time_steps ,
119+ x = self . time_steps ,
121120 y = best_envs ,
122121 mode = "lines" ,
123- showlegend = False ,
122+ showlegend = True ,
123+ name = controller_dir ,
124124 line = dict (color = color ),
125125 legendgroup = controller_dir ,
126- ),
127- row = 3 , col = 1 )
128-
126+ visible = False ,
127+ ))
128+
129+ for idx , controller_dir in enumerate (self .controller_dirs ):
130+ final_costs = costs_dict [controller_dir ][:, - 1 ]
131+ color = self .colors [idx % len (self .colors )]
132+ fig .add_trace (
133+ go .Violin (
134+ y = final_costs ,
135+ name = controller_dir ,
136+ line_color = color ,
137+ box_visible = True ,
138+ meanline_visible = True ,
139+ showlegend = False ,
140+ points = "all" ,visible = False ,),
141+ )
142+
143+
144+ costs_button = dict (label = "Costs" ,
145+ method = "update" ,
146+ args = [{"visible" : self .visibility_masking ([True , True , False , False ])},
147+ {"xaxis.title.text" : "Time Steps" ,
148+ "yaxis.title.text" : "Normalized Cost" ,
149+ "xaxis.type" : "linear" }])
150+
151+ costs_button_wo_std = dict (label = "Costs (mean only)" ,
152+ method = "update" ,
153+ args = [{"visible" : self .visibility_masking ([False , True , False , False ])},
154+ {"xaxis.title.text" : "Time Steps" ,
155+ "yaxis.title.text" : "Normalized Cost" ,
156+ "xaxis.type" : "linear" }])
157+
158+ best_envs_button = dict (label = "Best Environments" ,
159+ method = "update" ,
160+ args = [{"visible" : self .visibility_masking ([False , False , True , False ])},
161+ {"xaxis.title.text" : "Time Steps" ,
162+ "yaxis.title.text" : "Number of Best Environments" ,
163+ "xaxis.type" : "linear" }])
164+
165+ viols_button = dict (label = "Violin Plots" ,
166+ method = "update" ,
167+ args = [{"visible" : self .visibility_masking ([False , False , False , True ])},
168+ {"xaxis.title.text" : "" ,
169+ "yaxis.title.text" : "Final Normalized Cost" ,
170+ "xaxis.type" : "category" }])
129171
130- # Update plot layout
131172 fig .update_layout (
132- yaxis_title = "Normalized cost" ,
133- legend = dict (
134- # orientation="h",
135- # xanchor="right",
136- x = 1 ,
137- y = 1 ,
138- bgcolor = "rgba(255, 255, 255, 0.8)"
139-
140- ),
141- yaxis3_title = "Number of Best Environments" ,
142- xaxis3_title = "Time Step" ,
143- )
173+ updatemenus = [
174+ dict (
175+ type = "buttons" ,
176+ direction = "left" ,
177+ buttons = [costs_button , costs_button_wo_std , best_envs_button , viols_button ],
178+ showactive = True ,
179+ x = 0.01 , y = 1.05 ,
180+ xanchor = 'left' ,
181+ yanchor = 'top'
182+ )
183+ ],
184+ legend = dict (x = 1.01 , y = 1 ),
185+ autosize = True ,
186+ template = "plotly_white" ,
187+ xaxis = dict (
188+ title = "Time Steps" ,
189+ showgrid = True , # Show vertical grid lines
190+ gridcolor = 'lightgray' , # Color of grid lines
191+ gridwidth = 1 ,
192+ mirror = True ,
193+ ticks = 'outside' ,
194+ showline = True ,
195+ # make axis lines thicker
196+ linecolor = 'black' ,
197+ ),
198+ yaxis = dict (
199+ title = "Normalized Cost" ,
200+ showgrid = True ,
201+ gridcolor = 'lightgray' ,
202+ gridwidth = 1 ,
203+ mirror = True ,
204+ ticks = 'outside' ,
205+ showline = True ,
206+ linecolor = 'black' ,
207+ ),)
144208
145209 return fig
146210
211+ def visibility_masking (self , group_masks ):
212+ return [val for group_visible in group_masks for val in [group_visible ] * self .num_controllers ]
213+
214+
147215 def run_analysis (self ):
148216 """Load data, generate plots, and output results."""
149217 costs_dict = self .load_and_normalize_costs ()
0 commit comments