[Mlir-commits] [mlir] [mlir][sparse] Fix crash in sparsification when disjunctive op has sparse operand (PR #184599)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 4 04:00:13 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
Author: Mehdi Amini (joker-eph)
<details>
<summary>Changes</summary>
When `buildTensorExp` processes a disjunctive binary op (add, sub, or, xor), it returned `hasSparseDep = disjSpVals = xSpVals && ySpVals` — true only when *both* operands reference a sparse tensor. For an expression like `math.exp(arith.subf(sparse_in, dense_in))`, the inner `arith.subf` has `xSpVals=true, ySpVals=false` → `disjSpVals=false`, incorrectly satisfying the `\!hasSparseDep` condition used to create a `kDenseOp` wrapper node.
The spurious `kDenseOp` inherits the disjunctive lattice of its child (`kSubF`), producing an "arg1-only" lattice point with an undefined level type at the loop being generated. `optimizeSet` does not eliminate this point because the XOR with the base point contains a sparse bit (from `sparse_in`). When code generation reaches this lattice point, the synthetic-tensor iterator is used uninitialized (null cursor), causing a crash in `TrivialIterator::derefImpl`.
Fix: use `conjSpVals = xSpVals || ySpVals` (any-sparse) for all disjunctive ops in `buildTensorExp`. Since the return value's `hasSparseDep` field is only consumed internally within `buildTensorExp` for the `kDenseOp` creation check, this is safe. With the fix, any expression containing a sparse-tensor operand prevents `kDenseOp` wrapping, causing `initTensorExp` to return failure and `matchAndRewrite` to gracefully leave the `linalg.generic` unmodified.
Fixes #<!-- -->114855
---
Full diff: https://github.com/llvm/llvm-project/pull/184599.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+8-9)
- (modified) mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir (+32)
``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 5847fecc45404..2837361a3a3f5 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1402,7 +1402,6 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
// is sparse. For a disjunctive operation, it yields a "sparse" result if
// all operands are sparse.
bool conjSpVals = xSpVals || ySpVals;
- bool disjSpVals = xSpVals && ySpVals;
if (x.has_value() && y.has_value()) {
const ExprId e0 = *x;
const ExprId e1 = *y;
@@ -1421,23 +1420,23 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
return {addExp(TensorExp::Kind::kDivU, e0, e1), conjSpVals};
if (isa<arith::AddFOp>(def))
- return {addExp(TensorExp::Kind::kAddF, e0, e1), disjSpVals};
+ return {addExp(TensorExp::Kind::kAddF, e0, e1), conjSpVals};
if (isa<complex::AddOp>(def))
- return {addExp(TensorExp::Kind::kAddC, e0, e1), disjSpVals};
+ return {addExp(TensorExp::Kind::kAddC, e0, e1), conjSpVals};
if (isa<arith::AddIOp>(def))
- return {addExp(TensorExp::Kind::kAddI, e0, e1), disjSpVals};
+ return {addExp(TensorExp::Kind::kAddI, e0, e1), conjSpVals};
if (isa<arith::SubFOp>(def))
- return {addExp(TensorExp::Kind::kSubF, e0, e1), disjSpVals};
+ return {addExp(TensorExp::Kind::kSubF, e0, e1), conjSpVals};
if (isa<complex::SubOp>(def))
- return {addExp(TensorExp::Kind::kSubC, e0, e1), disjSpVals};
+ return {addExp(TensorExp::Kind::kSubC, e0, e1), conjSpVals};
if (isa<arith::SubIOp>(def))
- return {addExp(TensorExp::Kind::kSubI, e0, e1), disjSpVals};
+ return {addExp(TensorExp::Kind::kSubI, e0, e1), conjSpVals};
if (isa<arith::AndIOp>(def))
return {addExp(TensorExp::Kind::kAndI, e0, e1), conjSpVals};
if (isa<arith::OrIOp>(def))
- return {addExp(TensorExp::Kind::kOrI, e0, e1), disjSpVals};
+ return {addExp(TensorExp::Kind::kOrI, e0, e1), conjSpVals};
if (isa<arith::XOrIOp>(def))
- return {addExp(TensorExp::Kind::kXorI, e0, e1), disjSpVals};
+ return {addExp(TensorExp::Kind::kXorI, e0, e1), conjSpVals};
if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
return {addExp(TensorExp::Kind::kShrS, e0, e1), conjSpVals};
if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
diff --git a/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir b/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir
index ea29f1d677eff..14abad77a9177 100644
--- a/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir
+++ b/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir
@@ -93,3 +93,35 @@ func.func @dense_op_with_sp_dep(%169: tensor<2x10x8xf32>,
} -> tensor<2x10x100xf32>
return %179 : tensor<2x10x100xf32>
}
+
+//
+// This kernel cannot be sparsified: the unsparsifiable op (math.exp) takes
+// a result of arith.subf whose first operand is a sparse tensor. Even though
+// arith.subf is a disjunctive op (result is 0 when LHS=0 regardless of RHS),
+// it still carries a sparse tensor dependency so kDenseOp wrapping is invalid.
+// Regression test for github.com/llvm/llvm-project/issues/114855.
+//
+#sparse3d = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense) }>
+
+// CHECK-LABEL: func @dense_op_with_sparse_out_and_sp_dep
+// CHECK: linalg.generic {{.*}}
+func.func @dense_op_with_sparse_out_and_sp_dep(
+ %arg0: tensor<2x3x4xf32, #sparse3d>,
+ %arg1: tensor<2x4xf32>) -> tensor<2x3x4xf32, #sparse3d> {
+ %0 = tensor.empty() : tensor<2x3x4xf32, #sparse3d>
+ %1 = linalg.generic {
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+ ],
+ iterator_types = ["parallel", "parallel", "parallel"]}
+ ins(%arg0, %arg1 : tensor<2x3x4xf32, #sparse3d>, tensor<2x4xf32>)
+ outs(%0 : tensor<2x3x4xf32, #sparse3d>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %2 = arith.subf %in, %in_0 : f32
+ %3 = math.exp %2 : f32
+ linalg.yield %3 : f32
+ } -> tensor<2x3x4xf32, #sparse3d>
+ return %1 : tensor<2x3x4xf32, #sparse3d>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/184599
More information about the Mlir-commits
mailing list