[Mlir-commits] [mlir] [MLIR][IndexingMapOpInterface]: Validate maps and operands before composing loop ranges (PR #173434)

Stefan Weigl-Bosker llvmlistbot at llvm.org
Sat Jan 3 08:51:06 PST 2026


https://github.com/sweiglbosker updated https://github.com/llvm/llvm-project/pull/173434

>From 8d84f9f67b45d8ce499708f315aef867b3bad5c2 Mon Sep 17 00:00:00 2001
From: Stefan Weigl-Bosker <stefan at s00.xyz>
Date: Tue, 23 Dec 2025 19:04:52 -0500
Subject: [PATCH 1/6] [MLIR][Interface]: Verify index map ranks before
 composing loop bounds

---
 .../lib/Interfaces/IndexingMapOpInterface.cpp | 26 +++++++-------
 mlir/test/Dialect/Linalg/invalid.mlir         | 35 +++++++++++++++++++
 2 files changed, 48 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
index f3c12aed8df84..463276011acdb 100644
--- a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
+++ b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
@@ -32,35 +32,35 @@ LogicalResult mlir::IndexingMapOpInterface::verifyImpl() {
            << "(" << str << ")";
   }
 
-  SmallVector<int64_t> endLoopRangeValues = getStaticLoopRanges();
+  SmallVector<int64_t> allShapesSizes;
 
-  // Set this flag if this op has user defined maps. This is required to guard
-  // the below error condition which assume default indexing maps.
   for (OpOperand &opOperand : getOperation()->getOpOperands()) {
     AffineMap indexingMap = getMatchingIndexingMap(&opOperand);
+    SmallVector<int64_t> shape = getStaticOperandShape(&opOperand);
+    int64_t rank = shape.size();
 
     // Symbols disallowed.
     if (indexingMap.getNumSymbols() != 0)
       return getOperation()->emitOpError("unexpected symbols in indexing_map #")
              << opOperand.getOperandNumber();
 
-    // Domain must be consistent.
-    if (indexingMap.getNumDims() != endLoopRangeValues.size())
-      return getOperation()->emitOpError("expected indexing_map #")
-             << opOperand.getOperandNumber() << " to have "
-             << endLoopRangeValues.size()
-             << " dim(s) to match the number of loops";
-
-    SmallVector<int64_t> shape = getStaticOperandShape(&opOperand);
-    int64_t rank = shape.size();
-
+    // Result rank must match operand rank.
     if (indexingMap.getNumResults() != rank)
       return getOperation()->emitOpError("expected operand rank (")
              << rank << ") to match the result rank of indexing_map #"
              << opOperand.getOperandNumber() << " ("
              << indexingMap.getNumResults() << ")";
+
+    llvm::append_range(allShapesSizes, shape);
   }
 
+  SmallVector<int64_t> endLoopRangeValues = invertedMap.compose(allShapesSizes);
+
+  if (invertedMap.getNumResults() != endLoopRangeValues.size())
+    return getOperation()->emitOpError("expected each indexing_map to have ")
+           << endLoopRangeValues.size()
+           << " dim(s) to match the number of loops";
+
   // Check if given shapes match to inferred shapes.
   SmallVector<int64_t> startLoopRangeValues(endLoopRangeValues.size(), 0);
   // Verify only static cases since we can't get exact dimension sizes and
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 1f554e6c45da7..af9112b7c1f74 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -165,6 +165,41 @@ func.func @generic_singular_maps(%arg0: memref<?xf32, affine_map<(i)[off]->(off
   }
 }
 
+// -----
+
+func.func @generic_index_rank0(%arg0: tensor<f32>) -> tensor<f32> {
+// expected-error @+1 {{op expected operand rank (0) to match the result rank of indexing_map #0 (1)}}
+  %0 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0) -> (d0)>,
+      affine_map<(d0) -> (d0)>
+    ],
+    iterator_types = ["parallel"]}
+      ins(%arg0 : tensor<f32>)
+     outs(%arg0 : tensor<f32>) {
+  ^bb(%0: f32, %1: f32):
+    linalg.yield %1 : f32
+  } -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @generic_index_domain_error(%arg0: tensor<4xf32>) -> tensor<4xf32> {
+// expected-error @+1 {{op expected operand rank (1) to match the result rank of indexing_map #1 (2)}}
+  %0 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0) -> (d0)>,
+      affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]}
+      ins(%arg0 : tensor<4xf32>)
+     outs(%arg0 : tensor<4xf32>) {
+  ^bb(%0: f32):
+    linalg.yield %0 : f32
+  } -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 ///////////////////////////// Region tests /////////////////////////////////////
 ////////////////////////////////////////////////////////////////////////////////

>From 54c5f040b03b842d7cebe699ba0a0b817aef5506 Mon Sep 17 00:00:00 2001
From: Stefan Weigl-Bosker <stefan at s00.xyz>
Date: Fri, 26 Dec 2025 21:04:07 -0500
Subject: [PATCH 2/6] Check there are no symbols before inverting map

---
 .../lib/Interfaces/IndexingMapOpInterface.cpp | 22 +++++++++----------
 mlir/test/Dialect/Linalg/invalid.mlir         | 18 +++++++++++++++
 2 files changed, 29 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
index 463276011acdb..9df8832c9e39e 100644
--- a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
+++ b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
@@ -23,15 +23,6 @@ LogicalResult mlir::IndexingMapOpInterface::verifyImpl() {
            << ") to be equal to the number of input/output operands ("
            << getOperation()->getNumOperands() << ")";
 
-  AffineMap invertedMap = getShapesToLoopsMap();
-  if (!invertedMap) {
-    std::string str;
-    llvm::raw_string_ostream os(str);
-    getLoopsToShapesMap().print(os);
-    return this->emitOpError("invalid indexing maps are non-invertible: ")
-           << "(" << str << ")";
-  }
-
   SmallVector<int64_t> allShapesSizes;
 
   for (OpOperand &opOperand : getOperation()->getOpOperands()) {
@@ -41,12 +32,12 @@ LogicalResult mlir::IndexingMapOpInterface::verifyImpl() {
 
     // Symbols disallowed.
     if (indexingMap.getNumSymbols() != 0)
-      return getOperation()->emitOpError("unexpected symbols in indexing_map #")
+      return this->emitOpError("unexpected symbols in indexing_map #")
              << opOperand.getOperandNumber();
 
     // Result rank must match operand rank.
     if (indexingMap.getNumResults() != rank)
-      return getOperation()->emitOpError("expected operand rank (")
+      return this->emitOpError("expected operand rank (")
              << rank << ") to match the result rank of indexing_map #"
              << opOperand.getOperandNumber() << " ("
              << indexingMap.getNumResults() << ")";
@@ -54,6 +45,15 @@ LogicalResult mlir::IndexingMapOpInterface::verifyImpl() {
     llvm::append_range(allShapesSizes, shape);
   }
 
+  AffineMap invertedMap = getShapesToLoopsMap();
+  if (!invertedMap) {
+    std::string str;
+    llvm::raw_string_ostream os(str);
+    getLoopsToShapesMap().print(os);
+    return this->emitOpError("invalid indexing maps are non-invertible: ")
+           << "(" << str << ")";
+  }
+
   SmallVector<int64_t> endLoopRangeValues = invertedMap.compose(allShapesSizes);
 
   if (invertedMap.getNumResults() != endLoopRangeValues.size())
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index af9112b7c1f74..2406764b65919 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -200,6 +200,24 @@ func.func @generic_index_domain_error(%arg0: tensor<4xf32>) -> tensor<4xf32> {
   return %0 : tensor<4xf32>
 }
 
+// -----
+
+
+#map_with_symbol = affine_map<(d0)[s0] -> (d0 + s0)>
+
+func.func @generic_indexing_map_with_symbol(%arg0: tensor<8xf32>) -> tensor<8xf32> {
+  // expected-error @+1 {{unexpected symbols in indexing_map #0}}
+  %0 = linalg.generic {
+    indexing_maps = [#map_with_symbol, #map_with_symbol],
+    iterator_types = ["parallel"]
+  } ins(%arg0 : tensor<8xf32>)
+    outs(%arg0 : tensor<8xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<8xf32>
+  return %0 : tensor<8xf32>
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 ///////////////////////////// Region tests /////////////////////////////////////
 ////////////////////////////////////////////////////////////////////////////////

>From 7071ef437edcc765705397861f30421469db4ec9 Mon Sep 17 00:00:00 2001
From: Stefan Weigl-Bosker <stefan at s00.xyz>
Date: Sat, 27 Dec 2025 11:52:39 -0500
Subject: [PATCH 3/6] fixes

---
 mlir/lib/Interfaces/IndexingMapOpInterface.cpp | 5 +++--
 mlir/test/Dialect/Linalg/invalid.mlir          | 1 -
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
index 9df8832c9e39e..1e2a6aa28d409 100644
--- a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
+++ b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
@@ -37,8 +37,9 @@ LogicalResult mlir::IndexingMapOpInterface::verifyImpl() {
 
     // Result rank must match operand rank.
     if (indexingMap.getNumResults() != rank)
-      return this->emitOpError("expected operand rank (")
-             << rank << ") to match the result rank of indexing_map #"
+      return this->emitOpError("expected operand #")
+             << opOperand.getOperandNumber() << " rank (" << rank
+             << ") to match the result rank of indexing_map #"
              << opOperand.getOperandNumber() << " ("
              << indexingMap.getNumResults() << ")";
 
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 2406764b65919..8a22bdd4cb952 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -202,7 +202,6 @@ func.func @generic_index_domain_error(%arg0: tensor<4xf32>) -> tensor<4xf32> {
 
 // -----
 
-
 #map_with_symbol = affine_map<(d0)[s0] -> (d0 + s0)>
 
 func.func @generic_indexing_map_with_symbol(%arg0: tensor<8xf32>) -> tensor<8xf32> {

>From 96fc911fb6960e9906c11315db8b7c986ece7f26 Mon Sep 17 00:00:00 2001
From: Stefan Weigl-Bosker <stefan at s00.xyz>
Date: Sat, 27 Dec 2025 12:27:25 -0500
Subject: [PATCH 4/6] revert error message change

---
 mlir/lib/Interfaces/IndexingMapOpInterface.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
index 1e2a6aa28d409..9df8832c9e39e 100644
--- a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
+++ b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
@@ -37,9 +37,8 @@ LogicalResult mlir::IndexingMapOpInterface::verifyImpl() {
 
     // Result rank must match operand rank.
     if (indexingMap.getNumResults() != rank)
-      return this->emitOpError("expected operand #")
-             << opOperand.getOperandNumber() << " rank (" << rank
-             << ") to match the result rank of indexing_map #"
+      return this->emitOpError("expected operand rank (")
+             << rank << ") to match the result rank of indexing_map #"
              << opOperand.getOperandNumber() << " ("
              << indexingMap.getNumResults() << ")";
 

>From 5b4a74de13b4ce243351ca771f87f55e64060427 Mon Sep 17 00:00:00 2001
From: Stefan Weigl-Bosker <stefan at s00.xyz>
Date: Tue, 30 Dec 2025 15:35:54 -0500
Subject: [PATCH 5/6] fix diagnostic

---
 .../lib/Interfaces/IndexingMapOpInterface.cpp |  6 +--
 mlir/test/Dialect/Linalg/invalid.mlir         | 18 ++++----
 mlir/test/Dialect/Linalg/named-ops-fail.mlir  | 42 +++++++++----------
 3 files changed, 33 insertions(+), 33 deletions(-)

diff --git a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
index 9df8832c9e39e..0c9e158da2708 100644
--- a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
+++ b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
@@ -37,9 +37,9 @@ LogicalResult mlir::IndexingMapOpInterface::verifyImpl() {
 
     // Result rank must match operand rank.
     if (indexingMap.getNumResults() != rank)
-      return this->emitOpError("expected operand rank (")
-             << rank << ") to match the result rank of indexing_map #"
-             << opOperand.getOperandNumber() << " ("
+      return this->emitOpError("expected operand #")
+             << opOperand.getOperandNumber() << " rank (" << rank
+             << ") to match the result rank of indexing_map ("
              << indexingMap.getNumResults() << ")";
 
     llvm::append_range(allShapesSizes, shape);
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 8a22bdd4cb952..f0241e0c6398e 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -109,7 +109,7 @@ func.func @generic_wrong_iterator(%arg0: memref<1xi32>) {
 // -----
 
 func.func @generic_one_d_view(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
-  // expected-error @+1 {{expected operand rank (1) to match the result rank of indexing_map #0 (2)}}
+  // expected-error @+1 {{expected operand #0 rank (1) to match the result rank of indexing_map (2)}}
   linalg.generic {
     indexing_maps =  [ affine_map<() -> (0, 0)> ],
     iterator_types = []}
@@ -123,7 +123,7 @@ func.func @generic_one_d_view(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i
 
 func.func @generic_scalar_view(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
   %cst = arith.constant 0.0 : f32
-  // expected-error @+1 {{expected operand rank (0) to match the result rank of indexing_map #0 (1)}}
+  // expected-error @+1 {{expected operand #0 rank (0) to match the result rank of indexing_map (1)}}
   linalg.generic {
     indexing_maps =  [ affine_map<() -> (0)>, affine_map<() -> (0, 0)> ],
     iterator_types = []}
@@ -168,7 +168,7 @@ func.func @generic_singular_maps(%arg0: memref<?xf32, affine_map<(i)[off]->(off
 // -----
 
 func.func @generic_index_rank0(%arg0: tensor<f32>) -> tensor<f32> {
-// expected-error @+1 {{op expected operand rank (0) to match the result rank of indexing_map #0 (1)}}
+// expected-error @+1 {{expected operand #0 rank (0) to match the result rank of indexing_map (1)}}
   %0 = linalg.generic {
     indexing_maps = [
       affine_map<(d0) -> (d0)>,
@@ -186,7 +186,7 @@ func.func @generic_index_rank0(%arg0: tensor<f32>) -> tensor<f32> {
 // -----
 
 func.func @generic_index_domain_error(%arg0: tensor<4xf32>) -> tensor<4xf32> {
-// expected-error @+1 {{op expected operand rank (1) to match the result rank of indexing_map #1 (2)}}
+// expected-error @+1 {{expected operand #1 rank (1) to match the result rank of indexing_map (2)}}
   %0 = linalg.generic {
     indexing_maps = [
       affine_map<(d0) -> (d0)>,
@@ -348,7 +348,7 @@ func.func @generic(%arg0: memref<?x?xf32>) {
 // // -----
 
 func.func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?xf32>) {
-  // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #1 (3)}}
+  // expected-error @+1 {{expected operand #1 rank (2) to match the result rank of indexing_map (3)}}
   linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?xf32>)
                      outs(%c3 : memref<?x?x?xf32>)
   return
@@ -441,7 +441,7 @@ func.func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>,
 // -----
 
 func.func @invalid_scalar_input_matmul(%arg0: f32, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
-  // expected-error @+1 {{'linalg.matmul' op expected operand rank (0) to match the result rank of indexing_map #0 (2)}}
+  // expected-error @+1 {{'linalg.matmul' op expected operand #0 rank (0) to match the result rank of indexing_map (2)}}
   linalg.matmul ins(%arg0, %arg1 : f32, memref<3x4xf32>)
                 outs(%arg2 : memref<2x4xf32>)
   return
@@ -553,7 +553,7 @@ func.func @invalid_bcast_b(%arg0: memref<3x5xf32>, %arg1: memref<7xf32>, %arg2:
 // -----
 
 func.func @invalid_bcast_a_rank_mismatch(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
-  // expected-error @+1 {{'linalg.matmul' op expected operand rank (2) to match the result rank of indexing_map #0 (1)}}
+  // expected-error @+1 {{'linalg.matmul' op expected operand #0 rank (2) to match the result rank of indexing_map (1)}}
   linalg.matmul indexing_maps = [
                        affine_map<(d0, d1, d2) -> (d2)>,
                        affine_map<(d0, d1, d2) -> (d2, d1)>,
@@ -566,7 +566,7 @@ func.func @invalid_bcast_a_rank_mismatch(%arg0: memref<3x5xf32>, %arg1: memref<5
 // -----
 
 func.func @invalid_bcast_b_rank_mismatch(%arg0: memref<3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<3x7xf32>) {
-  // expected-error @+1 {{'linalg.matmul' op expected operand rank (2) to match the result rank of indexing_map #1 (1)}}
+  // expected-error @+1 {{'linalg.matmul' op expected operand #1 rank (2) to match the result rank of indexing_map (1)}}
   linalg.matmul indexing_maps = [
                        affine_map<(d0, d1, d2) -> (d0, d2)>,
                        affine_map<(d0, d1, d2) -> (d2)>,
@@ -1214,7 +1214,7 @@ func.func @mmt4d_dims_mismatch(%A: tensor<16x16x8x1xf32>,
 func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
                  %B: tensor<16x16x8x1xf32>,
                  %C_in: tensor<8x8xf32>) -> tensor<8x8xf32> {
-    // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #2 (4)}}
+    // expected-error @+1 {{expected operand #2 rank (2) to match the result rank of indexing_map (4)}}
     %res = linalg.mmt4d
                      ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
                      outs(%C_in: tensor<8x8xf32>)
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index 552a0abaa797c..321a218cc1da2 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -9,7 +9,7 @@ func.func @add_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %
 // -----
 
 func.func @add_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.add ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
   return
 }
@@ -25,7 +25,7 @@ func.func @sub_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %
 // -----
 
 func.func @sub_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.sub ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
   return
 }
@@ -41,7 +41,7 @@ func.func @mul_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %
 // -----
 
 func.func @mul_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.mul ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
   return
 }
@@ -57,7 +57,7 @@ func.func @div_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %
 // -----
 
 func.func @div_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.div ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
   return
 }
@@ -73,7 +73,7 @@ func.func @divu_type_cast(%arg0: memref<4x8x16xi32>, %arg1: memref<4x8x16xi16>,
 // -----
 
 func.func @divu_broadcast(%arg0: memref<8x16xi32>, %arg1: memref<4x8x16xi32>, %arg2: memref<4x8x16xi32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.div_unsigned ins(%arg0, %arg1 : memref<8x16xi32>, memref<4x8x16xi32>) outs(%arg2: memref<4x8x16xi32>)
   return
 }
@@ -89,7 +89,7 @@ func.func @exp_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
 // -----
 
 func.func @exp_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.exp ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
   return
 }
@@ -105,7 +105,7 @@ func.func @log_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
 // -----
 
 func.func @log_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.log ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
   return
 }
@@ -121,7 +121,7 @@ func.func @abs_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
 // -----
 
 func.func @abs_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.abs ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
   return
 }
@@ -137,7 +137,7 @@ func.func @ceil_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
 // -----
 
 func.func @ceil_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.ceil ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
   return
 }
@@ -153,7 +153,7 @@ func.func @floor_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
 // -----
 
 func.func @floor_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.floor ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
   return
 }
@@ -169,7 +169,7 @@ func.func @negf_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
 // -----
 
 func.func @negf_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.negf ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
   return
 }
@@ -185,7 +185,7 @@ func.func @reciprocal_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf3
 // -----
 
 func.func @reciprocal_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.reciprocal ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
   return
 }
@@ -201,7 +201,7 @@ func.func @round_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
 // -----
 
 func.func @round_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.round ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
   return
 }
@@ -217,7 +217,7 @@ func.func @sqrt_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
 // -----
 
 func.func @sqrt_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.sqrt ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
   return
 }
@@ -233,7 +233,7 @@ func.func @rsqrt_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
 // -----
 
 func.func @rsqrt_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.rsqrt ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
   return
 }
@@ -249,7 +249,7 @@ func.func @square_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>)
 // -----
 
 func.func @square_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.square ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
   return
 }
@@ -265,7 +265,7 @@ func.func @tanh_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
 // -----
 
 func.func @tanh_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.tanh ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
   return
 }
@@ -281,7 +281,7 @@ func.func @erf_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
 // -----
 
 func.func @erf_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.erf ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
   return
 }
@@ -297,7 +297,7 @@ func.func @max_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %
 // -----
 
 func.func @max_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.max ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
   return
 }
@@ -313,7 +313,7 @@ func.func @min_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %
 // -----
 
 func.func @min_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.min ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
   return
 }
@@ -329,7 +329,7 @@ func.func @powf_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>,
 // -----
 
 func.func @powf_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
-  // CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
+  // CHECK: op expected operand #0 rank (2) to match the result rank of indexing_map (3)
   linalg.powf ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
   return
 }

>From f76b14a114d5e51adb9626b92780b520473310d3 Mon Sep 17 00:00:00 2001
From: Stefan Weigl-Bosker <stefan at s00.xyz>
Date: Sat, 3 Jan 2026 11:50:49 -0500
Subject: [PATCH 6/6] remove uneccesary check

---
 mlir/lib/Interfaces/IndexingMapOpInterface.cpp | 5 -----
 1 file changed, 5 deletions(-)

diff --git a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
index 0c9e158da2708..2ef36a21a1ac0 100644
--- a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
+++ b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
@@ -56,11 +56,6 @@ LogicalResult mlir::IndexingMapOpInterface::verifyImpl() {
 
   SmallVector<int64_t> endLoopRangeValues = invertedMap.compose(allShapesSizes);
 
-  if (invertedMap.getNumResults() != endLoopRangeValues.size())
-    return getOperation()->emitOpError("expected each indexing_map to have ")
-           << endLoopRangeValues.size()
-           << " dim(s) to match the number of loops";
-
   // Check if given shapes match to inferred shapes.
   SmallVector<int64_t> startLoopRangeValues(endLoopRangeValues.size(), 0);
   // Verify only static cases since we can't get exact dimension sizes and



More information about the Mlir-commits mailing list