-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_PS.py
More file actions
124 lines (94 loc) · 3.76 KB
/
train_PS.py
File metadata and controls
124 lines (94 loc) · 3.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# code for parameter sampling
import env
import copy
import numpy as np
import torch
import random
from collections import deque
import time
import matplotlib.pyplot as plt
import os
import json
import pandas as pd
import agent_structure
epi = 1000 # 한 번 학습에 사용할 episode 수
CFenv = env.ConnectFourEnv() # connect4 환경 생성
opAgent = agent_structure.HeuristicAgent() # 상대 agent
optimization_trial = 10 # sampling 시도 횟수
def dict2json(data, filename='parameter sampling.json'):
with open('loss_plot/'+filename, 'w') as f:
json.dump(results, f, indent=4, ensure_ascii=False)
num = 1
while True:
folder_path = "loss_plot/experiment_{}".format(num)
if not os.path.exists(folder_path):
os.makedirs(folder_path)
print(folder_path+" 에 폴더를 만들었습니다.")
break
else: num += 1
batch_size=64
memory_len = 2000
repeat_reward = 1
lrs = [0.01, 0.005, 0.001, 0.0005, 0.0001, 0.00005, 0.00001]
target_updates = [50, 100, 500, 1000]
i=0
# 하이퍼파라미터 무작위 탐색======================================
results = {}
for lr in lrs:
for target_update in target_updates:
CFenv.reset()
# 탐색한 하이퍼파라미터의 범위 지정===============
# lr = 10 ** np.random.uniform(-5, -2)
# batch_size = int(2 ** np.random.uniform(4, 10))
# target_update = int(10 ** np.random.uniform(0, 4))
# memory_len = int(2 ** np.random.uniform(11, 15))
# repeat_reward = int(10 ** np.random.uniform(0,1))
# ================================================
Qagent = agent_structure.ConnectFourDQNAgent(lr=lr, batch_size=batch_size, target_update=target_update, memory_len=memory_len, repeat_reward=repeat_reward)
Qagent.train(epi=epi, env=CFenv, op_model=opAgent)
plt.clf()
plt.plot(Qagent.losses)
plt.savefig('loss_plot/experiment_{}/train_PS_loss_{}.png'.format(num,i))
win, loss, draw = env.compare_model(Qagent, opAgent, n_battle=100)
win_rate = win / (win + loss + draw)
results[(
"order:"+str(i),
"lr:"+str(lr),
"batch_size:"+str(batch_size),
"target_update:"+str(target_update),
"memory_len:"+str(memory_len),
"repeat_reward:"+str(repeat_reward)
)] = win_rate
print("lr: "+str(lr))
print("batch_size: "+str(batch_size))
print("target_update: "+str(target_update))
print("memory_len: "+str(memory_len))
print("repeat_reward: "+str(repeat_reward))
print("win_rate: "+str(win_rate))
i += 1
results = sorted(results.items(), key=lambda x: x[1], reverse=True)
print(results)
filename = 'parameter sampling.json'
with open('loss_plot/experiment_{}/'.format(num)+filename, 'w') as f:
json.dump(results, f, indent=4, ensure_ascii=False)
# json2excel.py 파일 코드를 따로 실행시키는 번거로움을 막기 위해 이식
# Read the json file
with open('loss_plot/experiment_{}/'.format(num)+filename) as f:
json_data = json.load(f)
# Extract the order, lr, batch_size, target_update, memory_len, and repeat_reward values to a list
data_list = []
for block in json_data:
data = []
for ele in block[0]:
try:
val = int(ele.split(':')[1])
except:
val = float(ele.split(':')[1])
data.append(val)
data.append(float(block[1]))
data_list.append(data)
# Create a Pandas DataFrame from the list
df = pd.DataFrame(data_list, columns=['order', 'lr', 'batch_size', 'target_update', 'memory_len', 'repeat_reward', 'win_rate'])
# Write the DataFrame to an excel file
df.to_excel('loss_plot/experiment_{}/summary.xlsx'.format(num))
print(df)