Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

UnsupportedOperatorError: aten::scatter_reduce when include_self=False #126660

Open
LTsommer opened this issue May 20, 2024 · 6 comments
Open

UnsupportedOperatorError: aten::scatter_reduce when include_self=False #126660

LTsommer opened this issue May 20, 2024 · 6 comments
Labels
module: onnx Related to torch.onnx onnx-triaged triaged by ONNX team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@LTsommer
Copy link

LTsommer commented May 20, 2024

馃悰 Describe the bug

Hello, experts,
recently I am trying to export .pt model to .oonx model by function torch.onnx.export, but an error occured: "torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scatter_reduce' to ONNX opset version 11 is not supported.". I have tried different opset version from 11 - 18, it still doesn't work. Actually aten::scatter_reduce is not explicit used in the code, so I guess it is called in another function.
the code is as follows:

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch
import torch.nn as nn

from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing, max_pool, avg_pool
from torch_geometric.utils import add_self_loops, remove_self_loops

from mlp_layer import MLP
import onnx
import os


class GraphData(Data):
    """
    override key `cluster` indicating which polyline_id is for the vector
    """

    def __inc__(self, key, value, *args, **kwargs):
        if key == 'edge_index':
            return self.x.size(0)
        elif key == 'cluster':
            return int(self.cluster.max().item()) + 1
        else:
            return 0


class Model(nn.Module):

    def __init__(self, in_channels, num_subgraph_layers=3, hidden_unit=64):
        super(SubGraph, self).__init__()
        self.num_subgraph_layers = num_subgraph_layers
        self.hidden_unit = hidden_unit
        self.out_channels = hidden_unit

        self.layer_seq = nn.Sequential()
        for i in range(num_subgraph_layers):
            self.layer_seq.add_module(
                f'glp_{i}',
                MLP(in_channels, hidden_unit, hidden_unit)
            )
            in_channels = hidden_unit * 2

        self.linear = nn.Linear(hidden_unit * 2, hidden_unit)

    def forward(self, x, cluster, edge_index):
        data = GraphData(x=x,
                         cluster=cluster,
                         edge_index=edge_index)

        for name, layer in self.layer_seq.named_modules():
            if isinstance(layer, MLP):
                x = layer(x)
                data.x = x
                agg_data = max_pool(data.cluster.long(), data)

                x = torch.cat([x, agg_data.x[data.cluster.long()]], dim=-1)

        x = self.linear(x)
        data.x = x
        out = max_pool(data.cluster, data)
        x = out.x

        return F.normalize(x, p=2.0, dim=1)

The error imfomation is as follows:
Traceback (most recent call last):
File "/Desktop/Studien/VN/core/model/layers/subgraph.py", line 141, in
torch.onnx.export(
File "
/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/onnx/utils.py", line 506, in export
_export(
File "/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/onnx/utils.py", line 1548, in _export
graph, params_dict, torch_out = _model_to_graph(
File "
/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/onnx/utils.py", line 1117, in _model_to_graph
graph = _optimize_graph(
File "/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/onnx/utils.py", line 665, in _optimize_graph
graph = _C._jit_pass_onnx(graph, operator_export_type)
File "
/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/onnx/utils.py", line 1901, in _run_symbolic_function
raise errors.UnsupportedOperatorError(
torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scatter_reduce' to ONNX opset version 11 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.
The packae installed in my virtual env:

name: torch
channels:
  - conda-forge
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/msys2/
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
  - defaults
dependencies:
  - absl-py=2.0.0=pyhd8ed1ab_0
  - aiohttp=3.8.5=py39h0f82c59_0
  - aiosignal=1.3.1=pyhd8ed1ab_0
  - appnope=0.1.3=pyhd8ed1ab_0
  - asttokens=2.2.1=pyhd8ed1ab_0
  - async-timeout=4.0.3=pyhd8ed1ab_0
  - attrs=23.1.0=pyh71513ae_1
  - backcall=0.2.0=pyh9f0ad1d_0
  - backports=1.0=pyhd8ed1ab_3
  - backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0
  - blinker=1.6.2=pyhd8ed1ab_0
  - brotli-python=1.1.0=py39hb198ff7_0
  - bzip2=1.0.8=h3422bc3_4
  - c-ares=1.19.1=hb547adb_0
  - ca-certificates=2023.7.22=hf0a4a13_0
  - cachetools=5.3.1=pyhd8ed1ab_0
  - cffi=1.15.1=py39he153c15_5
  - click=8.1.7=unix_pyh707e725_0
  - comm=0.1.3=pyhd8ed1ab_0
  - cryptography=41.0.4=py39had97604_0
  - debugpy=1.6.7=py39h23fbdae_0
  - decorator=5.1.1=pyhd8ed1ab_0
  - executing=1.2.0=pyhd8ed1ab_0
  - frozenlist=1.4.0=py39h0f82c59_1
  - google-auth=2.23.1=pyhca7485f_0
  - google-auth-oauthlib=1.0.0=pyhd8ed1ab_1
  - grpcio=1.57.0=py39hbad4f83_1
  - importlib-metadata=6.8.0=pyha770c72_0
  - importlib_metadata=6.8.0=hd8ed1ab_0
  - jedi=0.18.2=pyhd8ed1ab_0
  - jupyter_client=8.3.0=pyhd8ed1ab_0
  - jupyter_core=5.3.1=py39h2804cbe_0
  - libabseil=20230802.1=cxx17_h13dd4ca_0
  - libblas=3.9.0=18_osxarm64_openblas
  - libcblas=3.9.0=18_osxarm64_openblas
  - libcxx=16.0.6=h4653b0c_0
  - libffi=3.4.2=h3422bc3_5
  - libgfortran=5.0.0=13_2_0_hd922786_1
  - libgfortran5=13.2.0=hf226fd6_1
  - libgrpc=1.57.0=hdbe17d8_1
  - liblapack=3.9.0=18_osxarm64_openblas
  - libopenblas=0.3.24=openmp_hd76b1f2_0
  - libprotobuf=4.23.4=hf590ac1_6
  - libsodium=1.0.18=h27ca646_1
  - libsqlite=3.40.0=h76d750c_0
  - libzlib=1.2.13=h03a7124_4
  - llvm-openmp=16.0.6=h1c12783_0
  - markdown=3.4.4=pyhd8ed1ab_0
  - matplotlib-inline=0.1.6=pyhd8ed1ab_0
  - multidict=6.0.4=py39h02fc5c5_0
  - ncurses=6.3=h07bb92c_1
  - nest-asyncio=1.5.6=pyhd8ed1ab_0
  - oauthlib=3.2.2=pyhd8ed1ab_0
  - openssl=3.1.3=h53f4e23_0
  - packaging=23.1=pyhd8ed1ab_0
  - parso=0.8.3=pyhd8ed1ab_0
  - pexpect=4.8.0=pyh1a96a4e_2
  - pickleshare=0.7.5=py_1003
  - pip=23.0.1=pyhd8ed1ab_0
  - platformdirs=3.9.1=pyhd8ed1ab_0
  - prompt-toolkit=3.0.39=pyha770c72_0
  - prompt_toolkit=3.0.39=hd8ed1ab_0
  - psutil=5.9.5=py39h02fc5c5_0
  - ptyprocess=0.7.0=pyhd3deb0d_0
  - pure_eval=0.2.2=pyhd8ed1ab_0
  - pyasn1=0.5.0=pyhd8ed1ab_0
  - pyasn1-modules=0.3.0=pyhd8ed1ab_0
  - pycparser=2.21=pyhd8ed1ab_0
  - pygments=2.15.1=pyhd8ed1ab_0
  - pyjwt=2.8.0=pyhd8ed1ab_0
  - pyopenssl=23.2.0=pyhd8ed1ab_1
  - pysocks=1.7.1=pyha2e5f31_6
  - python=3.9.16=hea58f1e_0_cpython
  - python-dateutil=2.8.2=pyhd8ed1ab_0
  - python_abi=3.9=3_cp39
  - pyu2f=0.1.5=pyhd8ed1ab_0
  - pyzmq=25.1.0=py39h1e134f0_0
  - re2=2023.03.02=hc5e2d97_0
  - readline=8.1.2=h46ed386_0
  - requests-oauthlib=1.3.1=pyhd8ed1ab_0
  - rsa=4.9=pyhd8ed1ab_0
  - setuptools=67.6.0=pyhd8ed1ab_0
  - six=1.16.0=pyh6c4a22f_0
  - stack_data=0.6.2=pyhd8ed1ab_0
  - tensorboard=2.14.0=pyhd8ed1ab_0
  - tensorboard-data-server=0.7.0=py39had97604_1
  - tk=8.6.12=he1e0b03_0
  - tornado=6.3.2=py39h0f82c59_0
  - traitlets=5.9.0=pyhd8ed1ab_0
  - wcwidth=0.2.6=pyhd8ed1ab_0
  - werkzeug=2.3.7=pyhd8ed1ab_0
  - wheel=0.40.0=pyhd8ed1ab_0
  - xz=5.2.6=h57fd34a_0
  - yarl=1.9.2=py39h0f82c59_0
  - zeromq=4.3.4=hbdafb3b_1
  - zipp=3.16.2=pyhd8ed1ab_0
  - pip:
      - accelerate==0.22.0
      - antlr4-python3-runtime==4.8
      - anyio==3.7.1
      - anykeystore==0.2
      - apex==0.9.10.dev0
      - argon2-cffi==21.3.0
      - argon2-cffi-bindings==21.2.0
      - argoverse==1.1.0
      - arrow==1.2.3
      - ase==3.22.1
      - beautifulsoup4==4.12.2
      - bleach==6.0.0
      - certifi==2022.12.7
      - chardet==4.0.0
      - charset-normalizer==3.1.0
      - coloredlogs==15.0.1
      - colour==0.1.5
      - contourpy==1.1.0
      - cprint==1.2.2
      - cryptacular==1.6.2
      - cycler==0.11.0
      - cython==3.0.0
      - defusedxml==0.7.1
      - descartes==1.1.0
      - docopt==0.6.2
      - exceptiongroup==1.1.2
      - fastjsonschema==2.17.1
      - filelock==3.10.0
      - flatbuffers==23.5.26
      - fonttools==4.41.0
      - fqdn==1.5.1
      - fsspec==2024.5.0
      - fvcore==0.1.5.post20221221
      - googledrivedownloader==0.4
      - h5py==3.9.0
      - humanfriendly==10.0
      - hupper==1.12
      - hydra-core==1.1.0
      - idna==2.10
      - imageio==2.31.2
      - importlib-resources==6.0.0
      - iopath==0.1.10
      - ipykernel==6.24.0
      - ipython==8.12.3
      - ipython-genutils==0.2.0
      - ipywidgets==8.0.7
      - isodate==0.6.1
      - isoduration==20.11.0
      - jinja2==3.1.2
      - joblib==1.3.1
      - jsonpointer==2.4
      - jsonschema==4.18.4
      - jsonschema-specifications==2023.7.1
      - jupyter==1.0.0
      - jupyter-console==6.6.3
      - jupyter-events==0.6.3
      - jupyter-server==2.7.0
      - jupyter-server-terminals==0.4.4
      - jupyterlab-pygments==0.2.2
      - jupyterlab-widgets==3.0.8
      - kiwisolver==1.4.4
      - lapsolver==1.1.0
      - llvmlite==0.40.1
      - markupsafe==2.1.2
      - matplotlib==3.7.2
      - mistune==3.0.1
      - motmetrics==1.1.3
      - mpmath==1.3.0
      - mxnet==1.6.0
      - nbclassic==1.0.0
      - nbclient==0.8.0
      - nbconvert==7.16.4
      - nbformat==5.9.1
      - networkx==3.0
      - notebook==6.5.4
      - notebook-shim==0.2.3
      - numba==0.57.1
      - numdifftools==0.9.41
      - numpy==1.24.4
      - omegaconf==2.1.0
      - onnx==1.14.1
      - onnxruntime==1.16.0
      - opencv-python==4.8.0.76
      - overrides==7.3.1
      - pandas==2.0.3
      - pandocfilters==1.5.0
      - pastedeploy==3.1.0
      - pbkdf2==1.3
      - pillow==9.4.0
      - pipreqs==0.5.0
      - plaster==1.1.2
      - plaster-pastedeploy==1.0.1
      - plyfile==1.0.1
      - polars==0.18.15
      - portalocker==2.7.0
      - prometheus-client==0.17.1
      - protobuf==4.24.3
      - pyntcloud==0.3.1
      - pyparsing==3.0.9
      - pyramid==2.0.2
      - pyramid-mailer==0.15.1
      - python-graphviz==0.8.4
      - python-json-logger==2.0.7
      - python3-openid==3.2.0
      - pytz==2023.3
      - pywavelets==1.4.1
      - pyyaml==6.0.1
      - qtconsole==5.4.3
      - qtpy==2.3.1
      - rdflib==7.0.0
      - referencing==0.30.0
      - repoze-sendmail==4.4.1
      - requests==2.25.1
      - rfc3339-validator==0.1.4
      - rfc3986-validator==0.1.1
      - rpds-py==0.9.2
      - scikit-learn==1.3.0
      - scipy==1.11.1
      - seaborn==0.12.2
      - send2trash==1.8.2
      - shapely==2.0.1
      - sniffio==1.3.0
      - soupsieve==2.4.1
      - sqlalchemy==2.0.25
      - sympy==1.11.1
      - tabulate==0.9.0
      - termcolor==2.3.0
      - terminado==0.17.1
      - thop==0.1.1-2209072238
      - threadpoolctl==3.2.0
      - tinycss2==1.2.1
      - torch==2.0.0
      - torch-cluster==1.5.4
      - torch-geometric==2.4.0
      - torch-scatter==2.0.4
      - torch-sparse==0.6.18
      - torch-spline-conv==1.2.0
      - torch-summary==1.4.5
      - torchaudio==2.0.1
      - torchfile==0.1.0
      - torchkeras==3.9.3
      - torchvision==0.15.1
      - tqdm==4.65.0
      - transaction==4.0
      - translationstring==1.4
      - typing-extensions==4.11.0
      - tzdata==2023.3
      - uri-template==1.3.0
      - urllib3==1.26.15
      - velruse==1.1.1
      - venusian==3.1.0
      - webcolors==1.13
      - webencodings==0.5.1
      - webob==1.8.7
      - websocket-client==1.6.1
      - widgetsnbextension==4.0.8
      - wtforms==3.1.2
      - wtforms-recaptcha==0.3.2
      - yacs==0.1.8
      - yarg==0.1.9
      - zope-deprecation==5.0
      - zope-interface==6.1
      - zope-sqlalchemy==3.1```

### Versions

wget https://raw.githubusercontent.com/pytorch/pytorch/main/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py

cc @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4
@LTsommer
Copy link
Author

MLP is a fc layer, which is

import torch
import torch.nn as nn
import torch.nn.functional as F
import onnx
import onnxruntime as ort


# MLP
class MLP(nn.Module):
    def __init__(
        self,
        in_channel,
        out_channel,
        hidden=64,
        bias=True,
        activation="relu",
        norm="layer",
    ):
        super(MLP, self).__init__()

        # define the activation function
        if activation == "relu":
            act_layer = nn.ReLU
        elif activation == "relu6":
            act_layer = nn.ReLU6
        elif activation == "leaky":
            act_layer = nn.LeakyReLU
        elif activation == "prelu":
            act_layer = nn.PReLU
        else:
            raise NotImplementedError

        # define the normalization function
        if norm == "layer":
            norm_layer = nn.LayerNorm
        elif norm == "batch":
            norm_layer = nn.BatchNorm1d
        else:
            raise NotImplementedError

        # insert the layers
        self.linear1 = nn.Linear(in_channel, hidden, bias=bias)
        self.linear1.apply(self._init_weights)
        self.linear2 = nn.Linear(hidden, out_channel, bias=bias)
        self.linear2.apply(self._init_weights)

        self.norm1 = norm_layer(hidden)
        self.norm2 = norm_layer(out_channel)

        self.act1 = act_layer(inplace=True)
        self.act2 = act_layer(inplace=True)

        self.shortcut = None
        if in_channel != out_channel:
            self.shortcut = nn.Sequential(
                nn.Linear(in_channel, out_channel, bias=bias),
                norm_layer(out_channel)
            )
        # self.layers = nn.Sequential(
        #     self.linear1,
        #     self.norm1,
        #     self.act1,
        #     self.linear2,
        #     self.norm2,
        # )

    @staticmethod
    def _init_weights(m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)

    def forward(self, x):
        # print("\nMLP")
        out = self.linear1(x)
        out = self.norm1(out)
        out = self.act1(out)
        out = self.linear2(out)
        out = self.norm2(out)
        # print("\nx {} \nout {}".format(x.shape, out.shape))
        # out = self.layers(x)
        # print(self.shortcut)
        if self.shortcut:
            out += self.shortcut(x)
        else:
            out += x
        # print("\nx {} \nout {}".format(x.shape, out.shape))
        return self.act2(out)```
it can be directly export as onnx

@tugsbayasgalan tugsbayasgalan added the module: onnx Related to torch.onnx label May 21, 2024
@titaiwangms
Copy link
Collaborator

titaiwangms commented May 21, 2024

scatter_reduce should be supported after/include opset_version=16:

@_onnx_symbolic("aten::scatter_reduce")

What do you get when you export it with 16? Please try a newer version of PyTorch. It looks like you are using torch==2.0.0, which is too old.

opset 11 is indeed not supporting scatter_reduce.

@titaiwangms titaiwangms self-assigned this May 21, 2024
@LTsommer
Copy link
Author

scatter_reduce should be supported after/include opset_version=16:

@_onnx_symbolic("aten::scatter_reduce")

What do you get when you export it with 16? Please try a newer version of PyTorch. It looks like you are using torch==2.0.0, which is too old.
opset 11 is indeed not supporting scatter_reduce.

thanks for reply, opset 16 doesn't work either. I will update my torch version to the newest version and try again.

@LTsommer
Copy link
Author

scatter_reduce should be supported after/include opset_version=16:

@_onnx_symbolic("aten::scatter_reduce")

What do you get when you export it with 16? Please try a newer version of PyTorch. It looks like you are using torch==2.0.0, which is too old.
opset 11 is indeed not supporting scatter_reduce.

another error occured

/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch_geometric/typing.py:110: UserWarning: An issue occurred while importing 'torch-sparse'. Disabling its usage. Stacktrace: dlopen(/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch_sparse/_version_cpu.so, 0x0006): Symbol not found: __ZN3c1017RegisterOperatorsD1Ev
  Referenced from: <24CB9FE0-FFDC-3215-B7D7-A33CF4EE7F2D> /Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch_sparse/_version_cpu.so
  Expected in:     <43889F86-100F-3086-90C3-D4AE08235BA7> /Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib
  warnings.warn(f"An issue occurred while importing 'torch-sparse'. "
/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch_geometric/typing.py:110: UserWarning: An issue occurred while importing 'torch-sparse'. Disabling its usage. Stacktrace: dlopen(/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch_sparse/_version_cpu.so, 0x0006): Symbol not found: __ZN3c1017RegisterOperatorsD1Ev
  Referenced from: <24CB9FE0-FFDC-3215-B7D7-A33CF4EE7F2D> /Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch_sparse/_version_cpu.so
  Expected in:     <43889F86-100F-3086-90C3-D4AE08235BA7> /Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib
  warnings.warn(f"An issue occurred while importing 'torch-sparse'. "
/Users/liaotianzhihao/Desktop/Studien/VN/core/model/layers/global_graph.py:84: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if valid_lens.shape[0] != shape[0]:
/Users/liaotianzhihao/Desktop/Studien/VN/core/model/layers/global_graph.py:92: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  for batch_id, cnt in enumerate(valid_len):
/Users/liaotianzhihao/Desktop/Studien/VN/core/model/layers/global_graph.py:93: TracerWarning: Converting a tensor to a NumPy array might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  cnt = int(cnt.detach().cpu().numpy())
Traceback (most recent call last):
  File "/Users/liaotianzhihao/Desktop/Studien/VN/core/generate_onnx.py", line 47, in <module>
    torch.onnx.export(
  File "/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/onnx/utils.py", line 516, in export
    _export(
  File "/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/onnx/utils.py", line 1612, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/onnx/utils.py", line 1138, in _model_to_graph
    graph = _optimize_graph(
  File "/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/onnx/utils.py", line 677, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/onnx/utils.py", line 1956, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File "/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/onnx/symbolic_helper.py", line 306, in wrapper
    return fn(g, *args, **kwargs)
  File "/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/onnx/symbolic_opset16.py", line 137, in scatter_reduce
    raise errors.OnnxExporterError(
torch.onnx.errors.OnnxExporterError: ONNX does not support include_self=False for scatter_reduce```

@LTsommer
Copy link
Author

scatter_reduce should be supported after/include opset_version=16:

@_onnx_symbolic("aten::scatter_reduce")

What do you get when you export it with 16? Please try a newer version of PyTorch. It looks like you are using torch==2.0.0, which is too old.
opset 11 is indeed not supporting scatter_reduce.

another error occured

/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch_geometric/typing.py:110: UserWarning: An issue occurred while importing 'torch-sparse'. Disabling its usage. Stacktrace: dlopen(/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch_sparse/_version_cpu.so, 0x0006): Symbol not found: __ZN3c1017RegisterOperatorsD1Ev
  Referenced from: <24CB9FE0-FFDC-3215-B7D7-A33CF4EE7F2D> /Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch_sparse/_version_cpu.so
  Expected in:     <43889F86-100F-3086-90C3-D4AE08235BA7> /Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib
  warnings.warn(f"An issue occurred while importing 'torch-sparse'. "
/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch_geometric/typing.py:110: UserWarning: An issue occurred while importing 'torch-sparse'. Disabling its usage. Stacktrace: dlopen(/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch_sparse/_version_cpu.so, 0x0006): Symbol not found: __ZN3c1017RegisterOperatorsD1Ev
  Referenced from: <24CB9FE0-FFDC-3215-B7D7-A33CF4EE7F2D> /Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch_sparse/_version_cpu.so
  Expected in:     <43889F86-100F-3086-90C3-D4AE08235BA7> /Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib
  warnings.warn(f"An issue occurred while importing 'torch-sparse'. "
/Users/liaotianzhihao/Desktop/Studien/VN/core/model/layers/global_graph.py:84: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if valid_lens.shape[0] != shape[0]:
/Users/liaotianzhihao/Desktop/Studien/VN/core/model/layers/global_graph.py:92: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  for batch_id, cnt in enumerate(valid_len):
/Users/liaotianzhihao/Desktop/Studien/VN/core/model/layers/global_graph.py:93: TracerWarning: Converting a tensor to a NumPy array might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  cnt = int(cnt.detach().cpu().numpy())
Traceback (most recent call last):
  File "/Users/liaotianzhihao/Desktop/Studien/VN/core/generate_onnx.py", line 47, in <module>
    torch.onnx.export(
  File "/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/onnx/utils.py", line 516, in export
    _export(
  File "/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/onnx/utils.py", line 1612, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/onnx/utils.py", line 1138, in _model_to_graph
    graph = _optimize_graph(
  File "/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/onnx/utils.py", line 677, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/onnx/utils.py", line 1956, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File "/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/onnx/symbolic_helper.py", line 306, in wrapper
    return fn(g, *args, **kwargs)
  File "/Users/liaotianzhihao/opt/anaconda3/envs/torch/lib/python3.9/site-packages/torch/onnx/symbolic_opset16.py", line 137, in scatter_reduce
    raise errors.OnnxExporterError(
torch.onnx.errors.OnnxExporterError: ONNX does not support include_self=False for scatter_reduce```

uninstall torch-sparse and reinstall torch-sparse solves the first two UserWarning, the following TraceWarning I will try to figure out. A higher version pytorch truely supports operator aten::scatter_reduce, but there is still an error.

@titaiwangms
Copy link
Collaborator

Hi @LTsommer,

Unfortunately, include_self=False is not supported in ONNX spec: onnx/onnx#5100. You would have to modify the model code to get around with that.

@titaiwangms titaiwangms changed the title UnsupportedOperatorError: aten::scatter_reduce UnsupportedOperatorError: aten::scatter_reduce when include_self=False May 22, 2024
@titaiwangms titaiwangms removed their assignment May 22, 2024
@titaiwangms titaiwangms added onnx-triaged triaged by ONNX team and removed oncall: export labels May 23, 2024
@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx onnx-triaged triaged by ONNX team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants