[Mlir-commits] [mlir] [mlir][memref] Add a new `ReifyResultShapes` pass (PR #145927)
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Jul 1 06:31:37 PDT 2025
https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/145927
>From 465c6606d2f30eb9a8654e07b9da7d6d64e65b97 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Thu, 26 Jun 2025 18:48:11 +0200
Subject: [PATCH 1/5] [mlir][memref] Add a new InderStaticShapes pass for
ReifyRankedShapedTypeOpInterface
---
.../mlir/Dialect/MemRef/Transforms/Passes.td | 13 ++
.../Dialect/MemRef/Transforms/Transforms.h | 4 +
.../ResolveShapedTypeResultDims.cpp | 126 ++++++++++++++++++
.../Dialect/Tensor/infer-static-shapes.mlir | 18 +++
4 files changed, 161 insertions(+)
create mode 100644 mlir/test/Dialect/Tensor/infer-static-shapes.mlir
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index a8d135caa74f0..2406b47538ddc 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -182,6 +182,19 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
];
}
+def InferStaticShapesPass : Pass<"infer-static-shapes"> {
+ let summary = "Resolve memref.dim of result values";
+ let description = [{
+ The pass resolves memref.dim of result of operations that
+ implement the `InferShapedTypeOpInterface` or
+ `ReifyRankedShapedTypeOpInterface` in terms of shapes of its
+ operands.
+ }];
+ let dependentDialects = [
+ "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
+ ];
+}
+
def ExpandStridedMetadataPass : Pass<"expand-strided-metadata"> {
let summary = "Expand memref operations into easier to analyze constructs";
let description = [{
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index c2b8cb05be922..b069d5f284597 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -57,6 +57,10 @@ void populateResolveRankedShapedTypeResultDimsPatterns(
/// terms of shapes of its input operands.
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
+/// Appends patterns that allow making ReifyRankedShapedTypeOpInterface ops
+/// shapes more static.
+void populateReifyToInferStaticShapePatterns(RewritePatternSet &patterns);
+
/// Appends patterns for expanding memref operations that modify the metadata
/// (sizes, offset, strides) of a memref into easier to analyze constructs.
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 89a3895d06ba5..919b3fbc95479 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -20,13 +20,22 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/InterleavedRange.h"
+
+#define DEBUG_TYPE "resolve-shaped-type"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
namespace mlir {
namespace memref {
#define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMSPASS
#define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMSPASS
+#define GEN_PASS_DEF_INFERSTATICSHAPESPASS
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir
@@ -105,6 +114,99 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
}
};
+struct ReifyToInferStaticShapePattern
+ : public OpInterfaceRewritePattern<ReifyRankedShapedTypeOpInterface> {
+ using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+
+ LogicalResult matchAndRewrite(ReifyRankedShapedTypeOpInterface op,
+ PatternRewriter &rewriter) const override {
+ LLVM_DEBUG(
+ { DBGS() << "ReifyToInferStaticShapePattern on " << op << "\n"; });
+
+ bool rewriteToMoreStatic = false;
+ ReifiedRankedShapedTypeDims reifiedResultShapes;
+ if (failed(reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
+ reifiedResultShapes.empty()) {
+ LLVM_DEBUG({ DBGS() << "reifyResultShapes failed\n"; });
+ return failure();
+ }
+
+ SmallVector<Type> newTypes;
+ for (auto [t, reifiedShape] :
+ llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
+ ShapedType st = dyn_cast<ShapedType>(t);
+ if (!st)
+ continue;
+
+ SmallVector<int64_t> newShape;
+ for (const auto &[s, ofr] :
+ llvm::zip_equal(st.getShape(), reifiedShape)) {
+ std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
+ // Reification does not add static information, just use existing shape.
+ if (!maybeCst.has_value()) {
+ newShape.push_back(s);
+ continue;
+ }
+ int64_t cst = *maybeCst;
+ assert((ShapedType::isDynamic(s) || s == cst) &&
+ "constants must agree!");
+ newShape.push_back(cst);
+ }
+
+ if (newShape == st.getShape()) {
+ newTypes.push_back(t);
+ continue;
+ }
+
+ rewriteToMoreStatic = true;
+ Type newType = st.cloneWith(newShape, st.getElementType());
+ newTypes.push_back(newType);
+ }
+
+ LLVM_DEBUG({
+ DBGS() << "--oldTypes: " << llvm::interleaved_array(op->getResultTypes())
+ << " \n";
+ DBGS() << "--newTypes: " << llvm::interleaved_array(newTypes) << " \n";
+ });
+ if (!rewriteToMoreStatic) {
+ LLVM_DEBUG({ DBGS() << "not more static\n"; });
+ return failure();
+ }
+
+ // We now have newTypes that need to be turned to tensor::CastOp.
+ Location loc = op->getLoc();
+ SmallVector<Value> newResults;
+ Operation *newOp = rewriter.clone(*op);
+ for (auto [nt, oldVal] : llvm::zip(newTypes, op->getResults())) {
+ Type ot = oldVal.getType();
+ OpResult newResult = newOp->getResult(oldVal.getResultNumber());
+ if (ot == nt) {
+ newResults.push_back(newResult);
+ continue;
+ }
+ newResult.setType(nt);
+ if (isa<RankedTensorType>(nt)) {
+ newResults.push_back(
+ rewriter.create<tensor::CastOp>(loc, ot, newResult));
+ } else if (isa<MemRefType>(nt)) {
+ newResults.push_back(
+ rewriter.create<memref::CastOp>(loc, ot, newResult));
+ } else {
+ llvm_unreachable("expected RankedTensorType or MemRefType");
+ }
+ }
+
+ LLVM_DEBUG({
+ op->getParentOp()->dump();
+ DBGS() << "replace op " << *op << "\n";
+ DBGS() << "with newResults " << llvm::interleaved_array(newResults)
+ << "\n\n\n\n";
+ });
+ rewriter.replaceAllOpUsesWith(op, newResults);
+ return success();
+ }
+};
+
/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
///
/// ```
@@ -175,6 +277,11 @@ struct ResolveShapedTypeResultDimsPass final
void runOnOperation() override;
};
+struct InferStaticShapesPass final
+ : public memref::impl::InferStaticShapesPassBase<InferStaticShapesPass> {
+ void runOnOperation() override;
+};
+
} // namespace
void memref::populateResolveRankedShapedTypeResultDimsPatterns(
@@ -192,6 +299,11 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
patterns.getContext());
}
+void memref::populateReifyToInferStaticShapePatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ReifyToInferStaticShapePattern>(patterns.getContext());
+}
+
void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
@@ -206,3 +318,17 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
+
+void InferStaticShapesPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ patterns.add<ReifyToInferStaticShapePattern>(&getContext());
+ FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+
+ SmallVector<Operation *> opsToSimplify;
+ getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
+ opsToSimplify.push_back(op);
+ });
+ (void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns,
+ GreedyRewriteConfig().setStrictness(
+ GreedyRewriteStrictness::ExistingOps));
+}
diff --git a/mlir/test/Dialect/Tensor/infer-static-shapes.mlir b/mlir/test/Dialect/Tensor/infer-static-shapes.mlir
new file mode 100644
index 0000000000000..1712ce7df38b1
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/infer-static-shapes.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt -infer-static-shapes -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func.func @pad_reification
+func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>)
+ -> tensor<1x?x64xf32> {
+ %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
+ %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1]
+ : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+
+// CHECK: tensor.pad
+// CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32>
+ %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
+ ^bb0(%a: index, %b: index, %c: index):
+ tensor.yield %cst : f32
+ } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+
+ return %padded : tensor<1x?x64xf32>
+}
>From 1f4fba78948dec85d558236f15678c10cd289fcc Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Thu, 26 Jun 2025 19:42:35 +0000
Subject: [PATCH 2/5] rename transform to reify-result-shapes
---
.../mlir/Dialect/MemRef/Transforms/Passes.td | 39 ++++-
.../Dialect/MemRef/Transforms/Transforms.h | 16 +-
.../Dialect/MemRef/Transforms/CMakeLists.txt | 1 +
.../MemRef/Transforms/ReifyResultShapes.cpp | 144 ++++++++++++++++++
.../ResolveShapedTypeResultDims.cpp | 126 ---------------
.../Dialect/Tensor/infer-static-shapes.mlir | 18 ---
mlir/test/Dialect/Tensor/reify-shapes.mlir | 31 ++++
7 files changed, 221 insertions(+), 154 deletions(-)
create mode 100644 mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
delete mode 100644 mlir/test/Dialect/Tensor/infer-static-shapes.mlir
create mode 100644 mlir/test/Dialect/Tensor/reify-shapes.mlir
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index 2406b47538ddc..4645d49cab2be 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -182,13 +182,40 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
];
}
-def InferStaticShapesPass : Pass<"infer-static-shapes"> {
- let summary = "Resolve memref.dim of result values";
+def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
+ let summary = "Reifies the results of all `ReifyRankedShapedTypeOpInterface` operations";
let description = [{
- The pass resolves memref.dim of result of operations that
- implement the `InferShapedTypeOpInterface` or
- `ReifyRankedShapedTypeOpInterface` in terms of shapes of its
- operands.
+ This pass reifies the shapes of every `ReifyRankedShapedTypeOpInterface`
+ operation with ranked `memref` and `tensor` results. Replacing the
+ operations with their reified versions, and inserting casts when results
+ shapes are updated.
+
+ Example:
+ ```mlir
+ #map = affine_map<(d0) -> (-d0 + 256)>
+ func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
+ %0 = affine.apply #map(%arg1)
+ %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+ %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index):
+ tensor.yield %arg0 : f32
+ } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+ return %padded : tensor<1x?x64xf32>
+ }
+
+ // mlir-opt --reify-result-shapes
+ #map = affine_map<()[s0] -> (-s0 + 256)>
+ func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
+ %0 = affine.apply #map()[%arg1]
+ %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+ %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index):
+ tensor.yield %arg0 : f32
+ } : tensor<1x?x64xf32> to tensor<1x256x64xf32>
+ %cast = tensor.cast %padded : tensor<1x256x64xf32> to tensor<1x?x64xf32>
+ return %cast : tensor<1x?x64xf32>
+ }
+ ```
}];
let dependentDialects = [
"affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index b069d5f284597..5f9f09d7992ca 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -23,6 +23,7 @@ class RewritePatternSet;
class RewriterBase;
class Value;
class ValueRange;
+class ReifyRankedShapedTypeOpInterface;
namespace arith {
class WideIntEmulationConverter;
@@ -57,10 +58,6 @@ void populateResolveRankedShapedTypeResultDimsPatterns(
/// terms of shapes of its input operands.
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
-/// Appends patterns that allow making ReifyRankedShapedTypeOpInterface ops
-/// shapes more static.
-void populateReifyToInferStaticShapePatterns(RewritePatternSet &patterns);
-
/// Appends patterns for expanding memref operations that modify the metadata
/// (sizes, offset, strides) of a memref into easier to analyze constructs.
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
@@ -213,6 +210,17 @@ memref::AllocaOp allocToAlloca(
RewriterBase &rewriter, memref::AllocOp alloc,
function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
+/// Reifies the results of `op`, potentially replacing `op` with a reified
+/// version. Returns `failure` if `mlir::reifyResultShapes` returned failure,
+/// otherwise it always succeeds. Users of this transform should always expect
+/// it to modify the IR, even when it fails. If any of the result types changes,
+/// the transform will insert cast operations to the old type to keep the IR
+/// consistent.
+///
+/// Note: This transform only works on ranked `memref` or `tensor` results,
+/// other types are ignored.
+LogicalResult reifyOpResultShapes(RewriterBase &rewriter,
+ ReifyRankedShapedTypeOpInterface op);
} // namespace memref
} // namespace mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 637f5ec1c9f9b..9049faccadef3 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
IndependenceTransforms.cpp
MultiBuffer.cpp
NormalizeMemRefs.cpp
+ ReifyResultShapes.cpp
ResolveShapedTypeResultDims.cpp
RuntimeOpVerification.cpp
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
new file mode 100644
index 0000000000000..dcb601577f88f
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
@@ -0,0 +1,144 @@
+//===- ReifyResultShapes.cpp - Reify result shapes ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transform reifies result shapes of `ReifyRankedShapedTypeOpInterface`
+// operations with ranked `memref` and `tensor` results.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "llvm/Support/InterleavedRange.h"
+
+#define DEBUG_TYPE "reify-result-shapes"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+namespace mlir {
+namespace memref {
+#define GEN_PASS_DEF_REIFYRESULTSHAPESPASS
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace memref
+} // namespace mlir
+
+using namespace mlir;
+
+LogicalResult
+mlir::memref::reifyOpResultShapes(RewriterBase &rewriter,
+ ReifyRankedShapedTypeOpInterface op) {
+ LLVM_DEBUG({ DBGS() << " reifying op: " << op << "\n"; });
+ // Get the reified out shapes.
+ ReifiedRankedShapedTypeDims reifiedResultShapes;
+ if (failed(mlir::reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
+ reifiedResultShapes.empty()) {
+ return op.emitError() << "failed to get the reified shapes";
+ }
+
+ bool modified = false;
+ // Compute the new output types.
+ SmallVector<Type> outTypes;
+ for (const auto &[oldTy, reifiedShape] :
+ llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
+ // Skip if it's not a memref or tensor type.
+ if (!isa<RankedTensorType, MemRefType>(oldTy)) {
+ outTypes.push_back(oldTy);
+ continue;
+ }
+
+ ShapedType shapedTy = dyn_cast<ShapedType>(oldTy);
+
+ SmallVector<int64_t> shape = llvm::to_vector(shapedTy.getShape());
+ for (auto &&[dim, ofr] : llvm::zip_equal(shape, reifiedShape)) {
+ std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
+ // If the reified dim is dynamic set it appropriately.
+ if (!maybeCst.has_value()) {
+ dim = ShapedType::kDynamic;
+ continue;
+ }
+ // Set the static dim.
+ dim = *maybeCst;
+ }
+
+ // If the shape didn't change continue.
+ if (shape == shapedTy.getShape()) {
+ outTypes.push_back(oldTy);
+ continue;
+ }
+ modified = true;
+ outTypes.push_back(shapedTy.cloneWith(shape, shapedTy.getElementType()));
+ }
+
+ // Return if we don't need to update.
+ if (!modified) {
+ LLVM_DEBUG({ DBGS() << "- op doesn't require update\n"; });
+ return success();
+ }
+
+ LLVM_DEBUG({
+ DBGS() << "- oldTypes: " << llvm::interleaved_array(op->getResultTypes())
+ << " \n";
+ DBGS() << "- outTypes: " << llvm::interleaved_array(outTypes) << " \n";
+ });
+
+ // We now have outTypes that need to be turned to cast ops.
+ Location loc = op->getLoc();
+ SmallVector<Value> newResults;
+ Operation *newOp = rewriter.clone(*op);
+ for (auto [reifiedTy, oldRes] : llvm::zip(outTypes, op->getResults())) {
+ OpResult newRes = newOp->getResult(oldRes.getResultNumber());
+ Type oldTy = oldRes.getType();
+ // Continue if the type remained invariant or is not shaped.
+ if (oldTy == reifiedTy || !isa<MemRefType, RankedTensorType>(oldTy)) {
+ newResults.push_back(newRes);
+ continue;
+ }
+
+ // Update the type.
+ newRes.setType(reifiedTy);
+ if (isa<RankedTensorType>(reifiedTy)) {
+ newResults.push_back(rewriter.create<tensor::CastOp>(loc, oldTy, newRes));
+ } else {
+ assert(isa<MemRefType>(reifiedTy) && "expected a memref type");
+ newResults.push_back(rewriter.create<memref::CastOp>(loc, oldTy, newRes));
+ }
+ }
+
+ LLVM_DEBUG({
+ DBGS() << "- reified results " << llvm::interleaved_array(newResults)
+ << "\n";
+ });
+ rewriter.replaceOp(op, newResults);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Pass registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ReifyResultShapesPass final
+ : public memref::impl::ReifyResultShapesPassBase<ReifyResultShapesPass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void ReifyResultShapesPass::runOnOperation() {
+ SmallVector<ReifyRankedShapedTypeOpInterface> ops;
+ getOperation()->walk(
+ [&](ReifyRankedShapedTypeOpInterface op) { ops.push_back(op); });
+ IRRewriter rewriter(&getContext());
+ for (ReifyRankedShapedTypeOpInterface op : ops) {
+ rewriter.setInsertionPoint(op);
+ if (failed(memref::reifyOpResultShapes(rewriter, op)))
+ return signalPassFailure();
+ }
+}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 919b3fbc95479..89a3895d06ba5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -20,22 +20,13 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/BuiltinTypeInterfaces.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Value.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/ErrorHandling.h"
-#include "llvm/Support/InterleavedRange.h"
-
-#define DEBUG_TYPE "resolve-shaped-type"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
namespace mlir {
namespace memref {
#define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMSPASS
#define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMSPASS
-#define GEN_PASS_DEF_INFERSTATICSHAPESPASS
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir
@@ -114,99 +105,6 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
}
};
-struct ReifyToInferStaticShapePattern
- : public OpInterfaceRewritePattern<ReifyRankedShapedTypeOpInterface> {
- using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
-
- LogicalResult matchAndRewrite(ReifyRankedShapedTypeOpInterface op,
- PatternRewriter &rewriter) const override {
- LLVM_DEBUG(
- { DBGS() << "ReifyToInferStaticShapePattern on " << op << "\n"; });
-
- bool rewriteToMoreStatic = false;
- ReifiedRankedShapedTypeDims reifiedResultShapes;
- if (failed(reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
- reifiedResultShapes.empty()) {
- LLVM_DEBUG({ DBGS() << "reifyResultShapes failed\n"; });
- return failure();
- }
-
- SmallVector<Type> newTypes;
- for (auto [t, reifiedShape] :
- llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
- ShapedType st = dyn_cast<ShapedType>(t);
- if (!st)
- continue;
-
- SmallVector<int64_t> newShape;
- for (const auto &[s, ofr] :
- llvm::zip_equal(st.getShape(), reifiedShape)) {
- std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
- // Reification does not add static information, just use existing shape.
- if (!maybeCst.has_value()) {
- newShape.push_back(s);
- continue;
- }
- int64_t cst = *maybeCst;
- assert((ShapedType::isDynamic(s) || s == cst) &&
- "constants must agree!");
- newShape.push_back(cst);
- }
-
- if (newShape == st.getShape()) {
- newTypes.push_back(t);
- continue;
- }
-
- rewriteToMoreStatic = true;
- Type newType = st.cloneWith(newShape, st.getElementType());
- newTypes.push_back(newType);
- }
-
- LLVM_DEBUG({
- DBGS() << "--oldTypes: " << llvm::interleaved_array(op->getResultTypes())
- << " \n";
- DBGS() << "--newTypes: " << llvm::interleaved_array(newTypes) << " \n";
- });
- if (!rewriteToMoreStatic) {
- LLVM_DEBUG({ DBGS() << "not more static\n"; });
- return failure();
- }
-
- // We now have newTypes that need to be turned to tensor::CastOp.
- Location loc = op->getLoc();
- SmallVector<Value> newResults;
- Operation *newOp = rewriter.clone(*op);
- for (auto [nt, oldVal] : llvm::zip(newTypes, op->getResults())) {
- Type ot = oldVal.getType();
- OpResult newResult = newOp->getResult(oldVal.getResultNumber());
- if (ot == nt) {
- newResults.push_back(newResult);
- continue;
- }
- newResult.setType(nt);
- if (isa<RankedTensorType>(nt)) {
- newResults.push_back(
- rewriter.create<tensor::CastOp>(loc, ot, newResult));
- } else if (isa<MemRefType>(nt)) {
- newResults.push_back(
- rewriter.create<memref::CastOp>(loc, ot, newResult));
- } else {
- llvm_unreachable("expected RankedTensorType or MemRefType");
- }
- }
-
- LLVM_DEBUG({
- op->getParentOp()->dump();
- DBGS() << "replace op " << *op << "\n";
- DBGS() << "with newResults " << llvm::interleaved_array(newResults)
- << "\n\n\n\n";
- });
- rewriter.replaceAllOpUsesWith(op, newResults);
- return success();
- }
-};
-
/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
///
/// ```
@@ -277,11 +175,6 @@ struct ResolveShapedTypeResultDimsPass final
void runOnOperation() override;
};
-struct InferStaticShapesPass final
- : public memref::impl::InferStaticShapesPassBase<InferStaticShapesPass> {
- void runOnOperation() override;
-};
-
} // namespace
void memref::populateResolveRankedShapedTypeResultDimsPatterns(
@@ -299,11 +192,6 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
patterns.getContext());
}
-void memref::populateReifyToInferStaticShapePatterns(
- RewritePatternSet &patterns) {
- patterns.add<ReifyToInferStaticShapePattern>(patterns.getContext());
-}
-
void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
@@ -318,17 +206,3 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
-
-void InferStaticShapesPass::runOnOperation() {
- RewritePatternSet patterns(&getContext());
- patterns.add<ReifyToInferStaticShapePattern>(&getContext());
- FrozenRewritePatternSet frozenPatterns(std::move(patterns));
-
- SmallVector<Operation *> opsToSimplify;
- getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
- opsToSimplify.push_back(op);
- });
- (void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns,
- GreedyRewriteConfig().setStrictness(
- GreedyRewriteStrictness::ExistingOps));
-}
diff --git a/mlir/test/Dialect/Tensor/infer-static-shapes.mlir b/mlir/test/Dialect/Tensor/infer-static-shapes.mlir
deleted file mode 100644
index 1712ce7df38b1..0000000000000
--- a/mlir/test/Dialect/Tensor/infer-static-shapes.mlir
+++ /dev/null
@@ -1,18 +0,0 @@
-// RUN: mlir-opt -infer-static-shapes -split-input-file %s | FileCheck %s
-
-// CHECK-LABEL: func.func @pad_reification
-func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>)
- -> tensor<1x?x64xf32> {
- %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
- %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1]
- : tensor<64x?x64xf32> to tensor<1x?x64xf32>
-
-// CHECK: tensor.pad
-// CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32>
- %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
- ^bb0(%a: index, %b: index, %c: index):
- tensor.yield %cst : f32
- } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
-
- return %padded : tensor<1x?x64xf32>
-}
diff --git a/mlir/test/Dialect/Tensor/reify-shapes.mlir b/mlir/test/Dialect/Tensor/reify-shapes.mlir
new file mode 100644
index 0000000000000..5569d90f8b731
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/reify-shapes.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt -reify-result-shapes %s | FileCheck %s
+
+// The test below checks concat op reification. In the first case, no cast is inserted while on the second a cast gets inserted.
+// CHECK-LABEL: func.func @concat_reification
+func.func @concat_reification(%arg0: tensor<4x7x3xf32>, %arg1 : tensor<4x4x3xf32>, %arg2: tensor<?x?x?xf32>)
+ -> (tensor<4x11x3xf32>, tensor<?x?x?xf32>) {
+ // CHECK: %[[RES0:.*]] = tensor.concat dim(1) %{{.*}} : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+ %1 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+ // CHECK: %[[V0:.*]] = tensor.concat dim(2) %{{.*}} : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<4x7x?xf32>
+ // CHECK: %[[RES1:.*]] = tensor.cast %[[V0]] : tensor<4x7x?xf32> to tensor<?x?x?xf32>
+ %2 = tensor.concat dim(2) %arg0, %arg2 : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ // CHECK: return %[[RES0]], %[[RES1]] : tensor<4x11x3xf32>, tensor<?x?x?xf32>
+ return %1, %2 : tensor<4x11x3xf32>, tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: func.func @pad_reification
+func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
+ %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
+ %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1]
+ : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+
+ // CHECK: tensor.pad
+ // CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32>
+ // CHECK: tensor.cast %{{.*}} : tensor<1x256x64xf32> to tensor<1x?x64xf32>
+ %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
+ ^bb0(%a: index, %b: index, %c: index):
+ tensor.yield %cst : f32
+ } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+
+ return %padded : tensor<1x?x64xf32>
+}
>From ba51026882343871ebb06ced928b657c82c3a2f7 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Fri, 27 Jun 2025 11:33:21 +0200
Subject: [PATCH 3/5] Don't fail the pass if we can't make shapes more static.
Also, only allow tensor::PadOp and tensor::ConcatOp for now as more extensive testing showed that
other ops are not ready yet (e.g. at least tensor::ExtractSliceOp / tensor::InsertSliceOp).
---
.../MemRef/Transforms/ReifyResultShapes.cpp | 14 +++++++++-----
1 file changed, 9 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
index dcb601577f88f..b00a0f2103d43 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
@@ -40,7 +40,7 @@ mlir::memref::reifyOpResultShapes(RewriterBase &rewriter,
ReifiedRankedShapedTypeDims reifiedResultShapes;
if (failed(mlir::reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
reifiedResultShapes.empty()) {
- return op.emitError() << "failed to get the reified shapes";
+ return op->emitWarning() << "failed to get the reified shapes";
}
bool modified = false;
@@ -133,12 +133,16 @@ struct ReifyResultShapesPass final
void ReifyResultShapesPass::runOnOperation() {
SmallVector<ReifyRankedShapedTypeOpInterface> ops;
- getOperation()->walk(
- [&](ReifyRankedShapedTypeOpInterface op) { ops.push_back(op); });
+ getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
+ // Some ops have rigid type checkers and need to update their operands.
+ // Only admit the ones that are explicitly supported for now.
+ if (!isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation()))
+ return;
+ ops.push_back(op);
+ });
IRRewriter rewriter(&getContext());
for (ReifyRankedShapedTypeOpInterface op : ops) {
rewriter.setInsertionPoint(op);
- if (failed(memref::reifyOpResultShapes(rewriter, op)))
- return signalPassFailure();
+ (void)memref::reifyOpResultShapes(rewriter, op);
}
}
>From 08a68236424bd424d82cbee3cdbaf8cf00923dcc Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Mon, 30 Jun 2025 15:36:31 +0200
Subject: [PATCH 4/5] Update pass documentation
---
.../mlir/Dialect/MemRef/Transforms/Passes.td | 40 ++++++++++++++-----
.../MemRef/Transforms/ReifyResultShapes.cpp | 7 +++-
2 files changed, 36 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index 4645d49cab2be..f3e40aaa29075 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -183,19 +183,38 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
}
def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
- let summary = "Reifies the results of all `ReifyRankedShapedTypeOpInterface` operations";
+ let summary ="Reifies the results of `tensor::PadOp` and `tensor::ConcatOp`.";
let description = [{
- This pass reifies the shapes of every `ReifyRankedShapedTypeOpInterface`
- operation with ranked `memref` and `tensor` results. Replacing the
- operations with their reified versions, and inserting casts when results
- shapes are updated.
+ This pass reifies the shapes of a subset of `ReifyRankedShapedTypeOpInterface`
+ ops with `tensor` results.
+
+ The pass currently only supports result shape type reification for:
+ - tensor::PadOp
+ - tensor::ConcatOp
+ It addresses a representation gap where implicit op semantics are needed to
+ infer static result types from dynamic operands.
+ But it does so by using `ReifyRankedShapedTypeOpInterface` as the source of
+ truth rather than the op itself. As a consequence, this cannot generalize
+ today.
+
+ TODO: in the future, we should consider coupling this information with op
+ "transfer functions" (e.g. `IndexingMapOpInterface`) to provide a source of
+ truth that can work across result shape inference, canonicalization and op
+ verifiers.
+
+ The pass replaces the operations with their reified versions, when more
+ static information can be derived, and inserts casts when results shapes
+ are updated.
Example:
```mlir
#map = affine_map<(d0) -> (-d0 + 256)>
- func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
+ func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>)
+ -> tensor<1x?x64xf32>
+ {
%0 = affine.apply #map(%arg1)
- %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+ %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1]
+ : tensor<64x?x64xf32> to tensor<1x?x64xf32>
%padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
^bb0(%arg3: index, %arg4: index, %arg5: index):
tensor.yield %arg0 : f32
@@ -205,9 +224,12 @@ def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
// mlir-opt --reify-result-shapes
#map = affine_map<()[s0] -> (-s0 + 256)>
- func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
+ func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>)
+ -> tensor<1x?x64xf32>
+ {
%0 = affine.apply #map()[%arg1]
- %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+ %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1]
+ : tensor<64x?x64xf32> to tensor<1x?x64xf32>
%padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
^bb0(%arg3: index, %arg4: index, %arg5: index):
tensor.yield %arg0 : f32
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
index b00a0f2103d43..0a8aacb9d15a1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "llvm/Support/InterleavedRange.h"
@@ -134,8 +135,10 @@ struct ReifyResultShapesPass final
void ReifyResultShapesPass::runOnOperation() {
SmallVector<ReifyRankedShapedTypeOpInterface> ops;
getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
- // Some ops have rigid type checkers and need to update their operands.
- // Only admit the ones that are explicitly supported for now.
+ // Handle ops that are not DPS and that do not carry an tied operand shapes.
+ // For now, limit to tensor::PadOp and tensor::ConcatOp.
+ if (isa<DestinationStyleOpInterface>(op.getOperation()))
+ return;
if (!isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation()))
return;
ops.push_back(op);
>From b45abc7dcb48f84fad4d49b491fbd200eac71ab1 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Tue, 1 Jul 2025 15:30:26 +0200
Subject: [PATCH 5/5] Review comments
---
.../Dialect/MemRef/Transforms/Transforms.h | 12 -----------
.../MemRef/Transforms/ReifyResultShapes.cpp | 20 +++++++++++++------
2 files changed, 14 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 5f9f09d7992ca..33e3d94f02b1c 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -209,18 +209,6 @@ FailureOr<Value> replaceWithIndependentOp(RewriterBase &rewriter,
memref::AllocaOp allocToAlloca(
RewriterBase &rewriter, memref::AllocOp alloc,
function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
-
-/// Reifies the results of `op`, potentially replacing `op` with a reified
-/// version. Returns `failure` if `mlir::reifyResultShapes` returned failure,
-/// otherwise it always succeeds. Users of this transform should always expect
-/// it to modify the IR, even when it fails. If any of the result types changes,
-/// the transform will insert cast operations to the old type to keep the IR
-/// consistent.
-///
-/// Note: This transform only works on ranked `memref` or `tensor` results,
-/// other types are ignored.
-LogicalResult reifyOpResultShapes(RewriterBase &rewriter,
- ReifyRankedShapedTypeOpInterface op);
} // namespace memref
} // namespace mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
index 0a8aacb9d15a1..e6b9e2f7e8213 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
@@ -33,9 +33,14 @@ namespace memref {
using namespace mlir;
-LogicalResult
-mlir::memref::reifyOpResultShapes(RewriterBase &rewriter,
- ReifyRankedShapedTypeOpInterface op) {
+/// Reifies the results of `op`, potentially replacing `op` with a reified
+/// version. Returns `failure` if `mlir::reifyResultShapes` returned failure,
+/// otherwise it always succeeds. Users of this transform should always expect
+/// it to modify the IR, even when it fails. If any of the result types changes,
+/// the transform will insert cast operations to the old type to keep the IR
+/// consistent.
+static LogicalResult reifyOpResultShapes(RewriterBase &rewriter,
+ ReifyRankedShapedTypeOpInterface op) {
LLVM_DEBUG({ DBGS() << " reifying op: " << op << "\n"; });
// Get the reified out shapes.
ReifiedRankedShapedTypeDims reifiedResultShapes;
@@ -93,6 +98,11 @@ mlir::memref::reifyOpResultShapes(RewriterBase &rewriter,
// We now have outTypes that need to be turned to cast ops.
Location loc = op->getLoc();
SmallVector<Value> newResults;
+ // TODO: `mlir::reifyResultShapes` and op verifiers may not agree atm.
+ // This is a confluence problem that will need to be addressed.
+ // For now, we know PadOp and ConcatOp are fine.
+ assert((isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation())) &&
+ "incorrect op");
Operation *newOp = rewriter.clone(*op);
for (auto [reifiedTy, oldRes] : llvm::zip(outTypes, op->getResults())) {
OpResult newRes = newOp->getResult(oldRes.getResultNumber());
@@ -137,8 +147,6 @@ void ReifyResultShapesPass::runOnOperation() {
getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
// Handle ops that are not DPS and that do not carry an tied operand shapes.
// For now, limit to tensor::PadOp and tensor::ConcatOp.
- if (isa<DestinationStyleOpInterface>(op.getOperation()))
- return;
if (!isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation()))
return;
ops.push_back(op);
@@ -146,6 +154,6 @@ void ReifyResultShapesPass::runOnOperation() {
IRRewriter rewriter(&getContext());
for (ReifyRankedShapedTypeOpInterface op : ops) {
rewriter.setInsertionPoint(op);
- (void)memref::reifyOpResultShapes(rewriter, op);
+ (void)reifyOpResultShapes(rewriter, op);
}
}
More information about the Mlir-commits
mailing list