[Mlir-commits] [mlir] [mlir][Interface] Allow scalar operands and require ranked shaped operands in IndexingMapOpInterface (PR #179072)

Samarth Narang llvmlistbot at llvm.org
Tue Feb 17 06:13:41 PST 2026


https://github.com/snarang181 updated https://github.com/llvm/llvm-project/pull/179072

>From 603087db84f9bbf150320fa18f504ef8886aa771 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Wed, 4 Feb 2026 16:42:01 -0500
Subject: [PATCH 1/2] accept scalar operands, reject unranked shaped types

---
 .../lib/Interfaces/IndexingMapOpInterface.cpp | 49 +++++++++++++++++--
 mlir/test/Dialect/Linalg/invalid.mlir         | 18 +++++++
 2 files changed, 64 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
index 2ef36a21a1ac0..c8ce69a92a3a8 100644
--- a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
+++ b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
@@ -14,6 +14,36 @@ namespace mlir {
 #include "mlir/Interfaces/IndexingMapOpInterface.cpp.inc"
 } // namespace mlir
 
+static LogicalResult verifyIndexingMapOperandType(Operation *op, Type t,
+                                                  unsigned operandNumber) {
+  // Scalars are allowed (treated as rank-0). verifyImpl checks the rank.
+  if (!isa<ShapedType>(t) && !isa<VectorType>(t))
+    return success();
+
+  // Vectors are allowed.
+  if (isa<VectorType>(t))
+    return success();
+
+  // MemRefs: must be ranked.
+  if (isa<UnrankedMemRefType>(t))
+    return op->emitOpError("operand #")
+           << operandNumber << " must be a ranked memref, but got " << t;
+  if (isa<MemRefType>(t))
+    return success();
+
+  // Tensors: must be ranked.
+  if (isa<UnrankedTensorType>(t))
+    return op->emitOpError("operand #")
+           << operandNumber << " must be a ranked tensor, but got " << t;
+  if (isa<RankedTensorType>(t))
+    return success();
+
+  // Any other shaped type is not supported by this interface.
+  return op->emitOpError("operand #")
+         << operandNumber
+         << " must be ranked tensor/memref, vector, or scalar, but got " << t;
+}
+
 LogicalResult mlir::IndexingMapOpInterface::verifyImpl() {
   // All input/output operands must be indexed.
   if (static_cast<int64_t>(getIndexingMapsArray().size()) !=
@@ -26,14 +56,27 @@ LogicalResult mlir::IndexingMapOpInterface::verifyImpl() {
   SmallVector<int64_t> allShapesSizes;
 
   for (OpOperand &opOperand : getOperation()->getOpOperands()) {
+    Type ty = opOperand.get().getType();
+    if (failed(verifyIndexingMapOperandType(getOperation(), ty,
+                                            opOperand.getOperandNumber())))
+      return failure();
     AffineMap indexingMap = getMatchingIndexingMap(&opOperand);
-    SmallVector<int64_t> shape = getStaticOperandShape(&opOperand);
-    int64_t rank = shape.size();
-
     // Symbols disallowed.
     if (indexingMap.getNumSymbols() != 0)
       return this->emitOpError("unexpected symbols in indexing_map #")
              << opOperand.getOperandNumber();
+    // Handle scalars.
+    if (!isa<ShapedType>(ty) && !isa<VectorType>(ty)) {
+      int64_t rank = 0;
+      if (indexingMap.getNumResults() != rank)
+        return this->emitOpError("expected operand #")
+               << opOperand.getOperandNumber() << " rank (" << rank
+               << ") to match the result rank of indexing_map ("
+               << indexingMap.getNumResults() << ")";
+      continue;
+    }
+    SmallVector<int64_t> shape = getStaticOperandShape(&opOperand);
+    int64_t rank = shape.size();
 
     // Result rank must match operand rank.
     if (indexingMap.getNumResults() != rank)
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 355d801f8732c..cc33205eb1486 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -321,6 +321,24 @@ func.func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->
 
 // -----
 
+// Unranked tensor inputs must be diagnosed.
+func.func @generic_unranked_input_tensor(%in: tensor<*xf32>) {
+  %out = tensor.empty() : tensor<16x16xf32>
+  // expected-error @+1 {{'linalg.generic' op operand #0 must be a ranked tensor, but got 'tensor<*xf32>'}}
+  %r = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]}
+      ins(%in : tensor<*xf32>)
+     outs(%out : tensor<16x16xf32>) {
+      ^bb0(%a: f32, %b: f32):
+        linalg.yield %a : f32
+     } -> tensor<16x16xf32>
+  return
+}
+
+// -----
+
 func.func @generic(%arg0: memref<?x?xf32>) {
   // expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32}}
   linalg.generic  {

>From c603e598be809951d61a3b18751584a2eaaba1f1 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Tue, 17 Feb 2026 09:13:02 -0500
Subject: [PATCH 2/2] Cosmetic changes

---
 mlir/lib/Interfaces/IndexingMapOpInterface.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
index c8ce69a92a3a8..a2af9095f63f4 100644
--- a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
+++ b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
@@ -25,16 +25,18 @@ static LogicalResult verifyIndexingMapOperandType(Operation *op, Type t,
     return success();
 
   // MemRefs: must be ranked.
-  if (isa<UnrankedMemRefType>(t))
+  if (isa<UnrankedMemRefType>(t)) {
     return op->emitOpError("operand #")
            << operandNumber << " must be a ranked memref, but got " << t;
+  }
   if (isa<MemRefType>(t))
     return success();
 
   // Tensors: must be ranked.
-  if (isa<UnrankedTensorType>(t))
+  if (isa<UnrankedTensorType>(t)) {
     return op->emitOpError("operand #")
            << operandNumber << " must be a ranked tensor, but got " << t;
+  }
   if (isa<RankedTensorType>(t))
     return success();
 



More information about the Mlir-commits mailing list