[Mlir-commits] [mlir] [MLIR][Interface]: Verify index map ranks before composing loop bounds (PR #173434)

Stefan Weigl-Bosker llvmlistbot at llvm.org
Tue Dec 23 16:20:17 PST 2025


https://github.com/sweiglbosker updated https://github.com/llvm/llvm-project/pull/173434

>From b233eceea45966b61e53fefc5df07fe2516db20e Mon Sep 17 00:00:00 2001
From: Stefan Weigl-Bosker <stefan at s00.xyz>
Date: Tue, 23 Dec 2025 19:04:52 -0500
Subject: [PATCH] [MLIR][Interface]: Verify index map ranks before composing
 loop bounds

---
 .../lib/Interfaces/IndexingMapOpInterface.cpp | 26 +++++++-------
 mlir/test/Dialect/Linalg/invalid.mlir         | 35 +++++++++++++++++++
 2 files changed, 48 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
index f3c12aed8df84..463276011acdb 100644
--- a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
+++ b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
@@ -32,35 +32,35 @@ LogicalResult mlir::IndexingMapOpInterface::verifyImpl() {
            << "(" << str << ")";
   }
 
-  SmallVector<int64_t> endLoopRangeValues = getStaticLoopRanges();
+  SmallVector<int64_t> allShapesSizes;
 
-  // Set this flag if this op has user defined maps. This is required to guard
-  // the below error condition which assume default indexing maps.
   for (OpOperand &opOperand : getOperation()->getOpOperands()) {
     AffineMap indexingMap = getMatchingIndexingMap(&opOperand);
+    SmallVector<int64_t> shape = getStaticOperandShape(&opOperand);
+    int64_t rank = shape.size();
 
     // Symbols disallowed.
     if (indexingMap.getNumSymbols() != 0)
       return getOperation()->emitOpError("unexpected symbols in indexing_map #")
              << opOperand.getOperandNumber();
 
-    // Domain must be consistent.
-    if (indexingMap.getNumDims() != endLoopRangeValues.size())
-      return getOperation()->emitOpError("expected indexing_map #")
-             << opOperand.getOperandNumber() << " to have "
-             << endLoopRangeValues.size()
-             << " dim(s) to match the number of loops";
-
-    SmallVector<int64_t> shape = getStaticOperandShape(&opOperand);
-    int64_t rank = shape.size();
-
+    // Result rank must match operand rank.
     if (indexingMap.getNumResults() != rank)
       return getOperation()->emitOpError("expected operand rank (")
              << rank << ") to match the result rank of indexing_map #"
              << opOperand.getOperandNumber() << " ("
              << indexingMap.getNumResults() << ")";
+
+    llvm::append_range(allShapesSizes, shape);
   }
 
+  SmallVector<int64_t> endLoopRangeValues = invertedMap.compose(allShapesSizes);
+
+  if (invertedMap.getNumResults() != endLoopRangeValues.size())
+    return getOperation()->emitOpError("expected each indexing_map to have ")
+           << endLoopRangeValues.size()
+           << " dim(s) to match the number of loops";
+
   // Check if given shapes match to inferred shapes.
   SmallVector<int64_t> startLoopRangeValues(endLoopRangeValues.size(), 0);
   // Verify only static cases since we can't get exact dimension sizes and
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 1f554e6c45da7..af9112b7c1f74 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -165,6 +165,41 @@ func.func @generic_singular_maps(%arg0: memref<?xf32, affine_map<(i)[off]->(off
   }
 }
 
+// -----
+
+func.func @generic_index_rank0(%arg0: tensor<f32>) -> tensor<f32> {
+// expected-error @+1 {{op expected operand rank (0) to match the result rank of indexing_map #0 (1)}}
+  %0 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0) -> (d0)>,
+      affine_map<(d0) -> (d0)>
+    ],
+    iterator_types = ["parallel"]}
+      ins(%arg0 : tensor<f32>)
+     outs(%arg0 : tensor<f32>) {
+  ^bb(%0: f32, %1: f32):
+    linalg.yield %1 : f32
+  } -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @generic_index_domain_error(%arg0: tensor<4xf32>) -> tensor<4xf32> {
+// expected-error @+1 {{op expected operand rank (1) to match the result rank of indexing_map #1 (2)}}
+  %0 = linalg.generic {
+    indexing_maps = [
+      affine_map<(d0) -> (d0)>,
+      affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]}
+      ins(%arg0 : tensor<4xf32>)
+     outs(%arg0 : tensor<4xf32>) {
+  ^bb(%0: f32):
+    linalg.yield %0 : f32
+  } -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 ///////////////////////////// Region tests /////////////////////////////////////
 ////////////////////////////////////////////////////////////////////////////////



More information about the Mlir-commits mailing list