[Mlir-commits] [mlir] [mlir][Interface] Allow scalar operands and require ranked shaped operands in IndexingMapOpInterface (PR #179072)
Samarth Narang
llvmlistbot at llvm.org
Tue Feb 17 06:13:41 PST 2026
https://github.com/snarang181 updated https://github.com/llvm/llvm-project/pull/179072
>From 603087db84f9bbf150320fa18f504ef8886aa771 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Wed, 4 Feb 2026 16:42:01 -0500
Subject: [PATCH 1/2] accept scalar operands, reject unranked shaped types
---
.../lib/Interfaces/IndexingMapOpInterface.cpp | 49 +++++++++++++++++--
mlir/test/Dialect/Linalg/invalid.mlir | 18 +++++++
2 files changed, 64 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
index 2ef36a21a1ac0..c8ce69a92a3a8 100644
--- a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
+++ b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
@@ -14,6 +14,36 @@ namespace mlir {
#include "mlir/Interfaces/IndexingMapOpInterface.cpp.inc"
} // namespace mlir
+static LogicalResult verifyIndexingMapOperandType(Operation *op, Type t,
+ unsigned operandNumber) {
+ // Scalars are allowed (treated as rank-0). verifyImpl checks the rank.
+ if (!isa<ShapedType>(t) && !isa<VectorType>(t))
+ return success();
+
+ // Vectors are allowed.
+ if (isa<VectorType>(t))
+ return success();
+
+ // MemRefs: must be ranked.
+ if (isa<UnrankedMemRefType>(t))
+ return op->emitOpError("operand #")
+ << operandNumber << " must be a ranked memref, but got " << t;
+ if (isa<MemRefType>(t))
+ return success();
+
+ // Tensors: must be ranked.
+ if (isa<UnrankedTensorType>(t))
+ return op->emitOpError("operand #")
+ << operandNumber << " must be a ranked tensor, but got " << t;
+ if (isa<RankedTensorType>(t))
+ return success();
+
+ // Any other shaped type is not supported by this interface.
+ return op->emitOpError("operand #")
+ << operandNumber
+ << " must be ranked tensor/memref, vector, or scalar, but got " << t;
+}
+
LogicalResult mlir::IndexingMapOpInterface::verifyImpl() {
// All input/output operands must be indexed.
if (static_cast<int64_t>(getIndexingMapsArray().size()) !=
@@ -26,14 +56,27 @@ LogicalResult mlir::IndexingMapOpInterface::verifyImpl() {
SmallVector<int64_t> allShapesSizes;
for (OpOperand &opOperand : getOperation()->getOpOperands()) {
+ Type ty = opOperand.get().getType();
+ if (failed(verifyIndexingMapOperandType(getOperation(), ty,
+ opOperand.getOperandNumber())))
+ return failure();
AffineMap indexingMap = getMatchingIndexingMap(&opOperand);
- SmallVector<int64_t> shape = getStaticOperandShape(&opOperand);
- int64_t rank = shape.size();
-
// Symbols disallowed.
if (indexingMap.getNumSymbols() != 0)
return this->emitOpError("unexpected symbols in indexing_map #")
<< opOperand.getOperandNumber();
+ // Handle scalars.
+ if (!isa<ShapedType>(ty) && !isa<VectorType>(ty)) {
+ int64_t rank = 0;
+ if (indexingMap.getNumResults() != rank)
+ return this->emitOpError("expected operand #")
+ << opOperand.getOperandNumber() << " rank (" << rank
+ << ") to match the result rank of indexing_map ("
+ << indexingMap.getNumResults() << ")";
+ continue;
+ }
+ SmallVector<int64_t> shape = getStaticOperandShape(&opOperand);
+ int64_t rank = shape.size();
// Result rank must match operand rank.
if (indexingMap.getNumResults() != rank)
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 355d801f8732c..cc33205eb1486 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -321,6 +321,24 @@ func.func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->
// -----
+// Unranked tensor inputs must be diagnosed.
+func.func @generic_unranked_input_tensor(%in: tensor<*xf32>) {
+ %out = tensor.empty() : tensor<16x16xf32>
+ // expected-error @+1 {{'linalg.generic' op operand #0 must be a ranked tensor, but got 'tensor<*xf32>'}}
+ %r = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%in : tensor<*xf32>)
+ outs(%out : tensor<16x16xf32>) {
+ ^bb0(%a: f32, %b: f32):
+ linalg.yield %a : f32
+ } -> tensor<16x16xf32>
+ return
+}
+
+// -----
+
func.func @generic(%arg0: memref<?x?xf32>) {
// expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32}}
linalg.generic {
>From c603e598be809951d61a3b18751584a2eaaba1f1 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Tue, 17 Feb 2026 09:13:02 -0500
Subject: [PATCH 2/2] Cosmetic changes
---
mlir/lib/Interfaces/IndexingMapOpInterface.cpp | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
index c8ce69a92a3a8..a2af9095f63f4 100644
--- a/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
+++ b/mlir/lib/Interfaces/IndexingMapOpInterface.cpp
@@ -25,16 +25,18 @@ static LogicalResult verifyIndexingMapOperandType(Operation *op, Type t,
return success();
// MemRefs: must be ranked.
- if (isa<UnrankedMemRefType>(t))
+ if (isa<UnrankedMemRefType>(t)) {
return op->emitOpError("operand #")
<< operandNumber << " must be a ranked memref, but got " << t;
+ }
if (isa<MemRefType>(t))
return success();
// Tensors: must be ranked.
- if (isa<UnrankedTensorType>(t))
+ if (isa<UnrankedTensorType>(t)) {
return op->emitOpError("operand #")
<< operandNumber << " must be a ranked tensor, but got " << t;
+ }
if (isa<RankedTensorType>(t))
return success();
More information about the Mlir-commits
mailing list