[Mlir-commits] [mlir] [mlir][memref] Add a new InderStaticShapes pass for ReifyRankedShaped… (PR #145927)

Nicolas Vasilache llvmlistbot at llvm.org
Thu Jun 26 09:53:46 PDT 2025


https://github.com/nicolasvasilache created https://github.com/llvm/llvm-project/pull/145927

…TypeOpInterface

>From 475e9198a2cdf9d67d251b80608a17731bb52246 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Thu, 26 Jun 2025 18:48:11 +0200
Subject: [PATCH] [mlir][memref] Add a new InderStaticShapes pass for
 ReifyRankedShapedTypeOpInterface

---
 .../mlir/Dialect/MemRef/Transforms/Passes.td  |  13 +++
 .../Dialect/MemRef/Transforms/Transforms.h    |   4 +
 .../ResolveShapedTypeResultDims.cpp           | 106 ++++++++++++++++++
 .../Dialect/Tensor/infer-static-shapes.mlir   |  17 +++
 4 files changed, 140 insertions(+)
 create mode 100644 mlir/test/Dialect/Tensor/infer-static-shapes.mlir

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index a8d135caa74f0..2406b47538ddc 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -182,6 +182,19 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
   ];
 }
 
+def InferStaticShapesPass : Pass<"infer-static-shapes"> {
+  let summary = "Resolve memref.dim of result values";
+  let description = [{
+    The pass resolves memref.dim of result of operations that
+    implement the `InferShapedTypeOpInterface` or
+    `ReifyRankedShapedTypeOpInterface` in terms of shapes of its
+    operands.
+  }];
+  let dependentDialects = [
+    "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
+  ];
+}
+
 def ExpandStridedMetadataPass : Pass<"expand-strided-metadata"> {
   let summary = "Expand memref operations into easier to analyze constructs";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index c2b8cb05be922..b069d5f284597 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -57,6 +57,10 @@ void populateResolveRankedShapedTypeResultDimsPatterns(
 /// terms of shapes of its input operands.
 void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
 
+/// Appends patterns that allow making ReifyRankedShapedTypeOpInterface ops
+/// shapes more static.
+void populateReifyToInferStaticShapePatterns(RewritePatternSet &patterns);
+
 /// Appends patterns for expanding memref operations that modify the metadata
 /// (sizes, offset, strides) of a memref into easier to analyze constructs.
 void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 89a3895d06ba5..b00fc925b2f12 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -20,13 +20,18 @@
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/ErrorHandling.h"
 
 namespace mlir {
 namespace memref {
 #define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMSPASS
 #define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMSPASS
+#define GEN_PASS_DEF_INFERSTATICSHAPESPASS
 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
 } // namespace memref
 } // namespace mlir
@@ -105,6 +110,83 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
   }
 };
 
+struct ReifyToInferStaticShapePattern
+    : public OpInterfaceRewritePattern<ReifyRankedShapedTypeOpInterface> {
+  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(ReifyRankedShapedTypeOpInterface op,
+                                PatternRewriter &rewriter) const override {
+
+    bool rewriteToMoreStatic = false;
+    ReifiedRankedShapedTypeDims reifiedResultShapes;
+    if (failed(reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
+        reifiedResultShapes.empty())
+      return failure();
+
+    SmallVector<Type> newTypes;
+    for (auto [t, reifiedShape] :
+         llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
+      ShapedType st = dyn_cast<ShapedType>(t);
+      if (!st)
+        continue;
+
+      SmallVector<int64_t> newShape;
+      for (const auto &[s, ofr] :
+           llvm::zip_equal(st.getShape(), reifiedShape)) {
+        std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
+        // Reification does not add static information, just use existing shape.
+        if (!maybeCst.has_value()) {
+          newShape.push_back(s);
+          continue;
+        }
+        int64_t cst = *maybeCst;
+        assert((ShapedType::isDynamic(s) || s == cst) &&
+               "constants must agree!");
+        newShape.push_back(cst);
+      }
+
+      if (newShape == st.getShape()) {
+        newTypes.push_back(t);
+        continue;
+      }
+
+      rewriteToMoreStatic = true;
+      Type oldType = st;
+      Type newType = st.cloneWith(newShape, st.getElementType());
+      newTypes.push_back(newType);
+    }
+
+    if (!rewriteToMoreStatic)
+      return failure();
+
+    // We now have newTypes that need to be turned to tensor::CastOp.
+    Location loc = op->getLoc();
+    SmallVector<Value> newResults;
+    Operation *newOp = rewriter.clone(*op);
+    for (auto [nt, oldVal] : llvm::zip(newTypes, op->getResults())) {
+      Type ot = oldVal.getType();
+      OpResult newResult = newOp->getResult(oldVal.getResultNumber());
+      if (ot == nt) {
+        newResults.push_back(newResult);
+        continue;
+      }
+      newResult.setType(nt);
+      if (isa<RankedTensorType>(nt)) {
+        newResults.push_back(
+            rewriter.create<tensor::CastOp>(loc, nt, newResult));
+      } else if (isa<MemRefType>(nt)) {
+        newResults.push_back(
+            rewriter.create<memref::CastOp>(loc, nt, newResult));
+      } else {
+        llvm_unreachable("expected RankedTensorType or MemRefType");
+      }
+    }
+
+    rewriter.replaceOp(op, newResults);
+    return success();
+  }
+};
+
 /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
 ///
 /// ```
@@ -175,6 +257,11 @@ struct ResolveShapedTypeResultDimsPass final
   void runOnOperation() override;
 };
 
+struct InferStaticShapesPass final
+    : public memref::impl::InferStaticShapesPassBase<InferStaticShapesPass> {
+  void runOnOperation() override;
+};
+
 } // namespace
 
 void memref::populateResolveRankedShapedTypeResultDimsPatterns(
@@ -192,6 +279,11 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
       patterns.getContext());
 }
 
+void memref::populateReifyToInferStaticShapePatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ReifyToInferStaticShapePattern>(patterns.getContext());
+}
+
 void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
@@ -206,3 +298,17 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
     return signalPassFailure();
 }
+
+void InferStaticShapesPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+
+  SmallVector<Operation *> opsToSimplify;
+  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+  getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
+    opsToSimplify.push_back(op);
+  });
+  (void)applyOpPatternsGreedily(
+      opsToSimplify, frozenPatterns,
+      GreedyRewriteConfig().setStrictness(
+          GreedyRewriteStrictness::ExistingAndNewOps));
+}
diff --git a/mlir/test/Dialect/Tensor/infer-static-shapes.mlir b/mlir/test/Dialect/Tensor/infer-static-shapes.mlir
new file mode 100644
index 0000000000000..85ee97b906c09
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/infer-static-shapes.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt -infer-static-shapes -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL:  func.func @pad_reification
+func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>)
+    -> tensor<1x?x64xf32> {
+  %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
+  %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] 
+    : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+
+//       CHECK: tensor.pad
+//       CHECK:   : tensor<1x?x64xf32> to tensor<1x256x64xf32>
+  %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
+    ^bb0(%a: index, %b: index, %c: index):
+    tensor.yield %cst : f32
+  } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+
+  return %padded : tensor<1x?x64xf32>
\ No newline at end of file



More information about the Mlir-commits mailing list