[Mlir-commits] [mlir] [mlir][sparse] Fix crash in sparsification when unary/binary present block captures sparse tensor argument (PR #184597)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 4 03:57:21 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sparse

Author: Mehdi Amini (joker-eph)

<details>
<summary>Changes</summary>

`relinkBranch` in Sparsification.cpp assumed that any block argument from the outer `linalg.generic` op encountered inside an inlined semi-ring branch must be a dense tensor, and asserted accordingly. However, the `present` block of a `sparse_tensor.unary` (or similar semi-ring ops) is permitted to capture sparse tensor operands directly via `isAdmissibleBranchExp`, which accepts any `BlockArgument` as admissible.

The fix removes the incorrect assertion and extends the load generation to handle sparse tensors using `genSubscript`, which already knows how to return the value buffer and current value position via the loop emitter. The `kSparseIterator` strategy (where `genSubscript` returns a `TensorType`) is also handled by emitting a `sparse_tensor.extract_value` op.

Fixes #<!-- -->91183

---
Full diff: https://github.com/llvm/llvm-project/pull/184597.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+13-5) 
- (modified) mlir/test/Dialect/SparseTensor/spy_sddmm.mlir (+71) 


``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 0a5f5595bba56..6004ab26f4663 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -588,17 +588,25 @@ inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
 static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
                           Value e) {
   if (auto arg = dyn_cast<BlockArgument>(e)) {
-    // Direct arguments of the original linalg op must be converted
-    // into dense tensor loads. Note that we should not encounter
-    // anything else. This needs to be verified by semi-ring ops.
+    // Direct arguments of the original linalg op must be converted into
+    // tensor element loads. This handles both dense tensor loads (using
+    // current loop coordinates) and sparse tensor loads (using the current
+    // value position tracked by the loop emitter).
     linalg::GenericOp op = env.op();
     if (arg.getOwner()->getParentOp() == op) {
       const TensorId tid = env.makeTensorId(arg.getArgNumber());
       OpOperand *t = &op->getOpOperand(tid);
-      assert(!getSparseTensorType(t->get()).hasEncoding()); // dense!
       SmallVector<Value> args;
       Value ptr = genSubscript(env, rewriter, t, args);
-      return memref::LoadOp::create(rewriter, op.getLoc(), ptr, args);
+      Location loc = op.getLoc();
+      if (llvm::isa<TensorType>(ptr.getType())) {
+        // kSparseIterator strategy: extract value at the iterator position.
+        assert(env.options().sparseEmitStrategy ==
+               SparseEmitStrategy::kSparseIterator);
+        return ExtractValOp::create(rewriter, loc, ptr,
+                                    llvm::getSingleElement(args));
+      }
+      return memref::LoadOp::create(rewriter, loc, ptr, args);
     }
   } else if (Operation *def = e.getDefiningOp()) {
     // Handle index computation.
diff --git a/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir b/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir
index 0c73d2fe8a079..328f217f8754e 100644
--- a/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir
+++ b/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir
@@ -71,3 +71,74 @@ func.func @sparse_sampled_dd(%argA: tensor<8x8xf64>,
   } -> tensor<8x8xf64, #SM>
   return %result : tensor<8x8xf64, #SM>
 }
+
+//
+// Variant where the present block directly captures the sparse input (%s).
+// This previously crashed with an assertion failure in relinkBranch.
+//
+
+#trait_sddmm_scaled = {
+  indexing_maps = [
+    affine_map<(i,j,k) -> (i,k)>,  // A
+    affine_map<(i,j,k) -> (k,j)>,  // B
+    affine_map<(i,j,k) -> (i,j)>   // S
+  ],
+  iterator_types = ["parallel", "parallel", "reduction"],
+  doc = "S(i,j) += S(i,j) * SUM_k A(i,k) B(k,j)"
+}
+
+// CHECK-LABEL: func.func @sparse_unary_captures_sparse_arg(
+// CHECK-SAME:    %[[VAL_0:.*0]]: tensor<8x8xf64>,
+// CHECK-SAME:    %[[VAL_1:.*1]]: tensor<8x8xf64>,
+// CHECK-SAME:    %[[VAL_2:.*2]]: tensor<8x8xf64, #sparse{{[0-9]*}}>) -> tensor<8x8xf64, #sparse{{[0-9]*}}> {
+// CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[VAL_6:.*]] = bufferization.to_buffer %[[VAL_0]] : tensor<8x8xf64> to memref<8x8xf64>
+// CHECK-DAG:     %[[VAL_7:.*]] = bufferization.to_buffer %[[VAL_1]] : tensor<8x8xf64> to memref<8x8xf64>
+// CHECK-DAG:     %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<8x8xf64, #sparse{{[0-9]*}}> to memref<?xf64>
+// CHECK-DAG:     %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_2]] {level = 1 : index} : tensor<8x8xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG:     %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_2]] {level = 1 : index} : tensor<8x8xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK:         scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
+// CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
+// CHECK:             scf.for %[[VAL_13:.*]] = {{.*}} to {{.*}} step %[[VAL_5]] {
+// CHECK:               %[[VAL_14:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_13]]] : memref<?xindex>
+// CHECK:               %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_13]]] : memref<?xf64>
+// CHECK:               %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<8x8xf64>
+// CHECK:               %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]], %[[VAL_14]]] : memref<8x8xf64>
+// CHECK:               %[[VAL_18:.*]] = arith.mulf %[[VAL_16]], %[[VAL_17]] : f64
+// CHECK:               %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_13]]] : memref<?xf64>
+// CHECK:               %[[VAL_20:.*]] = arith.mulf %[[VAL_19]], %[[VAL_18]] : f64
+// CHECK:               %[[VAL_21:.*]] = arith.addf %[[VAL_15]], %[[VAL_20]] : f64
+// CHECK:               memref.store %[[VAL_21]], %[[VAL_8]]{{\[}}%[[VAL_13]]] : memref<?xf64>
+// CHECK:             } {"Emitted from" = "linalg.generic"}
+// CHECK:           } {"Emitted from" = "linalg.generic"}
+// CHECK:         } {"Emitted from" = "linalg.generic"}
+// CHECK:         return
+// CHECK:       }
+func.func @sparse_unary_captures_sparse_arg(%argA: tensor<8x8xf64>,
+                                            %argB: tensor<8x8xf64>,
+                                            %argS: tensor<8x8xf64, #SM>) -> tensor<8x8xf64, #SM> {
+  %f0 = arith.constant 0.0 : f64
+  %result = linalg.generic #trait_sddmm_scaled
+    ins(%argA, %argB: tensor<8x8xf64>, tensor<8x8xf64>) outs(%argS: tensor<8x8xf64, #SM>) {
+      ^bb(%a: f64, %b: f64, %s: f64):
+        // The present block captures %s (a sparse tensor block arg) directly.
+        // This used to crash in relinkBranch with an assertion failure.
+        %u = sparse_tensor.unary %s : f64 to f64
+          present={
+            ^bb0(%p: f64):
+              %mul1 = arith.mulf %a, %b : f64
+              %mul2 = arith.mulf %s, %mul1 : f64
+              sparse_tensor.yield %mul2 : f64
+          }
+          absent={}
+        %r = sparse_tensor.reduce %s, %u, %f0 : f64 {
+          ^bb0(%p: f64, %q: f64):
+            %add = arith.addf %p, %q : f64
+            sparse_tensor.yield %add : f64
+        }
+        linalg.yield %r : f64
+  } -> tensor<8x8xf64, #SM>
+  return %result : tensor<8x8xf64, #SM>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/184597


More information about the Mlir-commits mailing list