FlexGEMM is a high-performance, Triton-powered GEMM backend designed for 3D sparse convolutions.
It implements Explicit, Implicit, and Masked Implicit algorithm variants, featuring optional Split-K parallelism for sparse GEMM. FlexGEMM delivers state-of-the-art performance for Submanifold Convolution and voxel-based neural networks, consistently outperforming existing solutions.
- Deep Dive: Read the technical blog at JeffreyXiang's Blog.
- Real-world Demo: See FlexGEMM in action in the TRELLIS.2 project.
- Triton-First Architecture: Built entirely on Triton, ensuring high-performance kernel execution and cross-platform compatibility.
- Sparse-Optimized: Specifically tailored for 3D sparse tensors, efficiently handling highly irregular sparsity patterns.
- Blazing Fast: Consistently outperforms standard sparse convolution libraries (such as
spconv,torchsparse) in training throughput.
- PyTorch ≥ 2.4.0
- Triton ≥ 3.2.0
[WIP] BF16 precision support is under development on this branch.
git clone https://github.com/JeffreyXiang/FlexGEMM.git
cd FlexGEMM
pip install .The wheel is pure Python (py3-none-any): CUDA sources ship under flex_gemm/kernels/cuda/
and the pybind extension is JIT-built on first use of a native op (see
flex_gemm/kernels/_cuda_jit.py). The build backend is Hatchling; PyTorch is not
required at build time (only at runtime).
Editable install: pip install -e .
Here is a minimal example demonstrating how to perform a sparse submanifold convolution using FlexGEMM:
import torch
import flex_gemm
from flex_gemm.ops.spconv import sparse_submanifold_conv3d
from tests.spconv_fwd import sphere_coords
# 1. Prepare Sparse Voxel Data
# Generate a sparse voxel shell
feats, coords, shape = sphere_coords(256, 256, dtype=torch.float16, device='cuda')
# 2. Define Weights and Bias
Ci, Co = 256, 256
Ks = 3
weight = torch.randn(Co, Ks, Ks, Ks, Ci, dtype=torch.float16, device='cuda', requires_grad=True)
bias = torch.randn(Co, dtype=torch.float16, device='cuda', requires_grad=True)
# 3. Configure Algorithm
# Example: Using Masked Implicit GEMM with Split-K optimization
flex_gemm.ops.spconv.set_algorithm(
flex_gemm.ops.spconv.Algorithm.MASKED_IMPLICIT_GEMM_SPLITK
)
# 4. Forward Pass
out_feats, neighbor_cache = sparse_submanifold_conv3d(
feats, coords, shape,
weight, bias,
)
# 5. Backward Pass
out_feats.sum().backward()FlexGEMM supports torch.compile via custom op wrappers. The key idea
is to separate geometry preparation from computation: build the
neighbor cache once from geometry (outside compile), freeze it into a
SpConvConfig, then use that config inside the compiled region.
import torch
import flex_gemm
from flex_gemm.ops.spconv import sparse_submanifold_conv3d
from flex_gemm.ops.spconv.submanifold_conv3d import SubMConv3dFunction
# --- Phase 1: Preparation (outside torch.compile, run once) ---
feats, coords, shape = ... # your sparse voxel data
weight = torch.randn(Co, Ks, Ks, Ks, Ci, device='cuda', requires_grad=True)
bias = torch.randn(Co, device='cuda', requires_grad=True)
# Build neighbor cache directly from geometry (no forward pass needed).
# Uses the default algorithm (MASKED_IMPLICIT_GEMM_SPLITK).
neighbor_cache = SubMConv3dFunction._compute_neighbor_cache(
coords, shape, (Ks, Ks, Ks), (1, 1, 1),
)
# Freeze: pre-computes all block-size variants, returns a compile-friendly config
config = neighbor_cache.freeze()
# --- Phase 2: Compiled training loop ---
@torch.compile
def train_step(feats, weight, bias):
# Pass config= to use the compiled path (returns output only, no cache)
out = sparse_submanifold_conv3d(feats, weight=weight, bias=bias, config=config)
return out.sum()
loss = train_step(feats, weight, bias)
loss.backward()Note: The
config=path is only needed fortorch.compile. The legacy API (sparse_submanifold_conv3d(feats, coords, shape, weight, bias)) continues to work unchanged for eager execution.
FlexGEMM demonstrates significant speed improvements over existing baselines.
Test Environment:
- GPU: NVIDIA A100 80GB PCIe
- Software: PyTorch 2.4.1, CUDA 12.0, Triton 3.2.0
Note: FlexGEMM achieves ~2× acceleration compared to previous state-of-the-art methods under efficient data formats like FP16 and TF32.
- SOTA Speed: Consistently outperforms
spconv,torchsparse, andfvdb. - Scalability: Robust performance across various channel widths (C=64 to C=1024) and resolutions (RES=8 to RES=1024).
- Memory Efficient: Delivers higher throughput without increasing GPU memory overhead.
- Application Ready: Ideal for high-resolution voxelized point clouds, submanifold convolutions, and large-scale 3D networks.
We welcome contributions to make FlexGEMM faster and more robust!
- Report Bugs: Open an issue describing the bug and how to reproduce it.
- Suggest Features: Have an idea for a new algorithm or optimization? Let us know!
- Submit Pull Requests:
- Fork the repository and create your branch from
main. - Ensure your code follows the project's style.
- Run the tests in the
tests/directory to ensure no regressions. - Open a Pull Request with a detailed description.
- Fork the repository and create your branch from
We appreciate all contributors who help improve this project!
This project is released under the MIT License.


