[Mlir-commits] [mlir] [mlir][linalg] Avoid asserts in IndexingMapOpInterface (PR #179072)

Samarth Narang llvmlistbot at llvm.org
Sat Jan 31 17:13:27 PST 2026


https://github.com/snarang181 updated https://github.com/llvm/llvm-project/pull/179072

>From 84eb812906871f074f4b859a12d2974ca4e68ea1 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Sat, 31 Jan 2026 17:17:28 -0500
Subject: [PATCH 1/3] [mlir][linalg] Avoid asserts in IndexingMapOpInterface

---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.td       | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 9f1e88a040f5f..365d5344953b6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -385,9 +385,9 @@ def LinalgStructuredInterface
           return 0;
         // Tensor and Memref container types have a rank.
         if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
-          // Failsafe.
-          assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
-                 "expected a ranked tensor or memref in LinalgInterface::getRank");
+          // Only ranked tensors and MemRefs have well defined ranks.
+          if (!(isa<MemRefType>(t) || isa<RankedTensorType>(t)))
+            return 0;
           return shapedType.getRank();
         }
         return 0;
@@ -703,9 +703,9 @@ def LinalgStructuredInterface
         if (isa<VectorType>(t))
           return {};
         if (auto shapedType = ::llvm::dyn_cast<ShapedType>(t)) {
-          // Failsafe.
-          assert((isa<MemRefType>(t) || isa<RankedTensorType>(t)) &&
-                 "expected a ranked tensor or memref in LinalgInterface::getRank");
+          // Only ranked tensors and MemRefs have well defined shapes.
+          if (!(isa<MemRefType>(t) || isa<RankedTensorType>(t)))
+            return {};
           return shapedType.getShape();
         }
         return {};

>From 2ec78d322320cd95b2a13302c85b86baa67ce756 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Sat, 31 Jan 2026 17:41:48 -0500
Subject: [PATCH 2/3] Add test case

---
 mlir/test/Dialect/Linalg/invalid.mlir | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 355d801f8732c..d97b5ee31f97b 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -217,6 +217,25 @@ func.func @generic_indexing_map_with_symbol(%arg0: tensor<8xf32>) -> tensor<8xf3
   return %0 : tensor<8xf32>
 }
 
+// -----
+
+// Unranked tensor inputs must be diagnosed.
+func.func @generic_unranked_input_tensor(%in: tensor<*xf32>) {
+  %out = tensor.empty() : tensor<16x16xf32>
+  // expected-error @+1 {{expected operand #0 rank (0) to match the result rank of indexing_map (2)}}
+  %r = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]}
+      ins(%in : tensor<*xf32>)
+     outs(%out : tensor<16x16xf32>) {
+      ^bb0(%a: f32, %b: f32):
+        linalg.yield %a : f32
+     } -> tensor<16x16xf32>
+  return
+}
+
+
 ////////////////////////////////////////////////////////////////////////////////
 ///////////////////////////// Region tests /////////////////////////////////////
 ////////////////////////////////////////////////////////////////////////////////

>From cde2e79661df67d57b8fb5a0ec7c254147d545c1 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Sat, 31 Jan 2026 20:13:09 -0500
Subject: [PATCH 3/3] Make diagnostics better

---
 .../lib/Interfaces/IndexingMapOpInterface.cpp | 29 ++++++++++++++++++-
 mlir/test/Dialect/Linalg/invalid.mlir         |  2 +-
 2 files changed, 29 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
index 2ef36a21a1ac0..115408da5af6f 100644
--- a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
+++ b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
@@ -35,12 +35,39 @@ LogicalResult mlir::IndexingMapOpInterface::verifyImpl() {
       return this->emitOpError("unexpected symbols in indexing_map #")
              << opOperand.getOperandNumber();
 
+    Type operandTy = opOperand.get().getType();
+
+    // Ranked container.
+    auto rankedTensorTy = dyn_cast<RankedTensorType>(operandTy);
+    auto memrefTy = dyn_cast<MemRefType>(operandTy);
+
+    // Unranked tensor is an "unknown rank".
+    bool unrankedTensor = isa<UnrankedTensorType>(operandTy);
+
+    // Other shaped (but not supported) types are invalid.
+    bool shapedButNotSupported = isa<ShapedType>(operandTy) &&
+                                 !rankedTensorTy && !memrefTy &&
+                                 !unrankedTensor && !isa<VectorType>(operandTy);
+
     // Result rank must match operand rank.
-    if (indexingMap.getNumResults() != rank)
+    if (indexingMap.getNumResults() != rank) {
+      // If this operand does not have a meaningful rank,
+      // emit a type-based diagnostic instead of a rank-based one.
+      if (unrankedTensor || shapedButNotSupported) {
+        return this->emitOpError("expected operand #")
+               << opOperand.getOperandNumber()
+               << " to be a ranked tensor or memref type to match the result "
+                  "rank of "
+                  "indexing_map ("
+               << indexingMap.getNumResults() << "), but got " << operandTy;
+      }
+
+      // Scalars/vectors (and true rank-0 cases).
       return this->emitOpError("expected operand #")
              << opOperand.getOperandNumber() << " rank (" << rank
              << ") to match the result rank of indexing_map ("
              << indexingMap.getNumResults() << ")";
+    }
 
     llvm::append_range(allShapesSizes, shape);
   }
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index d97b5ee31f97b..031a9ad403a4a 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -222,7 +222,7 @@ func.func @generic_indexing_map_with_symbol(%arg0: tensor<8xf32>) -> tensor<8xf3
 // Unranked tensor inputs must be diagnosed.
 func.func @generic_unranked_input_tensor(%in: tensor<*xf32>) {
   %out = tensor.empty() : tensor<16x16xf32>
-  // expected-error @+1 {{expected operand #0 rank (0) to match the result rank of indexing_map (2)}}
+  // expected-error @+1 {{'linalg.generic' op expected operand #0 to be a ranked tensor or memref type to match the result rank of indexing_map (2), but got 'tensor<*xf32>'}}
   %r = linalg.generic {
     indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                      affine_map<(d0, d1) -> (d0, d1)>],



More information about the Mlir-commits mailing list