[Mlir-commits] [mlir] [mlir][Interfaces] Add interface methods to allow reifying single result/single dim of result. (PR #162924)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 24 13:05:12 PDT 2025
https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/162924
>From aa22ea5b8f02bbe2bf5b30f8c36b2e4eb96482ae Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Fri, 10 Oct 2025 14:12:59 -0700
Subject: [PATCH 1/2] [mlir][Interfaces] Add interface methods to allow
reifying single result/single dim of result.
Current implementation of `reifyResultShapes` forces all
implementations to return all dimensions of all results. This can be
wasteful when you only require dimensions of one result, or a single
dimension of a result. Further this also creates issues with using
patterns to resolve the `tensor.dim` and `memref.dim` operations since
the extra operations created result in the pattern rewriter entering
an infinite loop (eventually breaking out of the loop due to the
iteration limit on the pattern rewriter). This is demonstrated by some
of the test cases added here that hit this limit when using
`--resolve-shaped-type-result-dims` and
`--resolve-ranked-shaped-type-result-dims`. To resolve this issue the
interface should allow for creating just the operations needed. This
change is the first step in resolving this.
The original implementation was done with the restriction in mind that
it might not always be possible to compute dimension of a single
result or one dimension of a single result in all cases. To account
for such cases, two additional interface methods are added
- `reifyShapeOfResult` (which allows reifying dimensions of
just one result), has a default implementation that calls
`reifyResultShapes` and returns the dimensions of a single result.
- `reifyDimOfResult` (which allows reifying a single dimension of a
single result) has a default implementation that calls
`reifyDimOfResult` and returns the value for the dimension of the
result (which in turn for the default case would call
`reifyDimOfResult`).
While this change sets up the interface, ideally most operations will
implement the `refiyDimOfResult` when possible. For almost all
operations in tree this is true. Subsequent commits will change those
incrementally.
Some of the tests added here that check that the default
implementations for the above method work as expected, also end up
hitting the pattern rewriter limit when using
`--resolve-ranked-shaped-type-result-dims`/
`--resolve-ranked-shaped-type-result-dims`. For testing purposes, a
flag is added to these passes that ignore the error returned by the
pattern application (this flag is left on by default to maintain
current state).
Changes required downstream to integrate this change
1. In operation definitions in .td files, for those operations that
implement the `ReifyRankedShapedTypeOpInterface`.
```
def <op-name> : Op<..., [...,
DeclareOpInterfaceMethods[ReifyRankedShapedTypeOpInterface]]>
```
should be changed to
```
def <op-name> : Op<..., [...,
DeclareOpInterfaceMethods[ReifyRankedShapedTypeOpInterface, [
"reifyResultShapes"]]]>
```
Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
.../Bufferization/IR/BufferizationOps.td | 6 +-
.../mlir/Dialect/Linalg/IR/LinalgOps.td | 3 +-
.../Dialect/Linalg/IR/LinalgRelayoutOps.td | 3 +-
.../mlir/Dialect/MemRef/IR/MemRefOps.td | 3 +-
.../mlir/Dialect/MemRef/Transforms/Passes.td | 10 ++
.../mlir/Dialect/Tensor/IR/TensorOps.td | 22 ++-
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 3 +-
.../mlir/Interfaces/InferTypeOpInterface.h | 4 +
.../mlir/Interfaces/InferTypeOpInterface.td | 64 ++++++++-
.../ResolveShapedTypeResultDims.cpp | 33 +++--
.../IR/TensorInferTypeOpInterfaceImpl.cpp | 3 +
mlir/lib/Interfaces/InferTypeOpInterface.cpp | 16 +++
.../resolve-shaped-type-result-dims.mlir | 127 +++++++++++++++++-
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 101 +++++++++++++-
mlir/test/lib/Dialect/Test/TestOps.h | 1 +
mlir/test/lib/Dialect/Test/TestOps.td | 90 ++++++++++++-
16 files changed, 451 insertions(+), 38 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 6724d4c483101..a9b2b9f39519d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -28,7 +28,8 @@ class Bufferization_Op<string mnemonic, list<Trait> traits = []>
def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
[AttrSizedOperandSegments, BufferizableOpInterface,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+ "reifyResultShapes"]>]> {
let summary = "allocate buffer for a tensor";
let description = [{
@@ -219,7 +220,8 @@ def Bufferization_MaterializeInDestinationOp
: Bufferization_Op<"materialize_in_destination",
[AllElementTypesMatch<["source", "dest"]>,
BufferizableOpInterface, DestinationStyleOpInterface,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+ "reifyResultShapes"]>,
DeclareOpInterfaceMethods<SubsetOpInterface,
["operatesOnEquivalentSubset", "operatesOnDisjointSubset"]>,
DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 7ff44c2e1d2ed..2754ee3b4f586 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -94,7 +94,8 @@ def Linalg_IndexOp : Linalg_Op<"index", [Pure]>,
def Linalg_SoftmaxOp : Linalg_Op<"softmax",
[DestinationStyleOpInterface,
PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes"]>,
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<TilingInterface,
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 6504ca8664d49..238fa42cae427 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -35,7 +35,8 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DestinationStyleOpInterface, LinalgRelayoutOpInterface,
ConditionallySpeculatable, NoMemoryEffect,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+ "reifyResultShapes"]>,
TypesMatchWith<"result type matches type of dest",
"dest", "result",
"$_self">])> {
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index b39207fc30dd7..9d44d05b9fc86 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1778,7 +1778,8 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes"]>]> {
let summary = "operation to produce a memref with a higher rank.";
let description = [{
The `memref.expand_shape` op produces a new view with a higher rank whose
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index f3e40aaa29075..c403386bd214a 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -164,6 +164,11 @@ def ResolveRankedShapeTypeResultDimsPass
implement the `ReifyRankedShapedTypeOpInterface` in terms of
shapes of its operands.
}];
+ let options = [
+ Option<"errorOnPatternIterationLimit", "error-on-pattern-iteration-limit", "bool",
+ /*default=*/"true",
+ "Throw an error when pattern rewriter hits iteration limit">,
+ ];
let dependentDialects = [
"memref::MemRefDialect", "tensor::TensorDialect"
];
@@ -177,6 +182,11 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
`ReifyRankedShapedTypeOpInterface` in terms of shapes of its
operands.
}];
+ let options = [
+ Option<"errorOnPatternIterationLimit", "error-on-pattern-iteration-limit", "bool",
+ /*default=*/"true",
+ "Throw an error when pattern rewriter hits iteration limit">,
+ ];
let dependentDialects = [
"affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
];
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 2453cf5b5b5a4..3e93e58575e65 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -131,7 +131,9 @@ def Tensor_CastOp : Tensor_Op<"cast", [
def Tensor_ConcatOp : Tensor_Op<"concat",
[Pure,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+ "reifyResultShapes"]>,
+ ]> {
let summary = "tensor concatenation operation";
let description = [{
The "concat" operation constructs a tensor out of a variadic list of input
@@ -261,7 +263,8 @@ def Tensor_DimOp : Tensor_Op<"dim", [
def Tensor_EmptyOp : Tensor_Op<"empty",
[Pure,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+ "reifyResultShapes"]>]> {
let summary = "empty tensor operation";
let description = [{
@@ -358,7 +361,8 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [
def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+ "reifyResultShapes"]>,
AttrSizedOperandSegments,
Pure,
OffsetSizeAndStrideOpInterface
@@ -740,7 +744,8 @@ def Tensor_GatherOp : Tensor_Op<"gather", [
def Tensor_GenerateOp : Tensor_Op<"generate", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
RecursiveMemoryEffects,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+ "reifyResultShapes"]>,
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
let summary = "Creates a dynamically sized tensor from elements";
let description = [{
@@ -835,7 +840,8 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+ "reifyResultShapes"]>,
AttrSizedOperandSegments,
DestinationStyleOpInterface,
Pure,
@@ -1256,7 +1262,8 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
def Tensor_PadOp : Tensor_Op<"pad", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, [
+ "reifyResultShapes"]>,
AttrSizedOperandSegments,
Pure,
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
@@ -1764,7 +1771,8 @@ def Tensor_ScatterOp : Tensor_Op<"scatter", [
def Tensor_SplatOp : Tensor_Op<"splat", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes"]>,
Pure,
TypesMatchWith<"operand type matches element type of result",
"aggregate", "input",
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 6e1759119a621..a5c28dffc632d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2218,7 +2218,8 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
// Operator: transpose
//===----------------------------------------------------------------------===//
def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
- [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface ,
+ ["reifyResultShapes"]>,
AllElementTypesMatch<["input1", "output"]>]> {
let summary = "Transpose operator.";
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
index 4fcbeff9df560..1bfb66e681d8d 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h
@@ -33,6 +33,10 @@ using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<OpFoldResult>>;
LogicalResult
reifyResultShapes(OpBuilder &b, Operation *op,
ReifiedRankedShapedTypeDims &reifiedReturnShapes);
+FailureOr<SmallVector<OpFoldResult>>
+reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex);
+FailureOr<OpFoldResult> reifyDimOfResult(OpBuilder &b, Operation *op,
+ int resultIndex, int dim);
/// Adaptor class to abstract the differences between whether value is from
/// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index 1a2c05fc16ed5..c949656325b2d 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -361,20 +361,76 @@ def ReifyRankedShapedTypeOpInterface :
let methods = [
InterfaceMethod<
/*desc=*/[{
- Reify the shape of the result of an operation (typically in terms of the
- shape of its operands).
+ Reify the shapes of all the result of an operation (typically in terms
+ of the shape of its operands).
`reifiedReturnShapes` is populated with one vector per op result. Each
of those vectors contains an OpFoldResult for each dimension of the
shaped type. The given builder may be used to insert ops that compute
result shapes.
- If the shape of a particular result cannot be computed it must be empty.
+ If the shape of a particular result cannot be computed it in terms of
+ its operands it must be left empty. If any dimension of the result cannot
+ be computed it must be set to OpFoldResult().
}],
/*retTy=*/"::llvm::LogicalResult",
/*methodName=*/"reifyResultShapes",
/*args=*/(ins "::mlir::OpBuilder &":$builder,
- "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes)
+ "::mlir::ReifiedRankedShapedTypeDims &":$reifiedReturnShapes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{ return ::mlir::failure(); }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Reify the shape of a single result of an operation (typically in terms
+ of the shape of its operands).
+
+ Returns the shape of a single result of the operation as a
+ `SmallVector<OpFoldResult>`, one per dimension of the shaped type. The
+ given builder may be used to insert ops that compute result shapes.
+
+ If any dimension of the result cannot be computed it must be set to
+ OpFoldResult().
+ }],
+ /*retTy=*/"::llvm::FailureOr<::llvm::SmallVector<::mlir::OpFoldResult>>",
+ /*methodName=*/"reifyShapeOfResult",
+ /*args=*/(ins "::mlir::OpBuilder &":$builder,
+ "int":$resultIndex),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ ReifiedRankedShapedTypeDims reifiedShapes;
+ if (failed(cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyResultShapes(builder, reifiedShapes)))
+ return failure();
+ if (resultIndex < 0 || resultIndex >= (int)(reifiedShapes.size()))
+ return $_op.emitOpError("invalid result index");
+ return reifiedShapes[resultIndex];
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Reify the shape of a dimension of a given result of an operation
+ (typically in terms of the shape of its operands).
+
+ Returns the shape of a specific dimension of a result of the operation as
+ an OpFoldResult. The given builder may be used to insert ops that compute
+ the shapes.
+
+ If the dimension of the result cannot be computed the method must return
+ `failure()`.
+ }],
+ /*retTy=*/"::llvm::FailureOr<::mlir::OpFoldResult>",
+ /*methodName=*/"reifyDimOfResult",
+ /*args=*/(ins "::mlir::OpBuilder &":$builder,
+ "int":$resultIndex, "int":$dim),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto shapes = cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyShapeOfResult(builder, resultIndex);
+ if (failed(shapes))
+ return failure();
+ if (dim < 0 || dim >= (int)((*shapes).size()))
+ return $_op.emitOpError("invalid dimension");
+ return (*shapes)[dim];
+ }]
>
];
}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 6a81a15f30e47..c498c8a60bf6e 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -90,17 +90,16 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
if (!dimIndex)
return failure();
- ReifiedRankedShapedTypeDims reifiedResultShapes;
- if (failed(reifyResultShapes(rewriter, dimValue.getOwner(),
- reifiedResultShapes)))
+ FailureOr<OpFoldResult> replacement = reifyDimOfResult(
+ rewriter, dimValue.getOwner(), dimValue.getResultNumber(), *dimIndex);
+ if (failed(replacement))
return failure();
- unsigned resultNumber = dimValue.getResultNumber();
- // Do not apply pattern if the IR is invalid (dim out of bounds).
- if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size())
- return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds");
- Value replacement = getValueOrCreateConstantIndexOp(
- rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
- rewriter.replaceOp(dimOp, replacement);
+ // Check if the OpFoldResult is empty (unreifiable dimension).
+ if (!replacement.value())
+ return failure();
+ Value replacementVal = getValueOrCreateConstantIndexOp(
+ rewriter, dimOp.getLoc(), replacement.value());
+ rewriter.replaceOp(dimOp, replacementVal);
return success();
}
};
@@ -166,12 +165,14 @@ namespace {
struct ResolveRankedShapeTypeResultDimsPass final
: public memref::impl::ResolveRankedShapeTypeResultDimsPassBase<
ResolveRankedShapeTypeResultDimsPass> {
+ using Base::Base;
void runOnOperation() override;
};
struct ResolveShapedTypeResultDimsPass final
: public memref::impl::ResolveShapedTypeResultDimsPassBase<
ResolveShapedTypeResultDimsPass> {
+ using Base::Base;
void runOnOperation() override;
};
@@ -195,14 +196,22 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
+ if (errorOnPatternIterationLimit && failed(result)) {
+ getOperation()->emitOpError(
+ "dim operation resolution hit pattern iteration limit");
return signalPassFailure();
+ }
}
void ResolveShapedTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
- if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ auto result = applyPatternsGreedily(getOperation(), std::move(patterns));
+ if (errorOnPatternIterationLimit && failed(result)) {
+ getOperation()->emitOpError(
+ "dim operation resolution hit pattern iteration limit");
return signalPassFailure();
+ }
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 4ec13e189f621..686f6eed1f8c7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -77,6 +77,9 @@ namespace {
struct ReifyExpandShapeOp
: public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
ExpandShapeOp> {
+ using Base =
+ ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
+ ExpandShapeOp>;
LogicalResult
reifyResultShapes(Operation *op, OpBuilder &b,
ReifiedRankedShapedTypeDims &reifyResultShapes) const {
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 9f4f672fb9f4d..c31e0ae7470e2 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -58,6 +58,22 @@ mlir::reifyResultShapes(OpBuilder &b, Operation *op,
return status;
}
+FailureOr<SmallVector<OpFoldResult>>
+mlir::reifyShapeOfResult(OpBuilder &b, Operation *op, int resultIndex) {
+ auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
+ if (!reifiableOp)
+ return failure();
+ return reifiableOp.reifyShapeOfResult(b, resultIndex);
+}
+
+FailureOr<OpFoldResult> mlir::reifyDimOfResult(OpBuilder &b, Operation *op,
+ int resultIndex, int dim) {
+ auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
+ if (!reifiableOp)
+ return failure();
+ return reifiableOp.reifyDimOfResult(b, resultIndex, dim);
+}
+
bool ShapeAdaptor::hasRank() const {
if (val.isNull())
return false;
diff --git a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
index 4fa7406f21042..ee9991cf78b45 100644
--- a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -resolve-shaped-type-result-dims -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(resolve-shaped-type-result-dims{error-on-pattern-iteration-limit=false}))" -split-input-file | FileCheck %s
func.func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
-> (index, index, index, index, index) {
@@ -27,12 +27,14 @@ func.func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
// -----
-func.func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
+// Test result shape reification for an operation that implements only
+// `reifyResultShapes` method of the `InferShapedTypeOpInterface`.
+func.func @reify_shaped_type_using_reify_result_shapes(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
-> (index, index, index, index, index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
- %0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1)
+ %0:2 = "test.reify_shaped_type_using_reify_result_shapes"(%arg0, %arg1)
: (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
%1 = tensor.dim %0#0, %c0 : tensor<?x5xf32>
%2 = tensor.dim %0#0, %c1 : tensor<?x5xf32>
@@ -41,7 +43,7 @@ func.func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf3
%5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32>
return %1, %2, %3, %4, %5 : index, index, index, index, index
}
-// CHECK-LABEL: func @result_shape_per_dim(
+// CHECK-LABEL: func @reify_shaped_type_using_reify_result_shapes(
// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
@@ -51,3 +53,120 @@ func.func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf3
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
+
+// -----
+
+// Test result shape reification for an operation that implements only
+// `reifyShapeOfResult` method of the `InferShapedTypeOpInterface`.
+func.func @reify_shaped_type_using_reify_shape_of_result(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
+ -> (index, index, index, index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %0:2 = "test.reify_shaped_type_using_reify_result_shapes"(%arg0, %arg1)
+ : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
+ %1 = tensor.dim %0#0, %c0 : tensor<?x5xf32>
+ %2 = tensor.dim %0#0, %c1 : tensor<?x5xf32>
+ %3 = tensor.dim %0#1, %c0 : tensor<2x3x?xf32>
+ %4 = tensor.dim %0#1, %c1 : tensor<2x3x?xf32>
+ %5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32>
+ return %1, %2, %3, %4, %5 : index, index, index, index, index
+}
+// CHECK-LABEL: func @reify_shaped_type_using_reify_shape_of_result(
+// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
+// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
+// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
+
+// -----
+
+// Test result shape reification for an operation that implements only
+// `reifyDimOfResult` method of the `InferShapedTypeOpInterface`.
+func.func @reify_shaped_type_using_reify_dim_of_result(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
+ -> (index, index, index, index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %0:2 = "test.reify_shaped_type_using_reify_result_shapes"(%arg0, %arg1)
+ : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
+ %1 = tensor.dim %0#0, %c0 : tensor<?x5xf32>
+ %2 = tensor.dim %0#0, %c1 : tensor<?x5xf32>
+ %3 = tensor.dim %0#1, %c0 : tensor<2x3x?xf32>
+ %4 = tensor.dim %0#1, %c1 : tensor<2x3x?xf32>
+ %5 = tensor.dim %0#1, %c2 : tensor<2x3x?xf32>
+ return %1, %2, %3, %4, %5 : index, index, index, index, index
+}
+// CHECK-LABEL: func @reify_shaped_type_using_reify_dim_of_result(
+// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
+// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG_1]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG_0]], %[[C2]]
+// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
+
+// -----
+
+func.func @test_unreifiable_result_shapes(%arg0 : tensor<?x?xf32>)
+ -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = "test.unreifiable_result_shapes"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+ %d0 = tensor.dim %0, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %0, %c1 : tensor<?x?xf32>
+ return %d0, %d1 : index, index
+}
+// CHECK-LABEL: func @test_unreifiable_result_shapes(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[OP:.+]] = "test.unreifiable_result_shapes"(%[[ARG0]])
+// CHECK: %[[D1:.+]] = tensor.dim %[[OP]], %[[C1]]
+// CHECK: return %[[D0]], %[[D1]]
+// -----
+
+func.func @test_unreifiable_result_shape(%arg0 : tensor<?x?xf32>)
+ -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = "test.unreifiable_result_shape"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+ %d0 = tensor.dim %0, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %0, %c1 : tensor<?x?xf32>
+ return %d0, %d1 : index, index
+}
+// CHECK-LABEL: func @test_unreifiable_result_shape(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[OP:.+]] = "test.unreifiable_result_shape"(%[[ARG0]])
+// CHECK: %[[D1:.+]] = tensor.dim %[[OP]], %[[C1]]
+// CHECK: return %[[D0]], %[[D1]]
+
+// -----
+
+func.func @test_unreifiable_dim_of_result_shape(%arg0 : tensor<?x?xf32>)
+ -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = "test.unreifiable_dim_of_result_shape"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
+ %d0 = tensor.dim %0, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %0, %c1 : tensor<?x?xf32>
+ return %d0, %d1 : index, index
+}
+// CHECK-LABEL: func @test_unreifiable_dim_of_result_shape(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+// CHECK-DAG: %[[OP:.+]] = "test.unreifiable_dim_of_result_shape"(%[[ARG0]])
+// CHECK: %[[D1:.+]] = tensor.dim %[[OP]], %[[C1]]
+// CHECK: return %[[D0]], %[[D1]]
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index b211e243f234c..c7e87d3b8fe36 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -320,10 +320,10 @@ LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
}
//===----------------------------------------------------------------------===//
-// OpWithResultShapePerDimInterfaceOp
+// ReifyShapedTypeUsingReifyResultShapesOp
//===----------------------------------------------------------------------===//
-LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
+LogicalResult ReifyShapedTypeUsingReifyResultShapesOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
Location loc = getLoc();
shapes.reserve(getNumOperands());
@@ -344,6 +344,103 @@ LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
return success();
}
+//===----------------------------------------------------------------------===//
+// ReifyShapedTypeUsingReifyShapeOfResultOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ReifyShapedTypeUsingReifyShapeOfResultOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
+ return failure();
+}
+
+FailureOr<SmallVector<OpFoldResult>>
+ReifyShapedTypeUsingReifyShapeOfResultOp::reifyShapeOfResult(OpBuilder &builder,
+ int resultIndex) {
+ Location loc = getLoc();
+ Value sourceOperand = getOperand(getNumOperands() - 1 - resultIndex);
+ SmallVector<OpFoldResult> shape =
+ tensor::getMixedSizes(builder, loc, sourceOperand);
+ return shape;
+}
+
+//===----------------------------------------------------------------------===//
+// ReifyShapedTypeUsingReifyDimOfResultOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ReifyShapedTypeUsingReifyDimOfResultOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
+ return failure();
+}
+
+FailureOr<SmallVector<OpFoldResult>>
+ReifyShapedTypeUsingReifyDimOfResultOp::reifyShapeOfResult(OpBuilder &builder,
+ int resultIndex) {
+ return failure();
+}
+
+FailureOr<OpFoldResult>
+ReifyShapedTypeUsingReifyDimOfResultOp::reifyDimOfResult(OpBuilder &builder,
+ int resultIndex,
+ int dim) {
+ Location loc = getLoc();
+ Value sourceOperand = getOperand(getNumOperands() - 1 - resultIndex);
+ OpFoldResult shape = tensor::getMixedSize(builder, loc, sourceOperand, dim);
+ return shape;
+}
+
+//===----------------------------------------------------------------------===//
+// UnreifableResultShapesOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult UnreifiableResultShapesOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
+ Location loc = getLoc();
+ shapes.resize(1);
+ shapes[0] = {tensor::getMixedSize(builder, loc, getOperand(), 0),
+ OpFoldResult()};
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// UnreifableResultShapeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult UnreifiableResultShapeOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
+ return failure();
+}
+
+FailureOr<SmallVector<OpFoldResult>>
+UnreifiableResultShapeOp::reifyShapeOfResult(OpBuilder &builder,
+ int resultIndex) {
+ SmallVector<OpFoldResult> shape = {
+ tensor::getMixedSize(builder, getLoc(), getOperand(), 0), OpFoldResult()};
+ return shape;
+}
+
+//===----------------------------------------------------------------------===//
+// UnreifableResultShapeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult UnreifiableDimOfResultShapeOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
+ return failure();
+}
+
+FailureOr<SmallVector<OpFoldResult>>
+UnreifiableDimOfResultShapeOp::reifyShapeOfResult(OpBuilder &builder,
+ int resultIndex) {
+ return failure();
+}
+
+FailureOr<OpFoldResult>
+UnreifiableDimOfResultShapeOp::reifyDimOfResult(OpBuilder &builder,
+ int resultIndex, int dim) {
+ if (dim == 0)
+ return tensor::getMixedSize(builder, getLoc(), getOperand(), 0);
+ return failure();
+}
+
//===----------------------------------------------------------------------===//
// SideEffectOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h
index 4201ade9795e7..679274346fb13 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.h
+++ b/mlir/test/lib/Dialect/Test/TestOps.h
@@ -42,6 +42,7 @@
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
namespace test {
class TestDialect;
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 05a33cf1afd94..9a5fc7bc717da 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -914,13 +914,97 @@ def OpWithResultShapeInterfaceOp : TEST_Op<"op_with_result_shape_interface",
let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
}
-def OpWithResultShapePerDimInterfaceOp :
- TEST_Op<"op_with_result_shape_per_dim_interface",
- [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+def ReifyShapedTypeUsingReifyResultShapesOp :
+ TEST_Op<"reify_shaped_type_using_reify_result_shapes",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes"]>]> {
+ let description = [{
+ Test that when resolving a single dimension of a result for an operation
+ that doesnt implement `reifyShapeOfResult` nor implements `reifyDimOfResult`
+ calls into the implementation of `reifyResultShapes` to get the required value.
+ The op semantics is that the first result has the same shape as the second operand
+ and the second result has the same shape as the first operand.
+ }];
+ let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
+ let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
+}
+
+def ReifyShapedTypeUsingReifyShapeOfResultOp :
+ TEST_Op<"reify_shaped_type_using_reify_shape_of_result",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes", "reifyShapeOfResult"]>]> {
+ let description = [{
+ Test that when resolving a single dimension of a result for an operation
+ that doesnt implement `reifyDimOfResult` but implements `reifyShapeOfResult`, which
+ is used to get the required value. `reifyResultShapes` is implemented as a failure
+ (which is also the default implementation) to ensure it is not called.
+ The op semantics is that the first result has the same shape as the second operand
+ and the second result has the same shape as the first operand.
+ }];
let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
}
+def ReifyShapedTypeUsingReifyDimOfResultOp :
+ TEST_Op<"reify_shaped_type_using_reify_dim_of_result",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes", "reifyShapeOfResult", "reifyDimOfResult"]>]> {
+ let description = [{
+ Test that when resolving a single dimension of a result for an operation
+ that implements `reifyDimOfResult`, which is used to get the required value.
+ `reifyResultShapes` and `reifyShapeOfResult` are implemented as failures
+ to ensure they are not called. The op semantics is that the first result has
+ the same shape as the second operand and the second result has the same shape
+ as the first operand.
+ }];
+ let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
+ let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
+}
+
+def UnreifiableResultShapesOp : TEST_Op<"unreifiable_result_shapes",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes"]>]> {
+ let description = [{
+ Test handling of case where some dimension of the result cannot be
+ reified. This tests the path when `reifyResultShapes` is implemented.
+
+ Expected that dim 0 of `result` is reifable as dim 0 of `operand`, but
+ dim 1 of `result` is not reifiable.
+ }];
+ let arguments = (ins 2DTensorOf<[AnyType]>:$operand);
+ let results = (outs 2DTensorOf<[AnyType]>:$result);
+}
+
+def UnreifiableResultShapeOp : TEST_Op<"unreifiable_result_shape",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes", "reifyShapeOfResult"]>]> {
+ let description = [{
+ Test handling of case where some dimension of the result cannot be
+ reified. This tests the path when `reifyShapeOfResult` is implemented,
+ but not `reifyDimOfResult` with `reifyResultShapes` implemented as a failure.
+
+ Expected that dim 0 of `result` is reifable as dim 0 of `operand`, but
+ dim 1 of `result` is not reifiable.
+ }];
+ let arguments = (ins 2DTensorOf<[AnyType]>:$operand);
+ let results = (outs 2DTensorOf<[AnyType]>:$result);
+}
+
+def UnreifiableDimOfResultShapeOp : TEST_Op<"unreifiable_dim_of_result_shape",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes", "reifyShapeOfResult", "reifyDimOfResult"]>]> {
+ let description = [{
+ Test handling of case where some dimension of the result cannot be
+ reified. This tests the path when `reifyDimOfResult` is implemented,
+ and `reifyDimOfResult` with `reifyResultShapes` are implemented as a failure.
+
+ Expected that dim 0 of `result` is reifable as dim 0 of `operand`, but
+ dim 1 of `result` is not reifiable.
+ }];
+ let arguments = (ins 2DTensorOf<[AnyType]>:$operand);
+ let results = (outs 2DTensorOf<[AnyType]>:$result);
+}
+
def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
def UpdateAttr : Pat<(I32ElementsAttrOp $attr),
>From 66148fe3844ce6fac43979dd4895372ec138a251 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Fri, 24 Oct 2025 12:56:43 -0700
Subject: [PATCH 2/2] Address comments.
Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
mlir/include/mlir/Interfaces/InferTypeOpInterface.td | 4 ++--
.../resolve-shaped-type-result-dims.mlir | 8 ++++++++
2 files changed, 10 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index c949656325b2d..67568f731f597 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -401,7 +401,7 @@ def ReifyRankedShapedTypeOpInterface :
ReifiedRankedShapedTypeDims reifiedShapes;
if (failed(cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyResultShapes(builder, reifiedShapes)))
return failure();
- if (resultIndex < 0 || resultIndex >= (int)(reifiedShapes.size()))
+ if (resultIndex < 0 || resultIndex >= static_cast<int>(reifiedShapes.size()))
return $_op.emitOpError("invalid result index");
return reifiedShapes[resultIndex];
}]
@@ -427,7 +427,7 @@ def ReifyRankedShapedTypeOpInterface :
auto shapes = cast<ReifyRankedShapedTypeOpInterface>($_op.getOperation()).reifyShapeOfResult(builder, resultIndex);
if (failed(shapes))
return failure();
- if (dim < 0 || dim >= (int)((*shapes).size()))
+ if (dim < 0 || dim >= static_cast<int>((*shapes).size()))
return $_op.emitOpError("invalid dimension");
return (*shapes)[dim];
}]
diff --git a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
index ee9991cf78b45..624e0990a4bb3 100644
--- a/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Interfaces/InferShapedTypeOpInterface/resolve-shaped-type-result-dims.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(resolve-shaped-type-result-dims{error-on-pattern-iteration-limit=false}))" -split-input-file | FileCheck %s
+// See %test_unreifiable_result_shape below for why `error-on-partition-iteration-limit` is set to false.
func.func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
-> (index, index, index, index, index) {
@@ -114,6 +115,13 @@ func.func @reify_shaped_type_using_reify_dim_of_result(%arg0 : tensor<2x3x?xf32>
// -----
+// This tests also indicates a problem with the approach of just using `reifyShapes`
+// without being specific about {result, dim} that needs to be resolved. The
+// `reifyShapes` implementations introduces `dim` operations that are effectively
+// dead, but it creates an infinite loop on pattern application (which eventually
+// bails on hitting the iteration limit). This is the pitfall of this legacy
+// mechanism.
+
func.func @test_unreifiable_result_shapes(%arg0 : tensor<?x?xf32>)
-> (index, index) {
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list