Skip to content

Commit c5372a4

Browse files
author
Saurav Agarwal
committed
Add gnn backbone
1 parent 7babd79 commit c5372a4

File tree

5 files changed

+126
-8
lines changed

5 files changed

+126
-8
lines changed

cppsrc/tests/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,7 @@ add_executable(test_maps test_maps.cpp)
5555
target_link_libraries(test_maps PRIVATE CoverageControlTorch)
5656
install(TARGETS test_maps DESTINATION ${CMAKE_INSTALL_BINDIR})
5757

58+
add_executable(torch_data_loader torch_data_loader.cpp)
59+
target_link_libraries(torch_data_loader PRIVATE CoverageControlTorch)
60+
install(TARGETS torch_data_loader DESTINATION ${CMAKE_INSTALL_BINDIR})
61+

cppsrc/tests/torch_data_loader.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#include <torch/torch.h>
2+
#include <iostream>
3+
4+
/** Experimenting with torch data loaders
5+
* We have a torch tensor as the data
6+
* How do we load it into a torch data loader?
7+
*/
8+
9+
class TensorDataset : public torch::data::Dataset<TensorDataset> {
10+
private:
11+
torch::Tensor data_;
12+
torch::Tensor targets_;
13+
public:
14+
TensorDataset(torch::Tensor data, torch::Tensor targets) {
15+
data_ = data;
16+
targets_ = targets;
17+
}
18+
19+
torch::data::Example<> get(size_t index) override {
20+
return {data_[index], targets_[index]};
21+
}
22+
23+
torch::optional<size_t> size() const override {
24+
return data_.size(0);
25+
}
26+
};
27+
28+
int main() {
29+
30+
int M = 10; // Dataset size
31+
int kBatchSize = 5;
32+
33+
torch::Tensor data = torch::rand({M,3,3}); // 2 channel image fo 3x3
34+
torch::Tensor targets = torch::rand({M, 3}); // 3 targets for each data
35+
std::cout << data << std::endl;
36+
std::cout << targets << std::endl;
37+
38+
auto dataset = TensorDataset(data, targets).map(torch::data::transforms::Stack<>());
39+
auto data_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
40+
std::move(dataset),
41+
torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2).enforce_ordering(false));
42+
43+
for (torch::data::Example<>& batch : *data_loader) {
44+
std::cout << "Batch size: " << batch.data.size(0) << " | Labels: ";
45+
for (int64_t i = 0; i < batch.data.size(0); ++i) {
46+
std::cout << batch.target[i] << " ";
47+
}
48+
std::cout << std::endl;
49+
}
50+
// In a for loop you can now use your data.
51+
/* for (auto& batch : data_loader) { */
52+
/* auto data = batch.data; */
53+
/* auto labels = batch.target; */
54+
/* std::cout << "Batch data: " << data << std::endl; */
55+
/* std::cout << "Batch labels: " << labels << std::endl; */
56+
/* } */
57+
}

cppsrc/torch/include/CoverageControlTorch/train_cnn.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,10 @@ namespace CoverageControlTorch {
221221
}
222222
config_ = YAML::LoadFile(config_file);
223223
data_dir_ = config_["pDataDir"].as<std::string>();
224-
batch_size_ = config_["BatchSize"].as<size_t>();
225-
num_epochs_ = config_["NumEpochs"].as<size_t>();
226-
learning_rate_ = config_["LearningRate"].as<float>();
227-
weight_decay_ = config_["WeightDecay"].as<float>();
224+
batch_size_ = config_["CNNTraining"]["BatchSize"].as<size_t>();
225+
num_epochs_ = config_["CNNTraining"]["NumEpochs"].as<size_t>();
226+
learning_rate_ = config_["CNNTraining"]["LearningRate"].as<float>();
227+
weight_decay_ = config_["CNNTraining"]["WeightDecay"].as<float>();
228228

229229
cnn_config_ = config_["CNN"] ;
230230
image_size_ = cnn_config_["ImageSize"].as<int>();

params/learning_params.yaml

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,18 @@ pDataDir: "/root/CoverageControl_ws/data/pure_coverage/" # Absolute location
22

33
GPUs: [4, 5]
44

5-
LearningRate: 0.001
6-
WeightDecay: 0.0001
7-
BatchSize: 10
8-
NumEpochs: 50
5+
GNNBackBone:
6+
InputDim: 7
7+
NumHops: 3
8+
NumLayers: 4
9+
LatentSize: 64
10+
OutputDim: 2
11+
12+
GNNTraining:
13+
LearningRate: 0.001
14+
WeightDecay: 0.0001
15+
BatchSize: 10
16+
NumEpochs: 50
917

1018
CNN:
1119
InputDim: 4
@@ -14,3 +22,10 @@ CNN:
1422
LatentSize: 8
1523
KernelSize: 3
1624
ImageSize: 32
25+
26+
CNNTraining:
27+
LearningRate: 0.001
28+
WeightDecay: 0.0001
29+
BatchSize: 10
30+
NumEpochs: 50
31+

python/ts_jit/gnn_backbone.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import sys
2+
import torch
3+
import math
4+
from torch import nn
5+
from torch_geometric.nn import TAGConv
6+
import yaml
7+
8+
class GNNBackBone(nn.Module):
9+
"""
10+
Implements a multi-layer graph convolutional neural network, with ReLU non-linearities between layers,
11+
according to hyperparameters specified in the input config
12+
"""
13+
def __init__(self, input_dim, num_layers, num_hops, latent_size):
14+
super().__init__()
15+
16+
self.input_dim_ = input_dim
17+
self.num_layers_ = num_layers
18+
self.num_hops_ = num_hops
19+
self.latent_size_ = latent_size
20+
21+
f = [self.latent_size_]*self.num_layers_
22+
f = [self.input_dim_] + f
23+
24+
self.graph_convs = nn.ModuleList()
25+
for layer in range(self.num_layers_):
26+
self.graph_convs.append(TAGConv(in_channels=f[layer], out_channels=f[layer+1], K=self.num_hops_).jittable())
27+
28+
def forward(self, x, edge_index, edge_weight) -> torch.Tensor:
29+
for conv in self.graph_convs:
30+
x = conv(x, edge_index, edge_weight)
31+
x = torch.relu(x)
32+
return x
33+
34+
if __name__ == "__main__":
35+
# Load config yaml file
36+
config_file = str(sys.argv[1])
37+
script_file = str(sys.argv[2])
38+
with open(config_file, 'r') as stream:
39+
config = yaml.safe_load(stream)['GNNBackBone']
40+
print(config)
41+
scripted_model = torch.jit.script(GNNBackBone(config['InputDim'], config['NumLayers'], config['NumHops'], config['LatentSize']))
42+
scripted_model.save(script_file)

0 commit comments

Comments
 (0)