Skip to content

quat2unit modifies tensors in-place #371

@lahavlipson

Description

@lahavlipson

🐛 Describe the bug

The function pypose.quat2unit modifies tensors in place. E.g.,

import pypose

a = pypose.randn_SE3(requires_grad=True)
b = pypose.quat2unit(a)
b.tensor().mean().backward()

raises

Traceback (most recent call last):
  File "/root/test.py", line 5, in <module>
    b = pypose.quat2unit(a)
  File "/miniforge/envs/pypose/lib/python3.10/site-packages/pypose/lietensor/convert.py", line 855, in quat2unit
    data[..., 3:7] = normalize(data[..., 3:7], p=2, dim=-1, eps=eps)
RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.

a potential fix is to change these lines to

if input.ltype in [SO3_type, RxSO3_type]:
    quat, *scale = torch.split(data, [4, input.ltype.dimension[0] - 4], -1)
    quat = normalize(quat, p=2, dim=-1, eps=eps)
    data = torch.concatenate((quat, *scale), -1)
elif input.ltype in [SE3_type, Sim3_type]:
    t, quat, *scale = torch.split(data, [3, 4, input.ltype.dimension[0] - 7], -1)
    quat = normalize(quat, p=2, dim=-1, eps=eps)
    data = torch.concatenate((t, quat, *scale), -1)

Versions

PyTorch version: 2.4.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.2 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.6)
CMake version: version 3.30.5
Libc version: N/A

Python version: 3.10.15 (main, Sep  9 2024, 22:43:48) [Clang 18.1.8 ] (64-bit runtime)
Python platform: macOS-15.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy-boto3-s3==1.35.32
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] torch==2.4.1
[pip3] torch_scatter==2.1.2
[pip3] torch_sparse==0.6.18
[pip3] torch-tb-profiler==0.4.3
[pip3] torchmetrics==1.5.1
[pip3] torchvision==0.19.1
[pip3] torchviz==0.0.2
[conda] numpy                     1.23.5                   pypi_0    pypi
[conda] numpydoc                  1.7.0           py310hca03da5_0
[conda] torch                     2.3.1                    pypi_0    pypi
[conda] torch-tb-profiler         0.4.3                    pypi_0    pypi
[conda] torchvision               0.18.1                   pypi_0    pypi

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions