|
7 | 7 | plt.rc('text', usetex=True) |
8 | 8 | plt.rc('font', family='serif') |
9 | 9 |
|
10 | | -# Define the parameters |
11 | 10 | # eta_duals = [0.1, 1.0, 10.0, 100.0] |
12 | 11 | # T_0s = [25, 50, 75, 100] |
13 | 12 | eta_duals = [1] |
14 | 13 | T_0s = [25] |
15 | 14 | envs = list(range(100)) # Assuming this corresponds to the environment IDs |
16 | 15 | eval_dir = sys.argv[1] # Path to the directory containing the evaluation results |
17 | 16 |
|
18 | | -# Set fixed axis ranges |
19 | | -x_axis_range = (0, 3000) # Assuming 3000 columns as per your description |
20 | | -y_axis_range = (0, .3) # Adjust this based on the expected range of max values |
| 17 | +# x_axis_range = (0, 3000) |
| 18 | +y_axis_range = (0, .30) |
21 | 19 |
|
22 | | -# Initialize a figure for the grid of plots |
23 | 20 | fig, axes = plt.subplots(len(eta_duals), len(T_0s), figsize=(24, 24)) |
24 | 21 | # fig.suptitle('Max Objective Values for Different $\\eta$ and $T_0$', fontsize=24) |
25 | 22 |
|
26 | | -# Iterate over eta_duals and T_0s to read data and plot |
27 | 23 | for i, eta_dual in enumerate(eta_duals): |
28 | 24 | for j, T_0 in enumerate(T_0s): |
29 | | - # Initialize a list to store max values for each environment |
30 | 25 | all_max_values = [] |
31 | 26 | all_obj_values = [] |
32 | 27 |
|
33 | 28 | for env_id in envs: |
34 | | - # Construct the file path |
35 | 29 | file_path = f"{eval_dir}/{eta_dual}_{T_0}/obj_values_{env_id}.csv" |
36 | 30 |
|
37 | | - # Check if the file exists |
38 | 31 | if os.path.exists(file_path): |
39 | | - # Load the obj_values from the CSV file |
40 | 32 | obj_values = np.loadtxt(file_path, delimiter=",") |
41 | 33 | all_obj_values.append(obj_values) |
42 | 34 |
|
43 | | - # Compute the max value for each column |
44 | 35 | max_values = np.max(obj_values, axis=0) |
45 | 36 |
|
46 | | - # Append to the list of max values |
47 | 37 | all_max_values.append(max_values) |
48 | 38 |
|
49 | | - # Plot individual environment max values with light, thin lines |
50 | 39 | if len(eta_duals) > 1 and len(T_0s) > 1: |
51 | 40 | ax = axes[i, j] |
52 | 41 | elif len(eta_duals) > 1: |
|
57 | 46 | ax = axes |
58 | 47 | ax.plot(max_values, color='lightgray', linewidth=0.5) |
59 | 48 |
|
60 | | - # Plot the average max values across all environments |
61 | | - if all_max_values: |
62 | | - avg_max_values = np.mean(all_max_values, axis=0) |
63 | | - median_max_values = np.median(all_max_values, axis=0) |
64 | | - interquartile_range = np.percentile(all_max_values, 75, axis=0) - np.percentile(all_max_values, 25, axis=0) |
65 | | - ax.plot(avg_max_values, linewidth=2, label='Average') |
66 | | - ax.plot(median_max_values, color='red', linewidth=2, label='Median') |
67 | | - ax.fill_between(range(len(avg_max_values)), median_max_values - interquartile_range, median_max_values + interquartile_range, color='red', alpha=0.2, label='Interquartile Range') |
68 | | - ax.legend(fontsize=20) |
| 49 | + avg_max_values = np.mean(all_max_values, axis=0) |
| 50 | + print("Average max values at end: ", np.mean(avg_max_values)) |
| 51 | + median_max_values = np.median(all_max_values, axis=0) |
| 52 | + print("Median max values at end: ", median_max_values[-1]) |
| 53 | + interquartile_range = np.percentile(all_max_values, 75, axis=0) - np.percentile(all_max_values, 25, axis=0) |
| 54 | + ax.plot(avg_max_values, linewidth=2, label='Average') |
| 55 | + ax.plot(median_max_values, color='red', linewidth=2, label='Median') |
| 56 | + ax.fill_between(range(len(avg_max_values)), median_max_values - interquartile_range, median_max_values + interquartile_range, color='red', alpha=0.2, label='Interquartile Range') |
| 57 | + ax.legend(fontsize=20) |
69 | 58 |
|
70 | | - # Set plot title and axis ranges |
71 | | - # ax.set_xlim(x_axis_range) |
72 | | - ax.set_ylim(y_axis_range) |
73 | | - ax.set_title(rf'$\eta$: {eta_dual}, $T_0$: {T_0}', fontsize=32) |
| 59 | + ax.set_ylim(y_axis_range) |
| 60 | + ax.set_title(rf'$\eta$: {eta_dual}, $T_0$: {T_0}', fontsize=32) |
74 | 61 | ax.tick_params(axis='both', which='major', labelsize=30) |
75 | 62 |
|
76 | | -# Adjust layout to prevent overlap |
77 | 63 | plt.tight_layout() |
78 | 64 | plt.subplots_adjust(top=0.95) |
79 | | -# plt.show() |
80 | | -plt.savefig(sys.argv[2]) # Save the figure to a file |
| 65 | +plt.savefig(sys.argv[2]) |
0 commit comments