[Mlir-commits] [mlir] 0f59753 - [mlir][sparse] Fix crash in sparsification when unary/binary present block captures sparse tensor argument (#184597)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 5 03:50:18 PST 2026
Author: Mehdi Amini
Date: 2026-03-05T12:50:13+01:00
New Revision: 0f59753a422d6dcbe28ab07b6f01efe131375fbe
URL: https://github.com/llvm/llvm-project/commit/0f59753a422d6dcbe28ab07b6f01efe131375fbe
DIFF: https://github.com/llvm/llvm-project/commit/0f59753a422d6dcbe28ab07b6f01efe131375fbe.diff
LOG: [mlir][sparse] Fix crash in sparsification when unary/binary present block captures sparse tensor argument (#184597)
`relinkBranch` in Sparsification.cpp assumed that any block argument
from the outer `linalg.generic` op encountered inside an inlined
semi-ring branch must be a dense tensor, and asserted accordingly.
However, the `present` block of a `sparse_tensor.unary` (or similar
semi-ring ops) is permitted to capture sparse tensor operands directly
via `isAdmissibleBranchExp`, which accepts any `BlockArgument` as
admissible.
The fix removes the incorrect assertion and extends the load generation
to handle sparse tensors using `genSubscript`, which already knows how
to return the value buffer and current value position via the loop
emitter. The `kSparseIterator` strategy (where `genSubscript` returns a
`TensorType`) is also handled by emitting a
`sparse_tensor.extract_value` op.
Fixes #91183
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/test/Dialect/SparseTensor/spy_sddmm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 0a5f5595bba56..6004ab26f4663 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -588,17 +588,25 @@ inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
Value e) {
if (auto arg = dyn_cast<BlockArgument>(e)) {
- // Direct arguments of the original linalg op must be converted
- // into dense tensor loads. Note that we should not encounter
- // anything else. This needs to be verified by semi-ring ops.
+ // Direct arguments of the original linalg op must be converted into
+ // tensor element loads. This handles both dense tensor loads (using
+ // current loop coordinates) and sparse tensor loads (using the current
+ // value position tracked by the loop emitter).
linalg::GenericOp op = env.op();
if (arg.getOwner()->getParentOp() == op) {
const TensorId tid = env.makeTensorId(arg.getArgNumber());
OpOperand *t = &op->getOpOperand(tid);
- assert(!getSparseTensorType(t->get()).hasEncoding()); // dense!
SmallVector<Value> args;
Value ptr = genSubscript(env, rewriter, t, args);
- return memref::LoadOp::create(rewriter, op.getLoc(), ptr, args);
+ Location loc = op.getLoc();
+ if (llvm::isa<TensorType>(ptr.getType())) {
+ // kSparseIterator strategy: extract value at the iterator position.
+ assert(env.options().sparseEmitStrategy ==
+ SparseEmitStrategy::kSparseIterator);
+ return ExtractValOp::create(rewriter, loc, ptr,
+ llvm::getSingleElement(args));
+ }
+ return memref::LoadOp::create(rewriter, loc, ptr, args);
}
} else if (Operation *def = e.getDefiningOp()) {
// Handle index computation.
diff --git a/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir b/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir
index 0c73d2fe8a079..328f217f8754e 100644
--- a/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir
+++ b/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir
@@ -71,3 +71,74 @@ func.func @sparse_sampled_dd(%argA: tensor<8x8xf64>,
} -> tensor<8x8xf64, #SM>
return %result : tensor<8x8xf64, #SM>
}
+
+//
+// Variant where the present block directly captures the sparse input (%s).
+// This previously crashed with an assertion failure in relinkBranch.
+//
+
+#trait_sddmm_scaled = {
+ indexing_maps = [
+ affine_map<(i,j,k) -> (i,k)>, // A
+ affine_map<(i,j,k) -> (k,j)>, // B
+ affine_map<(i,j,k) -> (i,j)> // S
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ doc = "S(i,j) += S(i,j) * SUM_k A(i,k) B(k,j)"
+}
+
+// CHECK-LABEL: func.func @sparse_unary_captures_sparse_arg(
+// CHECK-SAME: %[[VAL_0:.*0]]: tensor<8x8xf64>,
+// CHECK-SAME: %[[VAL_1:.*1]]: tensor<8x8xf64>,
+// CHECK-SAME: %[[VAL_2:.*2]]: tensor<8x8xf64, #sparse{{[0-9]*}}>) -> tensor<8x8xf64, #sparse{{[0-9]*}}> {
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_6:.*]] = bufferization.to_buffer %[[VAL_0]] : tensor<8x8xf64> to memref<8x8xf64>
+// CHECK-DAG: %[[VAL_7:.*]] = bufferization.to_buffer %[[VAL_1]] : tensor<8x8xf64> to memref<8x8xf64>
+// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<8x8xf64, #sparse{{[0-9]*}}> to memref<?xf64>
+// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_2]] {level = 1 : index} : tensor<8x8xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_2]] {level = 1 : index} : tensor<8x8xf64, #sparse{{[0-9]*}}> to memref<?xindex>
+// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
+// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
+// CHECK: scf.for %[[VAL_13:.*]] = {{.*}} to {{.*}} step %[[VAL_5]] {
+// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_13]]] : memref<?xindex>
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_13]]] : memref<?xf64>
+// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<8x8xf64>
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]], %[[VAL_14]]] : memref<8x8xf64>
+// CHECK: %[[VAL_18:.*]] = arith.mulf %[[VAL_16]], %[[VAL_17]] : f64
+// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_13]]] : memref<?xf64>
+// CHECK: %[[VAL_20:.*]] = arith.mulf %[[VAL_19]], %[[VAL_18]] : f64
+// CHECK: %[[VAL_21:.*]] = arith.addf %[[VAL_15]], %[[VAL_20]] : f64
+// CHECK: memref.store %[[VAL_21]], %[[VAL_8]]{{\[}}%[[VAL_13]]] : memref<?xf64>
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: return
+// CHECK: }
+func.func @sparse_unary_captures_sparse_arg(%argA: tensor<8x8xf64>,
+ %argB: tensor<8x8xf64>,
+ %argS: tensor<8x8xf64, #SM>) -> tensor<8x8xf64, #SM> {
+ %f0 = arith.constant 0.0 : f64
+ %result = linalg.generic #trait_sddmm_scaled
+ ins(%argA, %argB: tensor<8x8xf64>, tensor<8x8xf64>) outs(%argS: tensor<8x8xf64, #SM>) {
+ ^bb(%a: f64, %b: f64, %s: f64):
+ // The present block captures %s (a sparse tensor block arg) directly.
+ // This used to crash in relinkBranch with an assertion failure.
+ %u = sparse_tensor.unary %s : f64 to f64
+ present={
+ ^bb0(%p: f64):
+ %mul1 = arith.mulf %a, %b : f64
+ %mul2 = arith.mulf %s, %mul1 : f64
+ sparse_tensor.yield %mul2 : f64
+ }
+ absent={}
+ %r = sparse_tensor.reduce %s, %u, %f0 : f64 {
+ ^bb0(%p: f64, %q: f64):
+ %add = arith.addf %p, %q : f64
+ sparse_tensor.yield %add : f64
+ }
+ linalg.yield %r : f64
+ } -> tensor<8x8xf64, #SM>
+ return %result : tensor<8x8xf64, #SM>
+}
More information about the Mlir-commits
mailing list