[Mlir-commits] [mlir] [mlir][SparseTensor] add `numSymbols` information to simplify affine expressions (PR #191649)

Vito Secona llvmlistbot at llvm.org
Sat Apr 11 10:41:26 PDT 2026


https://github.com/secona created https://github.com/llvm/llvm-project/pull/191649

Previously, the `translateShape` function hard-coded the `numSymbols` parameter to 0. This makes the affine expression fail when the sparse tensor encoding has symbols.

This PR fixes the issue by extracting and passing the `numSymbols` information during translation. A regression test has also been added to ensure this behavior remains supported.

Closes #191209

>From b09d624b72b00fa428d9c08af88bbf9c35becb95 Mon Sep 17 00:00:00 2001
From: Vito Secona <secona00 at gmail.com>
Date: Sun, 12 Apr 2026 00:28:41 +0700
Subject: [PATCH] add numSymbols information when simplifying affine expr

---
 .../SparseTensor/IR/SparseTensorDialect.cpp   |  9 +++++--
 .../SparseTensor/encoding_with_symbols.mlir   | 26 +++++++++++++++++++
 2 files changed, 33 insertions(+), 2 deletions(-)
 create mode 100644 mlir/test/Dialect/SparseTensor/encoding_with_symbols.mlir

diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index b77a536861d2a..eab2d14797257 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -525,10 +525,15 @@ SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
     }
   };
 
+  // The number of symbols information is included inside the `dimToLvl` map
+  // during parsing. Here, we're extracting it to be used when simplifying the
+  // affine expression.
+  unsigned numSymbols = getDimToLvl().getNumSymbols();
+
   for (AffineExpr exp : transMap.getResults()) {
     // Do constant propagation on the affine map.
-    AffineExpr evalExp =
-        simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
+    AffineExpr evalExp = simplifyAffineExpr(exp.replaceDims(dimRep),
+                                            srcShape.size(), numSymbols);
     // use llvm namespace here to avoid ambiguity
     if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
       ret.push_back(c.getValue() + 1);
diff --git a/mlir/test/Dialect/SparseTensor/encoding_with_symbols.mlir b/mlir/test/Dialect/SparseTensor/encoding_with_symbols.mlir
new file mode 100644
index 0000000000000..7cd68ee00dd09
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/encoding_with_symbols.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -sparsification-and-bufferization | FileCheck %s
+
+// Tests that mlir-opt does not crash when parsing sparse tensor encodings with symbols.
+
+// CHECK-DAG: #[[$SPARSE_0:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed) }>
+// CHECK-DAG: #[[$SPARSE_1:.*]] = #sparse_tensor.encoding<{ map = [s0](d0, d1) -> (d0 * (s0 * 3) : dense, d0 : dense, d1 : compressed) }>
+
+#Sparse = #sparse_tensor.encoding<{
+  map = [c](i, j) -> (c * 3 * i : dense, i : dense, j : compressed)
+}>
+
+// CHECK-LABEL: func.func @tensor_add(
+// CHECK-SAME:      %{{.*}}: memref<?xindex>, %{{.*}}: memref<?xindex>, %{{.*}}: memref<?xf32>,
+// CHECK-SAME:      %{{.*}}: !sparse_tensor.storage_specifier<#[[$SPARSE_0]]>) -> memref<8x8xf32> {
+func.func @tensor_add(%arg0: tensor<8x8xf32, #Sparse>) -> tensor<8x8xf32> {
+  %result_out = tensor.empty() : tensor<8x8xf32>
+
+  // CHECK: %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x8xf32>
+  // CHECK: %[[RES:.*]] = linalg.add ins(%{{.*}}, %{{.*}} : tensor<8x8xf32, #[[$SPARSE_1]]>, tensor<8x8xf32, #[[$SPARSE_1]]>)
+  %result = linalg.add
+    ins(%arg0, %arg0 : tensor<8x8xf32, #Sparse>, tensor<8x8xf32, #Sparse>)
+    outs(%result_out : tensor<8x8xf32>) -> tensor<8x8xf32>
+
+  // CHECK: return %{{.*}} : memref<8x8xf32>
+  return %result : tensor<8x8xf32>
+}



More information about the Mlir-commits mailing list