-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathptToJson.py
More file actions
22 lines (16 loc) · 742 Bytes
/
ptToJson.py
File metadata and controls
22 lines (16 loc) · 742 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# MACOS
# pip3 install torch torchvision torchaudio
import torch
import json
# Load the PyTorch model
pt_path = './Resources/output/weights/mnist_fc128_relu_fc10_softmax'
# IF CUDA
# state_dict = torch.load(pt_path + '.pt')
# Load the state dictionary, mapping it to CPU
state_dict = torch.load(pt_path + '.pt', map_location=torch.device('cpu'))
# Convert the state_dict (OrderedDict) to a regular dictionary for easier processing
weights_biases = {name: param.numpy().tolist() for name, param in state_dict.items()}
# Save weights and biases to a JSON file
with open(pt_path + '_weights_biases.json', 'w') as f:
json.dump(weights_biases, f, indent=4)
print(f"Weights and biases successfully saved to {pt_path}_weights_biases.json")