[Mlir-commits] [mlir] [mlir][sparse][gpu] fix sparse GPU codegen out buffer (PR #189221)
Vito Secona
llvmlistbot at llvm.org
Wed Apr 1 21:14:50 PDT 2026
https://github.com/secona updated https://github.com/llvm/llvm-project/pull/189221
>From 673e31c1e3b7aea687691175a1a72d649ba76e43 Mon Sep 17 00:00:00 2001
From: Vito Secona <secona00 at gmail.com>
Date: Sat, 28 Mar 2026 23:40:55 +0700
Subject: [PATCH 1/2] fix sparse GPU codegen for out buffer
---
.../Transforms/SparseGPUCodegen.cpp | 35 ++++++++++++++-----
.../GPU/gpu_codegen_out_buffer.mlir | 35 +++++++++++++++++++
.../Dialect/SparseTensor/GPU/gpu_combi.mlir | 29 ++++++++++-----
.../Dialect/SparseTensor/GPU/gpu_matmul.mlir | 16 ++++++---
4 files changed, 94 insertions(+), 21 deletions(-)
create mode 100644 mlir/test/Dialect/SparseTensor/GPU/gpu_codegen_out_buffer.mlir
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 0bd1d34c3504b..377a9191f0a11 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -27,6 +27,8 @@
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/Support/Casting.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
@@ -249,19 +251,20 @@ static void genParametersOut(OpBuilder &builder, Location loc, Value out,
Value kernelToken, SmallVectorImpl<Value> &scalars,
SmallVectorImpl<Value> &buffers,
SmallVectorImpl<Value> &args,
- SmallVectorImpl<Value> &tokens) {
+ SmallVectorImpl<Value> &tokens,
+ ArrayRef<bool> copyBack) {
unsigned base = scalars.size();
for (unsigned i = base, e = args.size(); i < e; i++) {
+ unsigned bufIdx = i - base;
Value firstToken;
- if (i == base) {
- // Assumed output parameter: unregister or copy-out.
- if (out) {
+ if (copyBack[bufIdx]) {
+ if (out && bufIdx == 0) {
genHostUnregisterMemref(builder, loc, out);
out = Value();
continue;
}
firstToken =
- genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken);
+ genCopyMemRef(builder, loc, buffers[bufIdx], args[i], kernelToken);
} else {
firstToken = genFirstWait(builder, loc);
}
@@ -1202,15 +1205,31 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
SmallVector<Value> constants;
SmallVector<Value> scalars;
SmallVector<Value> buffers;
+ SmallVector<bool> copyBack;
for (Value val : invariants) {
Type tp = val.getType();
if (val.getDefiningOp<arith::ConstantOp>())
constants.push_back(val);
else if (isa<FloatType>(tp) || tp.isIntOrIndex())
scalars.push_back(val);
- else if (isa<MemRefType>(tp))
+ else if (isa<MemRefType>(tp)) {
buffers.push_back(val);
- else
+
+ bool isWrite = false;
+ for (Operation *user : val.getUsers()) {
+ if (isa<memref::StoreOp>(user)) {
+ isWrite = true;
+ break;
+ }
+ if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(user)) {
+ if (memInterface.getEffectOnValue<MemoryEffects::Write>(val)) {
+ isWrite = true;
+ break;
+ }
+ }
+ }
+ copyBack.push_back(isWrite);
+ } else
return failure(); // don't know how to share
}
// Pass outlined non-constant values.
@@ -1239,7 +1258,7 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads);
// Finalize the outlined arguments.
genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args,
- tokens);
+ tokens, copyBack);
genBlockingWait(rewriter, loc, tokens);
rewriter.eraseOp(forallOp);
return success();
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_codegen_out_buffer.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_codegen_out_buffer.mlir
new file mode 100644
index 0000000000000..e4ee818b08799
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_codegen_out_buffer.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s --linalg-generalize-named-ops \
+// RUN: --pre-sparsification-rewrite \
+// RUN: --sparse-reinterpret-map \
+// RUN: --sparsification="parallelization-strategy=dense-outer-loop" \
+// RUN: --sparse-gpu-codegen | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+
+// CHECK-LABEL: func.func @tensor_add
+// CHECK: %[[TENSOR_EMPTY:.*]] = tensor.empty()
+// CHECK: %[[OUT_BUF:.*]] = bufferization.to_buffer %[[TENSOR_EMPTY]]
+// CHECK: %[[GPU_OUT_BUF:.*]], %[[T0:.*]] = gpu.alloc async [{{.*}}] ()
+// CHECK: gpu.memcpy async [%[[T0]]] %[[GPU_OUT_BUF]], %[[OUT_BUF]]
+// CHECK: %[[T1:.*]] = gpu.launch_func async @sparse_kernels::@kernel0 blocks
+// CHECK: %[[M0:.*]] = gpu.memcpy async [%[[T1]]] %[[OUT_BUF]], %[[GPU_OUT_BUF]]
+// CHECK: gpu.dealloc async [%[[M0]]] %[[GPU_OUT_BUF]]
+
+func.func @tensor_add(%arg0: tensor<32x32xf32, #CSR>,
+ %arg1: tensor<32x32xf32, #CSR>) -> tensor<32x32xf32> {
+ %empty = tensor.empty() : tensor<32x32xf32>
+ %res = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>
+ ],
+ iterator_types = ["parallel", "parallel"]
+ } ins(%arg0, %arg1 : tensor<32x32xf32, #CSR>, tensor<32x32xf32, #CSR>)
+ outs(%empty : tensor<32x32xf32>) {
+ ^bb0(%in1: f32, %in2: f32, %out: f32):
+ %sum = arith.addf %in1, %in2 : f32
+ linalg.yield %sum : f32
+ } -> tensor<32x32xf32>
+ return %res : tensor<32x32xf32>
+}
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir
index b12bad685b49b..e2c341151ea13 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir
@@ -11,43 +11,56 @@
// CHECK: gpu.func @kernel1
// CHECK: gpu.func @kernel0
//
-// CHECK-LABEL: func.func @matmuls
-// CHECK: gpu.alloc async
-// CHECK: gpu.memcpy async
+// CHECK-LABEL: func.func @matmuls(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1024x8xf64>,
+// CHECK-SAME: %[[ARG1:.*]]: tensor<8x1024xf64, #sparse>,
+// CHECK-SAME: %[[ARG2:.*]]: tensor<1024x1024xf64, #sparse>)
+// CHECK-SAME: -> tensor<1024x1024xf64> {
+// CHECK: %[[ZERO:.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf64>
+// CHECK: %[[OUT_BUF0:.*]] = bufferization.to_buffer %[[ZERO]]
// CHECK: gpu.alloc async
// CHECK: gpu.memcpy async
// CHECK: gpu.alloc async
// CHECK: gpu.memcpy async
// CHECK: gpu.alloc async
// CHECK: gpu.memcpy async
+// CHECK: %[[GPU_OUT_BUF0:.*]], %[[T0:.*]] = gpu.alloc async
+// CHECK: gpu.memcpy async [%[[T0]]] %[[GPU_OUT_BUF0]], %[[OUT_BUF0]]
// CHECK: gpu.alloc async
// CHECK: gpu.memcpy async
// CHECK: %[[T1:.*]] = gpu.launch_func async @sparse_kernels::@kernel1 blocks
-// CHECK: gpu.memcpy async [%[[T1]]]
-// CHECK: gpu.dealloc async
// CHECK: gpu.dealloc async
// CHECK: gpu.dealloc async
// CHECK: gpu.dealloc async
+// CHECK: %[[T2:.*]] = gpu.memcpy async [%[[T1]]] %[[OUT_BUF0]], %[[GPU_OUT_BUF0]]
+// CHECK: gpu.dealloc async [%[[T2]]] %[[GPU_OUT_BUF0]]
// CHECK: gpu.dealloc async
// CHECK: gpu.wait
+// CHECK: %[[OUT_BUF1:.*]] = bufferization.to_buffer %[[ZERO]]
// CHECK: gpu.alloc async
// CHECK: gpu.memcpy async
// CHECK: gpu.alloc async
// CHECK: gpu.memcpy async
// CHECK: gpu.alloc async
// CHECK: gpu.memcpy async
-// CHECK: gpu.alloc async
-// CHECK: gpu.memcpy async
+// CHECK: %[[GPU_OUT_BUF1:.*]], %[[T4:.*]] = gpu.alloc async
+// CHECK: gpu.memcpy async [%[[T4]]] %[[GPU_OUT_BUF1]], %[[OUT_BUF1]]
// CHECK: gpu.alloc async
// CHECK: gpu.memcpy async
// CHECK: %[[T0:.*]] = gpu.launch_func async @sparse_kernels::@kernel0 blocks
// CHECK: gpu.memcpy async [%[[T0]]]
// CHECK: gpu.dealloc async
+// CHECK: gpu.wait async
// CHECK: gpu.dealloc async
+// CHECK: gpu.wait async
// CHECK: gpu.dealloc async
-// CHECK: gpu.dealloc async
+// CHECK: %[[T5:.*]] = gpu.memcpy async [%[[T0]]] %[[OUT_BUF1]], %[[GPU_OUT_BUF1]]
+// CHECK: gpu.dealloc async [%[[T5]]] %[[GPU_OUT_BUF1]]
+// CHECK: gpu.wait async
// CHECK: gpu.dealloc async
// CHECK: gpu.wait
+// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[OUT_BUF1]]
+// CHECK: return %[[OUT_TENSOR]]
//
func.func @matmuls(%A: tensor<1024x8xf64>,
%B: tensor<8x1024xf64, #CSR>,
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
index 2c236d48ae78e..c5498aec0f522 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
@@ -47,7 +47,11 @@
// CHECK: }
//
//
-// CHECK-LABEL: func.func @matmul
+// CHECK-LABEL: func.func @matmul(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf64, #sparse>,
+// CHECK-SAME: %[[ARG1:.*]]: tensor<?x?xf64>,
+// CHECK-SAME: %[[ARG2:.*]]: tensor<?x?xf64>) -> tensor<?x?xf64> {
+// CHECK: %[[OUT_BUF:.*]] = bufferization.to_buffer %[[ARG2]]
// CHECK: gpu.wait async
// CHECK: gpu.alloc async
// CHECK: %[[S0:.*]] = gpu.memcpy async
@@ -58,24 +62,26 @@
// CHECK: gpu.alloc async
// CHECK: %[[S2:.*]] = gpu.memcpy async
// CHECK: gpu.wait async
-// CHECK: gpu.alloc async
-// CHECK: %[[S3:.*]] = gpu.memcpy async
+// CHECK: %[[GPU_OUT_BUF:.*]], %[[T0:.*]] = gpu.alloc async
+// CHECK: %[[S3:.*]] = gpu.memcpy async [%[[T0]]] %[[GPU_OUT_BUF]], %[[OUT_BUF]]
// CHECK: gpu.wait async
// CHECK: gpu.alloc async
// CHECK: %[[S4:.*]] = gpu.memcpy async
// CHECK: gpu.wait [%[[S0]], %[[S1]], %[[S2]], %[[S3]], %[[S4]]
// CHECK: %[[T0:.*]] = gpu.launch_func async @sparse_kernels::@kernel0 blocks
-// CHECK: %[[M0:.*]] = gpu.memcpy async [%[[T0]]]
+// CHECK: %[[M0:.*]] = gpu.wait async
// CHECK: %[[M1:.*]] = gpu.dealloc async [%[[M0]]]
// CHECK: %[[M2:.*]] = gpu.wait async
// CHECK: %[[M3:.*]] = gpu.dealloc async [%[[M2]]]
// CHECK: %[[M4:.*]] = gpu.wait async
// CHECK: %[[M5:.*]] = gpu.dealloc async [%[[M4]]]
-// CHECK: %[[M6:.*]] = gpu.wait async
+// CHECK: %[[M6:.*]] = gpu.memcpy async [%[[T0]]] %[[OUT_BUF]], %[[GPU_OUT_BUF]]
// CHECK: %[[M7:.*]] = gpu.dealloc async [%[[M6]]]
// CHECK: %[[M8:.*]] = gpu.wait async
// CHECK: %[[M9:.*]] = gpu.dealloc async [%[[M8]]]
// CHECK: gpu.wait [%[[M1]], %[[M3]], %[[M5]], %[[M7]], %[[M9]]
+// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[OUT_BUF]]
+// CHECK: return %[[OUT_TENSOR]]
//
func.func @matmul(%A: tensor<?x?xf64, #CSR>, %B: tensor<?x?xf64>, %C_in: tensor<?x?xf64>) -> tensor<?x?xf64> {
%C_out = linalg.matmul
>From 787889fb136199b38bc6e9a7adf508d4f186bdd3 Mon Sep 17 00:00:00 2001
From: Vito Secona <secona00 at gmail.com>
Date: Thu, 2 Apr 2026 10:45:20 +0700
Subject: [PATCH 2/2] document the changes to buffer copy back
---
.../SparseTensor/Transforms/SparseGPUCodegen.cpp | 15 +++++++++++++++
1 file changed, 15 insertions(+)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 377a9191f0a11..f15a52fa010a7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -247,6 +247,10 @@ static Value genParametersIn(OpBuilder &builder, Location loc,
/// Finalizes the outlined arguments. The output buffer is copied depending
/// on the kernel token and then deallocated. All other buffers are simply
/// deallocated. Then we wait for all operations to complete.
+///
+/// `copyBack` maps 1:1 to the `buffers` array. It tracks which buffers were
+/// mutated by the kernel and require a device-to-host copy. An empty
+/// `copyBack` array implies no buffers are "copied back".
static void genParametersOut(OpBuilder &builder, Location loc, Value out,
Value kernelToken, SmallVectorImpl<Value> &scalars,
SmallVectorImpl<Value> &buffers,
@@ -254,9 +258,15 @@ static void genParametersOut(OpBuilder &builder, Location loc, Value out,
SmallVectorImpl<Value> &tokens,
ArrayRef<bool> copyBack) {
unsigned base = scalars.size();
+
+ // `args` stores scalars followed by buffers. `base` is the index of the first
+ // buffer. `bufIdx` maps the current buffer to its exact 1:1 counterpart in
+ // the `copyBack` mask.
for (unsigned i = base, e = args.size(); i < e; i++) {
unsigned bufIdx = i - base;
Value firstToken;
+
+ // Checks if the current buffer needs a device-to-host copy.
if (copyBack[bufIdx]) {
if (out && bufIdx == 0) {
genHostUnregisterMemref(builder, loc, out);
@@ -1205,6 +1215,9 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
SmallVector<Value> constants;
SmallVector<Value> scalars;
SmallVector<Value> buffers;
+ // A boolean mask aligned 1:1 with the `buffers` array, tracking which
+ // of those buffers were mutated by the loop. If true, the corresponding
+ // buffer needs to be "copied back" using a device-to-host copy.
SmallVector<bool> copyBack;
for (Value val : invariants) {
Type tp = val.getType();
@@ -1215,6 +1228,8 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
else if (isa<MemRefType>(tp)) {
buffers.push_back(val);
+ // Determine if the buffer needs to be "copied back" from device
+ // to host by checking for `memref.store` and the write memory effect.
bool isWrite = false;
for (Operation *user : val.getUsers()) {
if (isa<memref::StoreOp>(user)) {
More information about the Mlir-commits
mailing list