Skip to content

wy17646051/tensorrt_scatter

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TensorRT Scatter

TensorRT Plugin of corresponding PyTorch Scatter operators.


At present, the project is only tested on TensorRT 8.5.x and CUDA 11.6, this does not mean that other versions cannot run, but it should be used with caution.

Supporting Operators TensorRT Version CUDA Version
scatter (sum, add, mean, mul, min, max) 8.5.x 11.6
segment_coo (sum, add, mean, min, max) 8.5.x 11.6
gather_coo 8.5.x 11.6
segment_csr (sum, add, mean, min, max) 8.5.x 11.6
gather_csr 8.5.x 11.6

Installation

Before installing the project, make sure you have configured your CUDA environment based on the support list above and downloaded TensorRT.

TensorRT Plugin

Build the project based on CMake as follows:

mkdir build && cd build
cmake .. -DTENSORRT_PREFIX_PATH="/The/TensorRT/path/you/downloaded" && make

PyTorch Symbolic

The project additionally provides the symbolic function corresponding to the pytorch_scatter operator required by pytorch to export to onnx in example/script/symbolic.py.

Make sure to register the symbol function before calling torch.onnx.export to export the onnx model, e.g.:

from example.script.symbolic import register_symbolic
register_symbolic(op_name=None, opset_version=9)

Example

The Project produce some simple example models based on PyTorch and provided some test data. The original form of the test data is a 3D point cloud of shape [N, 5] (3D coordinates in the first three dimensions and point attributes in the last two dimensions). The model and data loading logic are implemented in example/script/model.py.

In addition, we provide pytorch -> onnx and onnx -> tenerrt transformation scripts based on the example model in example/script/export.py, which can be run as follows:

 export PYTHONPATH="example"
 python example/script/export.py --model scatter_example --trt --onnx

About

TensorRT Plugin of corresponding PyTorch Scatter operators.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published