def get_relu_coefs(self, x):
print(x.shape)
# axis?
theta = torch.mean(x, dim=-1)
if self.conv_type == '2d':
# axis?
theta = torch.mean(theta, dim=-1)
theta = self.fc1(theta)
theta = self.relu(theta)
theta = self.fc2(theta)
theta = 2 * self.sigmoid(theta) - 1
return theta
My torch version is 1.1.0.
torch.mean(input, dim, out=None) → Tensor
My torch version is 1.1.0.
torch.mean(input, dim, out=None) → Tensor