Skip to content

Commit 8ee545b

Browse files
author
Saurav Agarwal
committed
Clean up
1 parent d0d1402 commit 8ee545b

File tree

1 file changed

+14
-29
lines changed

1 file changed

+14
-29
lines changed

python/utils/plot_constrained_learning_results.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,46 +7,35 @@
77
plt.rc('text', usetex=True)
88
plt.rc('font', family='serif')
99

10-
# Define the parameters
1110
# eta_duals = [0.1, 1.0, 10.0, 100.0]
1211
# T_0s = [25, 50, 75, 100]
1312
eta_duals = [1]
1413
T_0s = [25]
1514
envs = list(range(100)) # Assuming this corresponds to the environment IDs
1615
eval_dir = sys.argv[1] # Path to the directory containing the evaluation results
1716

18-
# Set fixed axis ranges
19-
x_axis_range = (0, 3000) # Assuming 3000 columns as per your description
20-
y_axis_range = (0, .3) # Adjust this based on the expected range of max values
17+
# x_axis_range = (0, 3000)
18+
y_axis_range = (0, .30)
2119

22-
# Initialize a figure for the grid of plots
2320
fig, axes = plt.subplots(len(eta_duals), len(T_0s), figsize=(24, 24))
2421
# fig.suptitle('Max Objective Values for Different $\\eta$ and $T_0$', fontsize=24)
2522

26-
# Iterate over eta_duals and T_0s to read data and plot
2723
for i, eta_dual in enumerate(eta_duals):
2824
for j, T_0 in enumerate(T_0s):
29-
# Initialize a list to store max values for each environment
3025
all_max_values = []
3126
all_obj_values = []
3227

3328
for env_id in envs:
34-
# Construct the file path
3529
file_path = f"{eval_dir}/{eta_dual}_{T_0}/obj_values_{env_id}.csv"
3630

37-
# Check if the file exists
3831
if os.path.exists(file_path):
39-
# Load the obj_values from the CSV file
4032
obj_values = np.loadtxt(file_path, delimiter=",")
4133
all_obj_values.append(obj_values)
4234

43-
# Compute the max value for each column
4435
max_values = np.max(obj_values, axis=0)
4536

46-
# Append to the list of max values
4737
all_max_values.append(max_values)
4838

49-
# Plot individual environment max values with light, thin lines
5039
if len(eta_duals) > 1 and len(T_0s) > 1:
5140
ax = axes[i, j]
5241
elif len(eta_duals) > 1:
@@ -57,24 +46,20 @@
5746
ax = axes
5847
ax.plot(max_values, color='lightgray', linewidth=0.5)
5948

60-
# Plot the average max values across all environments
61-
if all_max_values:
62-
avg_max_values = np.mean(all_max_values, axis=0)
63-
median_max_values = np.median(all_max_values, axis=0)
64-
interquartile_range = np.percentile(all_max_values, 75, axis=0) - np.percentile(all_max_values, 25, axis=0)
65-
ax.plot(avg_max_values, linewidth=2, label='Average')
66-
ax.plot(median_max_values, color='red', linewidth=2, label='Median')
67-
ax.fill_between(range(len(avg_max_values)), median_max_values - interquartile_range, median_max_values + interquartile_range, color='red', alpha=0.2, label='Interquartile Range')
68-
ax.legend(fontsize=20)
49+
avg_max_values = np.mean(all_max_values, axis=0)
50+
print("Average max values at end: ", np.mean(avg_max_values))
51+
median_max_values = np.median(all_max_values, axis=0)
52+
print("Median max values at end: ", median_max_values[-1])
53+
interquartile_range = np.percentile(all_max_values, 75, axis=0) - np.percentile(all_max_values, 25, axis=0)
54+
ax.plot(avg_max_values, linewidth=2, label='Average')
55+
ax.plot(median_max_values, color='red', linewidth=2, label='Median')
56+
ax.fill_between(range(len(avg_max_values)), median_max_values - interquartile_range, median_max_values + interquartile_range, color='red', alpha=0.2, label='Interquartile Range')
57+
ax.legend(fontsize=20)
6958

70-
# Set plot title and axis ranges
71-
# ax.set_xlim(x_axis_range)
72-
ax.set_ylim(y_axis_range)
73-
ax.set_title(rf'$\eta$: {eta_dual}, $T_0$: {T_0}', fontsize=32)
59+
ax.set_ylim(y_axis_range)
60+
ax.set_title(rf'$\eta$: {eta_dual}, $T_0$: {T_0}', fontsize=32)
7461
ax.tick_params(axis='both', which='major', labelsize=30)
7562

76-
# Adjust layout to prevent overlap
7763
plt.tight_layout()
7864
plt.subplots_adjust(top=0.95)
79-
# plt.show()
80-
plt.savefig(sys.argv[2]) # Save the figure to a file
65+
plt.savefig(sys.argv[2])

0 commit comments

Comments
 (0)