[Mlir-commits] [mlir] [mlir][sparse][gpu] fix sparse GPU codegen out buffer (PR #189221)

Vito Secona llvmlistbot at llvm.org
Wed Apr 1 21:14:50 PDT 2026


https://github.com/secona updated https://github.com/llvm/llvm-project/pull/189221

>From 673e31c1e3b7aea687691175a1a72d649ba76e43 Mon Sep 17 00:00:00 2001
From: Vito Secona <secona00 at gmail.com>
Date: Sat, 28 Mar 2026 23:40:55 +0700
Subject: [PATCH 1/2] fix sparse GPU codegen for out buffer

---
 .../Transforms/SparseGPUCodegen.cpp           | 35 ++++++++++++++-----
 .../GPU/gpu_codegen_out_buffer.mlir           | 35 +++++++++++++++++++
 .../Dialect/SparseTensor/GPU/gpu_combi.mlir   | 29 ++++++++++-----
 .../Dialect/SparseTensor/GPU/gpu_matmul.mlir  | 16 ++++++---
 4 files changed, 94 insertions(+), 21 deletions(-)
 create mode 100644 mlir/test/Dialect/SparseTensor/GPU/gpu_codegen_out_buffer.mlir

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 0bd1d34c3504b..377a9191f0a11 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -27,6 +27,8 @@
 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/Support/Casting.h"
 
 using namespace mlir;
 using namespace mlir::sparse_tensor;
@@ -249,19 +251,20 @@ static void genParametersOut(OpBuilder &builder, Location loc, Value out,
                              Value kernelToken, SmallVectorImpl<Value> &scalars,
                              SmallVectorImpl<Value> &buffers,
                              SmallVectorImpl<Value> &args,
-                             SmallVectorImpl<Value> &tokens) {
+                             SmallVectorImpl<Value> &tokens,
+                             ArrayRef<bool> copyBack) {
   unsigned base = scalars.size();
   for (unsigned i = base, e = args.size(); i < e; i++) {
+    unsigned bufIdx = i - base;
     Value firstToken;
-    if (i == base) {
-      // Assumed output parameter: unregister or copy-out.
-      if (out) {
+    if (copyBack[bufIdx]) {
+      if (out && bufIdx == 0) {
         genHostUnregisterMemref(builder, loc, out);
         out = Value();
         continue;
       }
       firstToken =
-          genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken);
+          genCopyMemRef(builder, loc, buffers[bufIdx], args[i], kernelToken);
     } else {
       firstToken = genFirstWait(builder, loc);
     }
@@ -1202,15 +1205,31 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
     SmallVector<Value> constants;
     SmallVector<Value> scalars;
     SmallVector<Value> buffers;
+    SmallVector<bool> copyBack;
     for (Value val : invariants) {
       Type tp = val.getType();
       if (val.getDefiningOp<arith::ConstantOp>())
         constants.push_back(val);
       else if (isa<FloatType>(tp) || tp.isIntOrIndex())
         scalars.push_back(val);
-      else if (isa<MemRefType>(tp))
+      else if (isa<MemRefType>(tp)) {
         buffers.push_back(val);
-      else
+
+        bool isWrite = false;
+        for (Operation *user : val.getUsers()) {
+          if (isa<memref::StoreOp>(user)) {
+            isWrite = true;
+            break;
+          }
+          if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(user)) {
+            if (memInterface.getEffectOnValue<MemoryEffects::Write>(val)) {
+              isWrite = true;
+              break;
+            }
+          }
+        }
+        copyBack.push_back(isWrite);
+      } else
         return failure(); // don't know how to share
     }
     // Pass outlined non-constant values.
@@ -1239,7 +1258,7 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
         genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads);
     // Finalize the outlined arguments.
     genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args,
-                     tokens);
+                     tokens, copyBack);
     genBlockingWait(rewriter, loc, tokens);
     rewriter.eraseOp(forallOp);
     return success();
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_codegen_out_buffer.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_codegen_out_buffer.mlir
new file mode 100644
index 0000000000000..e4ee818b08799
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_codegen_out_buffer.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s --linalg-generalize-named-ops \
+// RUN:             --pre-sparsification-rewrite \
+// RUN:             --sparse-reinterpret-map \
+// RUN:             --sparsification="parallelization-strategy=dense-outer-loop" \
+// RUN:             --sparse-gpu-codegen | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
+
+// CHECK-LABEL: func.func @tensor_add
+// CHECK:         %[[TENSOR_EMPTY:.*]] = tensor.empty()
+// CHECK:         %[[OUT_BUF:.*]] = bufferization.to_buffer %[[TENSOR_EMPTY]]
+// CHECK:         %[[GPU_OUT_BUF:.*]], %[[T0:.*]] = gpu.alloc async [{{.*}}] ()
+// CHECK:         gpu.memcpy async [%[[T0]]] %[[GPU_OUT_BUF]], %[[OUT_BUF]]
+// CHECK:         %[[T1:.*]] = gpu.launch_func async @sparse_kernels::@kernel0 blocks
+// CHECK:         %[[M0:.*]] = gpu.memcpy async [%[[T1]]] %[[OUT_BUF]], %[[GPU_OUT_BUF]]
+// CHECK:         gpu.dealloc async [%[[M0]]] %[[GPU_OUT_BUF]]
+
+func.func @tensor_add(%arg0: tensor<32x32xf32, #CSR>,
+                      %arg1: tensor<32x32xf32, #CSR>) -> tensor<32x32xf32> {
+  %empty = tensor.empty() : tensor<32x32xf32>
+  %res = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>,
+      affine_map<(d0, d1) -> (d0, d1)>
+    ],
+    iterator_types = ["parallel", "parallel"]
+  } ins(%arg0, %arg1 : tensor<32x32xf32, #CSR>, tensor<32x32xf32, #CSR>)
+    outs(%empty : tensor<32x32xf32>) {
+  ^bb0(%in1: f32, %in2: f32, %out: f32):
+    %sum = arith.addf %in1, %in2 : f32
+    linalg.yield %sum : f32
+  } -> tensor<32x32xf32>
+  return %res : tensor<32x32xf32>
+}
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir
index b12bad685b49b..e2c341151ea13 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir
@@ -11,43 +11,56 @@
 // CHECK:       gpu.func @kernel1
 // CHECK:       gpu.func @kernel0
 //
-// CHECK-LABEL: func.func @matmuls
-// CHECK:       gpu.alloc async
-// CHECK:       gpu.memcpy async
+// CHECK-LABEL: func.func @matmuls(
+// CHECK-SAME:    %[[ARG0:.*]]: tensor<1024x8xf64>,
+// CHECK-SAME:    %[[ARG1:.*]]: tensor<8x1024xf64, #sparse>,
+// CHECK-SAME:    %[[ARG2:.*]]: tensor<1024x1024xf64, #sparse>)
+// CHECK-SAME:    -> tensor<1024x1024xf64> {
+// CHECK:       %[[ZERO:.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf64>
+// CHECK:       %[[OUT_BUF0:.*]] = bufferization.to_buffer %[[ZERO]]
 // CHECK:       gpu.alloc async
 // CHECK:       gpu.memcpy async
 // CHECK:       gpu.alloc async
 // CHECK:       gpu.memcpy async
 // CHECK:       gpu.alloc async
 // CHECK:       gpu.memcpy async
+// CHECK:       %[[GPU_OUT_BUF0:.*]], %[[T0:.*]] = gpu.alloc async
+// CHECK:       gpu.memcpy async [%[[T0]]] %[[GPU_OUT_BUF0]], %[[OUT_BUF0]]
 // CHECK:       gpu.alloc async
 // CHECK:       gpu.memcpy async
 // CHECK:       %[[T1:.*]] = gpu.launch_func async @sparse_kernels::@kernel1 blocks
-// CHECK:       gpu.memcpy async [%[[T1]]]
-// CHECK:       gpu.dealloc async
 // CHECK:       gpu.dealloc async
 // CHECK:       gpu.dealloc async
 // CHECK:       gpu.dealloc async
+// CHECK:       %[[T2:.*]] = gpu.memcpy async [%[[T1]]] %[[OUT_BUF0]], %[[GPU_OUT_BUF0]]
+// CHECK:       gpu.dealloc async [%[[T2]]] %[[GPU_OUT_BUF0]]
 // CHECK:       gpu.dealloc async
 // CHECK:       gpu.wait
+// CHECK:       %[[OUT_BUF1:.*]] = bufferization.to_buffer %[[ZERO]]
 // CHECK:       gpu.alloc async
 // CHECK:       gpu.memcpy async
 // CHECK:       gpu.alloc async
 // CHECK:       gpu.memcpy async
 // CHECK:       gpu.alloc async
 // CHECK:       gpu.memcpy async
-// CHECK:       gpu.alloc async
-// CHECK:       gpu.memcpy async
+// CHECK:       %[[GPU_OUT_BUF1:.*]], %[[T4:.*]] = gpu.alloc async
+// CHECK:       gpu.memcpy async [%[[T4]]] %[[GPU_OUT_BUF1]], %[[OUT_BUF1]]
 // CHECK:       gpu.alloc async
 // CHECK:       gpu.memcpy async
 // CHECK:       %[[T0:.*]] = gpu.launch_func async @sparse_kernels::@kernel0 blocks
 // CHECK:       gpu.memcpy async [%[[T0]]]
 // CHECK:       gpu.dealloc async
+// CHECK:       gpu.wait async
 // CHECK:       gpu.dealloc async
+// CHECK:       gpu.wait async
 // CHECK:       gpu.dealloc async
-// CHECK:       gpu.dealloc async
+// CHECK:       %[[T5:.*]] = gpu.memcpy async [%[[T0]]] %[[OUT_BUF1]], %[[GPU_OUT_BUF1]]
+// CHECK:       gpu.dealloc async [%[[T5]]] %[[GPU_OUT_BUF1]]
+// CHECK:       gpu.wait async
 // CHECK:       gpu.dealloc async
 // CHECK:       gpu.wait
+// CHECK:       %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[OUT_BUF1]]
+// CHECK:       return %[[OUT_TENSOR]]
 //
 func.func @matmuls(%A: tensor<1024x8xf64>,
                    %B: tensor<8x1024xf64, #CSR>,
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
index 2c236d48ae78e..c5498aec0f522 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
@@ -47,7 +47,11 @@
 // CHECK:       }
 //
 //
-// CHECK-LABEL: func.func @matmul
+// CHECK-LABEL: func.func @matmul(
+// CHECK-SAME:    %[[ARG0:.*]]: tensor<?x?xf64, #sparse>,
+// CHECK-SAME:    %[[ARG1:.*]]: tensor<?x?xf64>,
+// CHECK-SAME:    %[[ARG2:.*]]: tensor<?x?xf64>) -> tensor<?x?xf64> {
+// CHECK:       %[[OUT_BUF:.*]] = bufferization.to_buffer %[[ARG2]]
 // CHECK:       gpu.wait async
 // CHECK:       gpu.alloc async
 // CHECK:       %[[S0:.*]] = gpu.memcpy async
@@ -58,24 +62,26 @@
 // CHECK:       gpu.alloc async
 // CHECK:       %[[S2:.*]] = gpu.memcpy async
 // CHECK:       gpu.wait async
-// CHECK:       gpu.alloc async
-// CHECK:       %[[S3:.*]] = gpu.memcpy async
+// CHECK:       %[[GPU_OUT_BUF:.*]], %[[T0:.*]] = gpu.alloc async
+// CHECK:       %[[S3:.*]] = gpu.memcpy async [%[[T0]]] %[[GPU_OUT_BUF]], %[[OUT_BUF]]
 // CHECK:       gpu.wait async
 // CHECK:       gpu.alloc async
 // CHECK:       %[[S4:.*]] = gpu.memcpy async
 // CHECK:       gpu.wait [%[[S0]], %[[S1]], %[[S2]], %[[S3]], %[[S4]]
 // CHECK:       %[[T0:.*]] = gpu.launch_func async @sparse_kernels::@kernel0 blocks
-// CHECK:       %[[M0:.*]] = gpu.memcpy async [%[[T0]]]
+// CHECK:       %[[M0:.*]] = gpu.wait async
 // CHECK:       %[[M1:.*]] = gpu.dealloc async [%[[M0]]]
 // CHECK:       %[[M2:.*]] = gpu.wait async
 // CHECK:       %[[M3:.*]] = gpu.dealloc async [%[[M2]]]
 // CHECK:       %[[M4:.*]] = gpu.wait async
 // CHECK:       %[[M5:.*]] = gpu.dealloc async [%[[M4]]]
-// CHECK:       %[[M6:.*]] = gpu.wait async
+// CHECK:       %[[M6:.*]] = gpu.memcpy async [%[[T0]]] %[[OUT_BUF]], %[[GPU_OUT_BUF]]
 // CHECK:       %[[M7:.*]] = gpu.dealloc async [%[[M6]]]
 // CHECK:       %[[M8:.*]] = gpu.wait async
 // CHECK:       %[[M9:.*]] = gpu.dealloc async [%[[M8]]]
 // CHECK:       gpu.wait [%[[M1]], %[[M3]], %[[M5]], %[[M7]], %[[M9]]
+// CHECK:       %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[OUT_BUF]]
+// CHECK:       return %[[OUT_TENSOR]]
 //
 func.func @matmul(%A: tensor<?x?xf64, #CSR>, %B: tensor<?x?xf64>, %C_in: tensor<?x?xf64>) -> tensor<?x?xf64> {
   %C_out = linalg.matmul

>From 787889fb136199b38bc6e9a7adf508d4f186bdd3 Mon Sep 17 00:00:00 2001
From: Vito Secona <secona00 at gmail.com>
Date: Thu, 2 Apr 2026 10:45:20 +0700
Subject: [PATCH 2/2] document the changes to buffer copy back

---
 .../SparseTensor/Transforms/SparseGPUCodegen.cpp  | 15 +++++++++++++++
 1 file changed, 15 insertions(+)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 377a9191f0a11..f15a52fa010a7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -247,6 +247,10 @@ static Value genParametersIn(OpBuilder &builder, Location loc,
 /// Finalizes the outlined arguments. The output buffer is copied depending
 /// on the kernel token and then deallocated. All other buffers are simply
 /// deallocated. Then we wait for all operations to complete.
+///
+/// `copyBack` maps 1:1 to the `buffers` array. It tracks which buffers were
+/// mutated by the kernel and require a device-to-host copy. An empty
+/// `copyBack` array implies no buffers are "copied back".
 static void genParametersOut(OpBuilder &builder, Location loc, Value out,
                              Value kernelToken, SmallVectorImpl<Value> &scalars,
                              SmallVectorImpl<Value> &buffers,
@@ -254,9 +258,15 @@ static void genParametersOut(OpBuilder &builder, Location loc, Value out,
                              SmallVectorImpl<Value> &tokens,
                              ArrayRef<bool> copyBack) {
   unsigned base = scalars.size();
+
+  // `args` stores scalars followed by buffers. `base` is the index of the first
+  // buffer. `bufIdx` maps the current buffer to its exact 1:1 counterpart in
+  // the `copyBack` mask.
   for (unsigned i = base, e = args.size(); i < e; i++) {
     unsigned bufIdx = i - base;
     Value firstToken;
+
+    // Checks if the current buffer needs a device-to-host copy.
     if (copyBack[bufIdx]) {
       if (out && bufIdx == 0) {
         genHostUnregisterMemref(builder, loc, out);
@@ -1205,6 +1215,9 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
     SmallVector<Value> constants;
     SmallVector<Value> scalars;
     SmallVector<Value> buffers;
+    // A boolean mask aligned 1:1 with the `buffers` array, tracking which
+    // of those buffers were mutated by the loop. If true, the corresponding
+    // buffer needs to be "copied back" using a device-to-host copy.
     SmallVector<bool> copyBack;
     for (Value val : invariants) {
       Type tp = val.getType();
@@ -1215,6 +1228,8 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
       else if (isa<MemRefType>(tp)) {
         buffers.push_back(val);
 
+        // Determine if the buffer needs to be "copied back" from device
+        // to host by checking for `memref.store` and the write memory effect.
         bool isWrite = false;
         for (Operation *user : val.getUsers()) {
           if (isa<memref::StoreOp>(user)) {



More information about the Mlir-commits mailing list