Disable __torch_function__ wrapping if __torch_dispatch__ is defined#73942
Disable __torch_function__ wrapping if __torch_dispatch__ is defined#73942ezyang wants to merge 2 commits intogh/ezyang/1096/basefrom
Conversation
This is never what you want and has led us to have to spray `__torch_function__ = _disabled_torch_function_impl` everywhere. Fortunately the fix is simple. You might still want to overwrite `__torch_function__` for performance reasons, as you can bypass having to run the Python code associated with the default `__torch_function__`. It might be a good performance optimization to bypass this by default if we detect you have the default `__torch_function__` and a `__torch_dispatch__`, but that is left for future work. Signed-off-by: Edward Z. Yang <[email protected]> [ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
|
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 38216e2 (more details on the Dr. CI page):
🕵️ 10 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
…is defined" This is never what you want and has led us to have to spray `__torch_function__ = _disabled_torch_function_impl` everywhere. Fortunately the fix is simple. You might still want to overwrite `__torch_function__` for performance reasons, as you can bypass having to run the Python code associated with the default `__torch_function__`. It might be a good performance optimization to bypass this by default if we detect you have the default `__torch_function__` and a `__torch_dispatch__`, but that is left for future work. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
This is never what you want and has led us to have to spray `__torch_function__ = _disabled_torch_function_impl` everywhere. Fortunately the fix is simple. You might still want to overwrite `__torch_function__` for performance reasons, as you can bypass having to run the Python code associated with the default `__torch_function__`. It might be a good performance optimization to bypass this by default if we detect you have the default `__torch_function__` and a `__torch_dispatch__`, but that is left for future work. Signed-off-by: Edward Z. Yang <ezyangfb.com> ghstack-source-id: ab8e7d1 Pull Request resolved: #73942
|
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
bdhirsh
left a comment
There was a problem hiding this comment.
LGTM! I was actually thinking of something like updating check_has_torch_function() to return false by default when torch_dispatch is defined (and maybe renaming it to check_should_call_torch_function()), but this seems cleaner (but more expensive)
|
It also doesn't work in all cases, because the DisableTorchFunction in the default implementation means that I cannot call FX proxies anymore (as they work on torch function). Need #55093 |
If __torch_function__ was disabled, this TLS should propagate to other threads. Although I was thinking about #73942 when I did this, this doesn't actually help solve the problem, because when I disable __torch_function__ as part of the disabled __torch_function__ implementation, this is prior to when snapshotting happens (also snapshotting only happens for Python tensors anyway). I intend to add some more TLS to this struct soon, which is why it's a struct and not just a bool. Testing is not so easy to do because on CPU there isn't an easy way to get Python code running in another thread. Signed-off-by: Edward Z. Yang <[email protected]> [ghstack-poisoned]
If __torch_function__ was disabled, this TLS should propagate to other threads. Although I was thinking about #73942 when I did this, this doesn't actually help solve the problem, because when I disable __torch_function__ as part of the disabled __torch_function__ implementation, this is prior to when snapshotting happens (also snapshotting only happens for Python tensors anyway). I intend to add some more TLS to this struct soon, which is why it's a struct and not just a bool. Testing is not so easy to do because on CPU there isn't an easy way to get Python code running in another thread. Signed-off-by: Edward Z. Yang <ezyangfb.com> ghstack-source-id: 90ce891 Pull Request resolved: #75110
…State" If __torch_function__ was disabled, this TLS should propagate to other threads. Although I was thinking about #73942 when I did this, this doesn't actually help solve the problem, because when I disable __torch_function__ as part of the disabled __torch_function__ implementation, this is prior to when snapshotting happens (also snapshotting only happens for Python tensors anyway). I intend to add some more TLS to this struct soon, which is why it's a struct and not just a bool. Testing is not so easy to do because on CPU there isn't an easy way to get Python code running in another thread. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
If __torch_function__ was disabled, this TLS should propagate to other threads. Although I was thinking about #73942 when I did this, this doesn't actually help solve the problem, because when I disable __torch_function__ as part of the disabled __torch_function__ implementation, this is prior to when snapshotting happens (also snapshotting only happens for Python tensors anyway). I intend to add some more TLS to this struct soon, which is why it's a struct and not just a bool. Testing is not so easy to do because on CPU there isn't an easy way to get Python code running in another thread. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
If __torch_function__ was disabled, this TLS should propagate to other threads. Although I was thinking about #73942 when I did this, this doesn't actually help solve the problem, because when I disable __torch_function__ as part of the disabled __torch_function__ implementation, this is prior to when snapshotting happens (also snapshotting only happens for Python tensors anyway). I intend to add some more TLS to this struct soon, which is why it's a struct and not just a bool. Testing is not so easy to do because on CPU there isn't an easy way to get Python code running in another thread. Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: #75110 Approved by: https://github.com/albanD
Summary: If __torch_function__ was disabled, this TLS should propagate to other threads. Although I was thinking about #73942 when I did this, this doesn't actually help solve the problem, because when I disable __torch_function__ as part of the disabled __torch_function__ implementation, this is prior to when snapshotting happens (also snapshotting only happens for Python tensors anyway). I intend to add some more TLS to this struct soon, which is why it's a struct and not just a bool. Testing is not so easy to do because on CPU there isn't an easy way to get Python code running in another thread. Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: #75110 Approved by: https://github.com/albanD Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/3f108a5cc132f5940553980c2b84947f7d9e1ed2 Reviewed By: atalman Differential Revision: D35359970 Pulled By: atalman fbshipit-source-id: 98b3910b05f79f5a4860f12959878d72c6ce6f56
|
cc @ezyang anything blocking this? |
|
yeah, the PR doesn't work because reentrant torch function is broken T.T |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Stack from ghstack (oldest at bottom):
This is never what you want and has led us to have to spray
__torch_function__ = _disabled_torch_function_impleverywhere.Fortunately the fix is simple.
You might still want to overwrite
__torch_function__for performancereasons, as you can bypass having to run the Python code associated with
the default
__torch_function__. It might be a good performanceoptimization to bypass this by default if we detect you have the
default
__torch_function__and a__torch_dispatch__, but that isleft for future work.
Signed-off-by: Edward Z. Yang [email protected]
Differential Revision: D34730883