[Mlir-commits] [mlir] [mlir][memref] Add a new InderStaticShapes pass for ReifyRankedShaped… (PR #145927)
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Jun 26 09:53:46 PDT 2025
https://github.com/nicolasvasilache created https://github.com/llvm/llvm-project/pull/145927
…TypeOpInterface
>From 475e9198a2cdf9d67d251b80608a17731bb52246 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Thu, 26 Jun 2025 18:48:11 +0200
Subject: [PATCH] [mlir][memref] Add a new InderStaticShapes pass for
ReifyRankedShapedTypeOpInterface
---
.../mlir/Dialect/MemRef/Transforms/Passes.td | 13 +++
.../Dialect/MemRef/Transforms/Transforms.h | 4 +
.../ResolveShapedTypeResultDims.cpp | 106 ++++++++++++++++++
.../Dialect/Tensor/infer-static-shapes.mlir | 17 +++
4 files changed, 140 insertions(+)
create mode 100644 mlir/test/Dialect/Tensor/infer-static-shapes.mlir
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index a8d135caa74f0..2406b47538ddc 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -182,6 +182,19 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
];
}
+def InferStaticShapesPass : Pass<"infer-static-shapes"> {
+ let summary = "Resolve memref.dim of result values";
+ let description = [{
+ The pass resolves memref.dim of result of operations that
+ implement the `InferShapedTypeOpInterface` or
+ `ReifyRankedShapedTypeOpInterface` in terms of shapes of its
+ operands.
+ }];
+ let dependentDialects = [
+ "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
+ ];
+}
+
def ExpandStridedMetadataPass : Pass<"expand-strided-metadata"> {
let summary = "Expand memref operations into easier to analyze constructs";
let description = [{
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index c2b8cb05be922..b069d5f284597 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -57,6 +57,10 @@ void populateResolveRankedShapedTypeResultDimsPatterns(
/// terms of shapes of its input operands.
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
+/// Appends patterns that allow making ReifyRankedShapedTypeOpInterface ops
+/// shapes more static.
+void populateReifyToInferStaticShapePatterns(RewritePatternSet &patterns);
+
/// Appends patterns for expanding memref operations that modify the metadata
/// (sizes, offset, strides) of a memref into easier to analyze constructs.
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 89a3895d06ba5..b00fc925b2f12 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -20,13 +20,18 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/ErrorHandling.h"
namespace mlir {
namespace memref {
#define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMSPASS
#define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMSPASS
+#define GEN_PASS_DEF_INFERSTATICSHAPESPASS
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir
@@ -105,6 +110,83 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
}
};
+struct ReifyToInferStaticShapePattern
+ : public OpInterfaceRewritePattern<ReifyRankedShapedTypeOpInterface> {
+ using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+
+ LogicalResult matchAndRewrite(ReifyRankedShapedTypeOpInterface op,
+ PatternRewriter &rewriter) const override {
+
+ bool rewriteToMoreStatic = false;
+ ReifiedRankedShapedTypeDims reifiedResultShapes;
+ if (failed(reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
+ reifiedResultShapes.empty())
+ return failure();
+
+ SmallVector<Type> newTypes;
+ for (auto [t, reifiedShape] :
+ llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
+ ShapedType st = dyn_cast<ShapedType>(t);
+ if (!st)
+ continue;
+
+ SmallVector<int64_t> newShape;
+ for (const auto &[s, ofr] :
+ llvm::zip_equal(st.getShape(), reifiedShape)) {
+ std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
+ // Reification does not add static information, just use existing shape.
+ if (!maybeCst.has_value()) {
+ newShape.push_back(s);
+ continue;
+ }
+ int64_t cst = *maybeCst;
+ assert((ShapedType::isDynamic(s) || s == cst) &&
+ "constants must agree!");
+ newShape.push_back(cst);
+ }
+
+ if (newShape == st.getShape()) {
+ newTypes.push_back(t);
+ continue;
+ }
+
+ rewriteToMoreStatic = true;
+ Type oldType = st;
+ Type newType = st.cloneWith(newShape, st.getElementType());
+ newTypes.push_back(newType);
+ }
+
+ if (!rewriteToMoreStatic)
+ return failure();
+
+ // We now have newTypes that need to be turned to tensor::CastOp.
+ Location loc = op->getLoc();
+ SmallVector<Value> newResults;
+ Operation *newOp = rewriter.clone(*op);
+ for (auto [nt, oldVal] : llvm::zip(newTypes, op->getResults())) {
+ Type ot = oldVal.getType();
+ OpResult newResult = newOp->getResult(oldVal.getResultNumber());
+ if (ot == nt) {
+ newResults.push_back(newResult);
+ continue;
+ }
+ newResult.setType(nt);
+ if (isa<RankedTensorType>(nt)) {
+ newResults.push_back(
+ rewriter.create<tensor::CastOp>(loc, nt, newResult));
+ } else if (isa<MemRefType>(nt)) {
+ newResults.push_back(
+ rewriter.create<memref::CastOp>(loc, nt, newResult));
+ } else {
+ llvm_unreachable("expected RankedTensorType or MemRefType");
+ }
+ }
+
+ rewriter.replaceOp(op, newResults);
+ return success();
+ }
+};
+
/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
///
/// ```
@@ -175,6 +257,11 @@ struct ResolveShapedTypeResultDimsPass final
void runOnOperation() override;
};
+struct InferStaticShapesPass final
+ : public memref::impl::InferStaticShapesPassBase<InferStaticShapesPass> {
+ void runOnOperation() override;
+};
+
} // namespace
void memref::populateResolveRankedShapedTypeResultDimsPatterns(
@@ -192,6 +279,11 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
patterns.getContext());
}
+void memref::populateReifyToInferStaticShapePatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ReifyToInferStaticShapePattern>(patterns.getContext());
+}
+
void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
@@ -206,3 +298,17 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
+
+void InferStaticShapesPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+
+ SmallVector<Operation *> opsToSimplify;
+ FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+ getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
+ opsToSimplify.push_back(op);
+ });
+ (void)applyOpPatternsGreedily(
+ opsToSimplify, frozenPatterns,
+ GreedyRewriteConfig().setStrictness(
+ GreedyRewriteStrictness::ExistingAndNewOps));
+}
diff --git a/mlir/test/Dialect/Tensor/infer-static-shapes.mlir b/mlir/test/Dialect/Tensor/infer-static-shapes.mlir
new file mode 100644
index 0000000000000..85ee97b906c09
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/infer-static-shapes.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt -infer-static-shapes -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func.func @pad_reification
+func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>)
+ -> tensor<1x?x64xf32> {
+ %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
+ %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1]
+ : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+
+// CHECK: tensor.pad
+// CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32>
+ %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
+ ^bb0(%a: index, %b: index, %c: index):
+ tensor.yield %cst : f32
+ } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+
+ return %padded : tensor<1x?x64xf32>
\ No newline at end of file
More information about the Mlir-commits
mailing list