[Mlir-commits] [mlir] [mlir][memref] Add a new `ReifyResultShapes` pass (PR #145927)

Nicolas Vasilache llvmlistbot at llvm.org
Tue Jul 1 06:31:37 PDT 2025


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/145927

>From 465c6606d2f30eb9a8654e07b9da7d6d64e65b97 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Thu, 26 Jun 2025 18:48:11 +0200
Subject: [PATCH 1/5] [mlir][memref] Add a new InderStaticShapes pass for
 ReifyRankedShapedTypeOpInterface

---
 .../mlir/Dialect/MemRef/Transforms/Passes.td  |  13 ++
 .../Dialect/MemRef/Transforms/Transforms.h    |   4 +
 .../ResolveShapedTypeResultDims.cpp           | 126 ++++++++++++++++++
 .../Dialect/Tensor/infer-static-shapes.mlir   |  18 +++
 4 files changed, 161 insertions(+)
 create mode 100644 mlir/test/Dialect/Tensor/infer-static-shapes.mlir

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index a8d135caa74f0..2406b47538ddc 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -182,6 +182,19 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
   ];
 }
 
+def InferStaticShapesPass : Pass<"infer-static-shapes"> {
+  let summary = "Resolve memref.dim of result values";
+  let description = [{
+    The pass resolves memref.dim of result of operations that
+    implement the `InferShapedTypeOpInterface` or
+    `ReifyRankedShapedTypeOpInterface` in terms of shapes of its
+    operands.
+  }];
+  let dependentDialects = [
+    "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
+  ];
+}
+
 def ExpandStridedMetadataPass : Pass<"expand-strided-metadata"> {
   let summary = "Expand memref operations into easier to analyze constructs";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index c2b8cb05be922..b069d5f284597 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -57,6 +57,10 @@ void populateResolveRankedShapedTypeResultDimsPatterns(
 /// terms of shapes of its input operands.
 void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
 
+/// Appends patterns that allow making ReifyRankedShapedTypeOpInterface ops
+/// shapes more static.
+void populateReifyToInferStaticShapePatterns(RewritePatternSet &patterns);
+
 /// Appends patterns for expanding memref operations that modify the metadata
 /// (sizes, offset, strides) of a memref into easier to analyze constructs.
 void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 89a3895d06ba5..919b3fbc95479 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -20,13 +20,22 @@
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/InterleavedRange.h"
+
+#define DEBUG_TYPE "resolve-shaped-type"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
 
 namespace mlir {
 namespace memref {
 #define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMSPASS
 #define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMSPASS
+#define GEN_PASS_DEF_INFERSTATICSHAPESPASS
 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
 } // namespace memref
 } // namespace mlir
@@ -105,6 +114,99 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
   }
 };
 
+struct ReifyToInferStaticShapePattern
+    : public OpInterfaceRewritePattern<ReifyRankedShapedTypeOpInterface> {
+  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+
+  LogicalResult matchAndRewrite(ReifyRankedShapedTypeOpInterface op,
+                                PatternRewriter &rewriter) const override {
+    LLVM_DEBUG(
+        { DBGS() << "ReifyToInferStaticShapePattern on " << op << "\n"; });
+
+    bool rewriteToMoreStatic = false;
+    ReifiedRankedShapedTypeDims reifiedResultShapes;
+    if (failed(reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
+        reifiedResultShapes.empty()) {
+      LLVM_DEBUG({ DBGS() << "reifyResultShapes failed\n"; });
+      return failure();
+    }
+
+    SmallVector<Type> newTypes;
+    for (auto [t, reifiedShape] :
+         llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
+      ShapedType st = dyn_cast<ShapedType>(t);
+      if (!st)
+        continue;
+
+      SmallVector<int64_t> newShape;
+      for (const auto &[s, ofr] :
+           llvm::zip_equal(st.getShape(), reifiedShape)) {
+        std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
+        // Reification does not add static information, just use existing shape.
+        if (!maybeCst.has_value()) {
+          newShape.push_back(s);
+          continue;
+        }
+        int64_t cst = *maybeCst;
+        assert((ShapedType::isDynamic(s) || s == cst) &&
+               "constants must agree!");
+        newShape.push_back(cst);
+      }
+
+      if (newShape == st.getShape()) {
+        newTypes.push_back(t);
+        continue;
+      }
+
+      rewriteToMoreStatic = true;
+      Type newType = st.cloneWith(newShape, st.getElementType());
+      newTypes.push_back(newType);
+    }
+
+    LLVM_DEBUG({
+      DBGS() << "--oldTypes: " << llvm::interleaved_array(op->getResultTypes())
+             << " \n";
+      DBGS() << "--newTypes: " << llvm::interleaved_array(newTypes) << " \n";
+    });
+    if (!rewriteToMoreStatic) {
+      LLVM_DEBUG({ DBGS() << "not more static\n"; });
+      return failure();
+    }
+
+    // We now have newTypes that need to be turned to tensor::CastOp.
+    Location loc = op->getLoc();
+    SmallVector<Value> newResults;
+    Operation *newOp = rewriter.clone(*op);
+    for (auto [nt, oldVal] : llvm::zip(newTypes, op->getResults())) {
+      Type ot = oldVal.getType();
+      OpResult newResult = newOp->getResult(oldVal.getResultNumber());
+      if (ot == nt) {
+        newResults.push_back(newResult);
+        continue;
+      }
+      newResult.setType(nt);
+      if (isa<RankedTensorType>(nt)) {
+        newResults.push_back(
+            rewriter.create<tensor::CastOp>(loc, ot, newResult));
+      } else if (isa<MemRefType>(nt)) {
+        newResults.push_back(
+            rewriter.create<memref::CastOp>(loc, ot, newResult));
+      } else {
+        llvm_unreachable("expected RankedTensorType or MemRefType");
+      }
+    }
+
+    LLVM_DEBUG({
+      op->getParentOp()->dump();
+      DBGS() << "replace op " << *op << "\n";
+      DBGS() << "with newResults " << llvm::interleaved_array(newResults)
+             << "\n\n\n\n";
+    });
+    rewriter.replaceAllOpUsesWith(op, newResults);
+    return success();
+  }
+};
+
 /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
 ///
 /// ```
@@ -175,6 +277,11 @@ struct ResolveShapedTypeResultDimsPass final
   void runOnOperation() override;
 };
 
+struct InferStaticShapesPass final
+    : public memref::impl::InferStaticShapesPassBase<InferStaticShapesPass> {
+  void runOnOperation() override;
+};
+
 } // namespace
 
 void memref::populateResolveRankedShapedTypeResultDimsPatterns(
@@ -192,6 +299,11 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
       patterns.getContext());
 }
 
+void memref::populateReifyToInferStaticShapePatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ReifyToInferStaticShapePattern>(patterns.getContext());
+}
+
 void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
@@ -206,3 +318,17 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
     return signalPassFailure();
 }
+
+void InferStaticShapesPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  patterns.add<ReifyToInferStaticShapePattern>(&getContext());
+  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+
+  SmallVector<Operation *> opsToSimplify;
+  getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
+    opsToSimplify.push_back(op);
+  });
+  (void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns,
+                                GreedyRewriteConfig().setStrictness(
+                                    GreedyRewriteStrictness::ExistingOps));
+}
diff --git a/mlir/test/Dialect/Tensor/infer-static-shapes.mlir b/mlir/test/Dialect/Tensor/infer-static-shapes.mlir
new file mode 100644
index 0000000000000..1712ce7df38b1
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/infer-static-shapes.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt -infer-static-shapes -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL:  func.func @pad_reification
+func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>)
+    -> tensor<1x?x64xf32> {
+  %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
+  %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] 
+    : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+
+//       CHECK: tensor.pad
+//       CHECK:   : tensor<1x?x64xf32> to tensor<1x256x64xf32>
+  %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
+    ^bb0(%a: index, %b: index, %c: index):
+    tensor.yield %cst : f32
+  } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+
+  return %padded : tensor<1x?x64xf32>
+}

>From 1f4fba78948dec85d558236f15678c10cd289fcc Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Thu, 26 Jun 2025 19:42:35 +0000
Subject: [PATCH 2/5] rename transform to reify-result-shapes

---
 .../mlir/Dialect/MemRef/Transforms/Passes.td  |  39 ++++-
 .../Dialect/MemRef/Transforms/Transforms.h    |  16 +-
 .../Dialect/MemRef/Transforms/CMakeLists.txt  |   1 +
 .../MemRef/Transforms/ReifyResultShapes.cpp   | 144 ++++++++++++++++++
 .../ResolveShapedTypeResultDims.cpp           | 126 ---------------
 .../Dialect/Tensor/infer-static-shapes.mlir   |  18 ---
 mlir/test/Dialect/Tensor/reify-shapes.mlir    |  31 ++++
 7 files changed, 221 insertions(+), 154 deletions(-)
 create mode 100644 mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
 delete mode 100644 mlir/test/Dialect/Tensor/infer-static-shapes.mlir
 create mode 100644 mlir/test/Dialect/Tensor/reify-shapes.mlir

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index 2406b47538ddc..4645d49cab2be 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -182,13 +182,40 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
   ];
 }
 
-def InferStaticShapesPass : Pass<"infer-static-shapes"> {
-  let summary = "Resolve memref.dim of result values";
+def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
+  let summary = "Reifies the results of all `ReifyRankedShapedTypeOpInterface` operations";
   let description = [{
-    The pass resolves memref.dim of result of operations that
-    implement the `InferShapedTypeOpInterface` or
-    `ReifyRankedShapedTypeOpInterface` in terms of shapes of its
-    operands.
+    This pass reifies the shapes of every `ReifyRankedShapedTypeOpInterface`
+    operation with ranked `memref` and `tensor` results. Replacing the
+    operations with their reified versions, and inserting casts when results
+    shapes are updated.
+
+    Example:
+    ```mlir
+    #map = affine_map<(d0) -> (-d0 + 256)>
+    func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
+      %0 = affine.apply #map(%arg1)
+      %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+      %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
+      ^bb0(%arg3: index, %arg4: index, %arg5: index):
+        tensor.yield %arg0 : f32
+      } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+      return %padded : tensor<1x?x64xf32>
+    }
+
+    // mlir-opt --reify-result-shapes
+    #map = affine_map<()[s0] -> (-s0 + 256)>
+    func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
+      %0 = affine.apply #map()[%arg1]
+      %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+      %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
+      ^bb0(%arg3: index, %arg4: index, %arg5: index):
+        tensor.yield %arg0 : f32
+      } : tensor<1x?x64xf32> to tensor<1x256x64xf32>
+      %cast = tensor.cast %padded : tensor<1x256x64xf32> to tensor<1x?x64xf32>
+      return %cast : tensor<1x?x64xf32>
+    }
+    ```
   }];
   let dependentDialects = [
     "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index b069d5f284597..5f9f09d7992ca 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -23,6 +23,7 @@ class RewritePatternSet;
 class RewriterBase;
 class Value;
 class ValueRange;
+class ReifyRankedShapedTypeOpInterface;
 
 namespace arith {
 class WideIntEmulationConverter;
@@ -57,10 +58,6 @@ void populateResolveRankedShapedTypeResultDimsPatterns(
 /// terms of shapes of its input operands.
 void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
 
-/// Appends patterns that allow making ReifyRankedShapedTypeOpInterface ops
-/// shapes more static.
-void populateReifyToInferStaticShapePatterns(RewritePatternSet &patterns);
-
 /// Appends patterns for expanding memref operations that modify the metadata
 /// (sizes, offset, strides) of a memref into easier to analyze constructs.
 void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
@@ -213,6 +210,17 @@ memref::AllocaOp allocToAlloca(
     RewriterBase &rewriter, memref::AllocOp alloc,
     function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
 
+/// Reifies the results of `op`, potentially replacing `op` with a reified
+/// version. Returns `failure` if `mlir::reifyResultShapes` returned failure,
+/// otherwise it always succeeds. Users of this transform should always expect
+/// it to modify the IR, even when it fails. If any of the result types changes,
+/// the transform will insert cast operations to the old type to keep the IR
+/// consistent.
+///
+/// Note: This transform only works on ranked `memref` or `tensor` results,
+/// other types are ignored.
+LogicalResult reifyOpResultShapes(RewriterBase &rewriter,
+                                  ReifyRankedShapedTypeOpInterface op);
 } // namespace memref
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 637f5ec1c9f9b..9049faccadef3 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   IndependenceTransforms.cpp
   MultiBuffer.cpp
   NormalizeMemRefs.cpp
+  ReifyResultShapes.cpp
   ResolveShapedTypeResultDims.cpp
   RuntimeOpVerification.cpp
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
new file mode 100644
index 0000000000000..dcb601577f88f
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
@@ -0,0 +1,144 @@
+//===- ReifyResultShapes.cpp - Reify result shapes ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transform reifies result shapes of `ReifyRankedShapedTypeOpInterface`
+// operations with ranked `memref` and `tensor` results.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "llvm/Support/InterleavedRange.h"
+
+#define DEBUG_TYPE "reify-result-shapes"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+namespace mlir {
+namespace memref {
+#define GEN_PASS_DEF_REIFYRESULTSHAPESPASS
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace memref
+} // namespace mlir
+
+using namespace mlir;
+
+LogicalResult
+mlir::memref::reifyOpResultShapes(RewriterBase &rewriter,
+                                  ReifyRankedShapedTypeOpInterface op) {
+  LLVM_DEBUG({ DBGS() << " reifying op: " << op << "\n"; });
+  // Get the reified out shapes.
+  ReifiedRankedShapedTypeDims reifiedResultShapes;
+  if (failed(mlir::reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
+      reifiedResultShapes.empty()) {
+    return op.emitError() << "failed to get the reified shapes";
+  }
+
+  bool modified = false;
+  // Compute the new output types.
+  SmallVector<Type> outTypes;
+  for (const auto &[oldTy, reifiedShape] :
+       llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
+    // Skip if it's not a memref or tensor type.
+    if (!isa<RankedTensorType, MemRefType>(oldTy)) {
+      outTypes.push_back(oldTy);
+      continue;
+    }
+
+    ShapedType shapedTy = dyn_cast<ShapedType>(oldTy);
+
+    SmallVector<int64_t> shape = llvm::to_vector(shapedTy.getShape());
+    for (auto &&[dim, ofr] : llvm::zip_equal(shape, reifiedShape)) {
+      std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
+      // If the reified dim is dynamic set it appropriately.
+      if (!maybeCst.has_value()) {
+        dim = ShapedType::kDynamic;
+        continue;
+      }
+      // Set the static dim.
+      dim = *maybeCst;
+    }
+
+    // If the shape didn't change continue.
+    if (shape == shapedTy.getShape()) {
+      outTypes.push_back(oldTy);
+      continue;
+    }
+    modified = true;
+    outTypes.push_back(shapedTy.cloneWith(shape, shapedTy.getElementType()));
+  }
+
+  // Return if we don't need to update.
+  if (!modified) {
+    LLVM_DEBUG({ DBGS() << "- op doesn't require update\n"; });
+    return success();
+  }
+
+  LLVM_DEBUG({
+    DBGS() << "- oldTypes: " << llvm::interleaved_array(op->getResultTypes())
+           << " \n";
+    DBGS() << "- outTypes: " << llvm::interleaved_array(outTypes) << " \n";
+  });
+
+  // We now have outTypes that need to be turned to cast ops.
+  Location loc = op->getLoc();
+  SmallVector<Value> newResults;
+  Operation *newOp = rewriter.clone(*op);
+  for (auto [reifiedTy, oldRes] : llvm::zip(outTypes, op->getResults())) {
+    OpResult newRes = newOp->getResult(oldRes.getResultNumber());
+    Type oldTy = oldRes.getType();
+    // Continue if the type remained invariant or is not shaped.
+    if (oldTy == reifiedTy || !isa<MemRefType, RankedTensorType>(oldTy)) {
+      newResults.push_back(newRes);
+      continue;
+    }
+
+    // Update the type.
+    newRes.setType(reifiedTy);
+    if (isa<RankedTensorType>(reifiedTy)) {
+      newResults.push_back(rewriter.create<tensor::CastOp>(loc, oldTy, newRes));
+    } else {
+      assert(isa<MemRefType>(reifiedTy) && "expected a memref type");
+      newResults.push_back(rewriter.create<memref::CastOp>(loc, oldTy, newRes));
+    }
+  }
+
+  LLVM_DEBUG({
+    DBGS() << "- reified results " << llvm::interleaved_array(newResults)
+           << "\n";
+  });
+  rewriter.replaceOp(op, newResults);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Pass registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ReifyResultShapesPass final
+    : public memref::impl::ReifyResultShapesPassBase<ReifyResultShapesPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void ReifyResultShapesPass::runOnOperation() {
+  SmallVector<ReifyRankedShapedTypeOpInterface> ops;
+  getOperation()->walk(
+      [&](ReifyRankedShapedTypeOpInterface op) { ops.push_back(op); });
+  IRRewriter rewriter(&getContext());
+  for (ReifyRankedShapedTypeOpInterface op : ops) {
+    rewriter.setInsertionPoint(op);
+    if (failed(memref::reifyOpResultShapes(rewriter, op)))
+      return signalPassFailure();
+  }
+}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 919b3fbc95479..89a3895d06ba5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -20,22 +20,13 @@
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/BuiltinTypeInterfaces.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/ErrorHandling.h"
-#include "llvm/Support/InterleavedRange.h"
-
-#define DEBUG_TYPE "resolve-shaped-type"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
 
 namespace mlir {
 namespace memref {
 #define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMSPASS
 #define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMSPASS
-#define GEN_PASS_DEF_INFERSTATICSHAPESPASS
 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
 } // namespace memref
 } // namespace mlir
@@ -114,99 +105,6 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
   }
 };
 
-struct ReifyToInferStaticShapePattern
-    : public OpInterfaceRewritePattern<ReifyRankedShapedTypeOpInterface> {
-  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
-
-  LogicalResult matchAndRewrite(ReifyRankedShapedTypeOpInterface op,
-                                PatternRewriter &rewriter) const override {
-    LLVM_DEBUG(
-        { DBGS() << "ReifyToInferStaticShapePattern on " << op << "\n"; });
-
-    bool rewriteToMoreStatic = false;
-    ReifiedRankedShapedTypeDims reifiedResultShapes;
-    if (failed(reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
-        reifiedResultShapes.empty()) {
-      LLVM_DEBUG({ DBGS() << "reifyResultShapes failed\n"; });
-      return failure();
-    }
-
-    SmallVector<Type> newTypes;
-    for (auto [t, reifiedShape] :
-         llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
-      ShapedType st = dyn_cast<ShapedType>(t);
-      if (!st)
-        continue;
-
-      SmallVector<int64_t> newShape;
-      for (const auto &[s, ofr] :
-           llvm::zip_equal(st.getShape(), reifiedShape)) {
-        std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
-        // Reification does not add static information, just use existing shape.
-        if (!maybeCst.has_value()) {
-          newShape.push_back(s);
-          continue;
-        }
-        int64_t cst = *maybeCst;
-        assert((ShapedType::isDynamic(s) || s == cst) &&
-               "constants must agree!");
-        newShape.push_back(cst);
-      }
-
-      if (newShape == st.getShape()) {
-        newTypes.push_back(t);
-        continue;
-      }
-
-      rewriteToMoreStatic = true;
-      Type newType = st.cloneWith(newShape, st.getElementType());
-      newTypes.push_back(newType);
-    }
-
-    LLVM_DEBUG({
-      DBGS() << "--oldTypes: " << llvm::interleaved_array(op->getResultTypes())
-             << " \n";
-      DBGS() << "--newTypes: " << llvm::interleaved_array(newTypes) << " \n";
-    });
-    if (!rewriteToMoreStatic) {
-      LLVM_DEBUG({ DBGS() << "not more static\n"; });
-      return failure();
-    }
-
-    // We now have newTypes that need to be turned to tensor::CastOp.
-    Location loc = op->getLoc();
-    SmallVector<Value> newResults;
-    Operation *newOp = rewriter.clone(*op);
-    for (auto [nt, oldVal] : llvm::zip(newTypes, op->getResults())) {
-      Type ot = oldVal.getType();
-      OpResult newResult = newOp->getResult(oldVal.getResultNumber());
-      if (ot == nt) {
-        newResults.push_back(newResult);
-        continue;
-      }
-      newResult.setType(nt);
-      if (isa<RankedTensorType>(nt)) {
-        newResults.push_back(
-            rewriter.create<tensor::CastOp>(loc, ot, newResult));
-      } else if (isa<MemRefType>(nt)) {
-        newResults.push_back(
-            rewriter.create<memref::CastOp>(loc, ot, newResult));
-      } else {
-        llvm_unreachable("expected RankedTensorType or MemRefType");
-      }
-    }
-
-    LLVM_DEBUG({
-      op->getParentOp()->dump();
-      DBGS() << "replace op " << *op << "\n";
-      DBGS() << "with newResults " << llvm::interleaved_array(newResults)
-             << "\n\n\n\n";
-    });
-    rewriter.replaceAllOpUsesWith(op, newResults);
-    return success();
-  }
-};
-
 /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
 ///
 /// ```
@@ -277,11 +175,6 @@ struct ResolveShapedTypeResultDimsPass final
   void runOnOperation() override;
 };
 
-struct InferStaticShapesPass final
-    : public memref::impl::InferStaticShapesPassBase<InferStaticShapesPass> {
-  void runOnOperation() override;
-};
-
 } // namespace
 
 void memref::populateResolveRankedShapedTypeResultDimsPatterns(
@@ -299,11 +192,6 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
       patterns.getContext());
 }
 
-void memref::populateReifyToInferStaticShapePatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<ReifyToInferStaticShapePattern>(patterns.getContext());
-}
-
 void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
@@ -318,17 +206,3 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
     return signalPassFailure();
 }
-
-void InferStaticShapesPass::runOnOperation() {
-  RewritePatternSet patterns(&getContext());
-  patterns.add<ReifyToInferStaticShapePattern>(&getContext());
-  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
-
-  SmallVector<Operation *> opsToSimplify;
-  getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
-    opsToSimplify.push_back(op);
-  });
-  (void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns,
-                                GreedyRewriteConfig().setStrictness(
-                                    GreedyRewriteStrictness::ExistingOps));
-}
diff --git a/mlir/test/Dialect/Tensor/infer-static-shapes.mlir b/mlir/test/Dialect/Tensor/infer-static-shapes.mlir
deleted file mode 100644
index 1712ce7df38b1..0000000000000
--- a/mlir/test/Dialect/Tensor/infer-static-shapes.mlir
+++ /dev/null
@@ -1,18 +0,0 @@
-// RUN: mlir-opt -infer-static-shapes -split-input-file %s | FileCheck %s
-
-// CHECK-LABEL:  func.func @pad_reification
-func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>)
-    -> tensor<1x?x64xf32> {
-  %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
-  %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] 
-    : tensor<64x?x64xf32> to tensor<1x?x64xf32>
-
-//       CHECK: tensor.pad
-//       CHECK:   : tensor<1x?x64xf32> to tensor<1x256x64xf32>
-  %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
-    ^bb0(%a: index, %b: index, %c: index):
-    tensor.yield %cst : f32
-  } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
-
-  return %padded : tensor<1x?x64xf32>
-}
diff --git a/mlir/test/Dialect/Tensor/reify-shapes.mlir b/mlir/test/Dialect/Tensor/reify-shapes.mlir
new file mode 100644
index 0000000000000..5569d90f8b731
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/reify-shapes.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt -reify-result-shapes  %s | FileCheck %s
+
+// The test below checks concat op reification. In the first case, no cast is inserted while on the second a cast gets inserted.
+// CHECK-LABEL:  func.func @concat_reification
+func.func @concat_reification(%arg0: tensor<4x7x3xf32>, %arg1 : tensor<4x4x3xf32>, %arg2: tensor<?x?x?xf32>)
+  -> (tensor<4x11x3xf32>, tensor<?x?x?xf32>) {
+  // CHECK: %[[RES0:.*]] = tensor.concat dim(1) %{{.*}} : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+  %1 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+  // CHECK: %[[V0:.*]] = tensor.concat dim(2) %{{.*}} : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<4x7x?xf32>
+  // CHECK: %[[RES1:.*]] = tensor.cast %[[V0]] : tensor<4x7x?xf32> to tensor<?x?x?xf32>
+  %2 = tensor.concat dim(2) %arg0, %arg2 : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  // CHECK: return %[[RES0]], %[[RES1]] : tensor<4x11x3xf32>, tensor<?x?x?xf32>
+  return %1, %2 : tensor<4x11x3xf32>, tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL:  func.func @pad_reification
+func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
+  %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
+  %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] 
+    : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+
+  // CHECK: tensor.pad
+  // CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32>
+  // CHECK: tensor.cast %{{.*}} : tensor<1x256x64xf32> to tensor<1x?x64xf32>
+  %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
+    ^bb0(%a: index, %b: index, %c: index):
+    tensor.yield %cst : f32
+  } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+
+  return %padded : tensor<1x?x64xf32>
+}

>From ba51026882343871ebb06ced928b657c82c3a2f7 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Fri, 27 Jun 2025 11:33:21 +0200
Subject: [PATCH 3/5] Don't fail the pass if we can't make shapes more static.

Also, only allow tensor::PadOp and tensor::ConcatOp for now as more extensive testing showed that
other ops are not ready yet (e.g. at least tensor::ExtractSliceOp / tensor::InsertSliceOp).
---
 .../MemRef/Transforms/ReifyResultShapes.cpp        | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
index dcb601577f88f..b00a0f2103d43 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
@@ -40,7 +40,7 @@ mlir::memref::reifyOpResultShapes(RewriterBase &rewriter,
   ReifiedRankedShapedTypeDims reifiedResultShapes;
   if (failed(mlir::reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
       reifiedResultShapes.empty()) {
-    return op.emitError() << "failed to get the reified shapes";
+    return op->emitWarning() << "failed to get the reified shapes";
   }
 
   bool modified = false;
@@ -133,12 +133,16 @@ struct ReifyResultShapesPass final
 
 void ReifyResultShapesPass::runOnOperation() {
   SmallVector<ReifyRankedShapedTypeOpInterface> ops;
-  getOperation()->walk(
-      [&](ReifyRankedShapedTypeOpInterface op) { ops.push_back(op); });
+  getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
+    // Some ops have rigid type checkers and need to update their operands.
+    // Only admit the ones that are explicitly supported for now.
+    if (!isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation()))
+      return;
+    ops.push_back(op);
+  });
   IRRewriter rewriter(&getContext());
   for (ReifyRankedShapedTypeOpInterface op : ops) {
     rewriter.setInsertionPoint(op);
-    if (failed(memref::reifyOpResultShapes(rewriter, op)))
-      return signalPassFailure();
+    (void)memref::reifyOpResultShapes(rewriter, op);
   }
 }

>From 08a68236424bd424d82cbee3cdbaf8cf00923dcc Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Mon, 30 Jun 2025 15:36:31 +0200
Subject: [PATCH 4/5] Update pass documentation

---
 .../mlir/Dialect/MemRef/Transforms/Passes.td  | 40 ++++++++++++++-----
 .../MemRef/Transforms/ReifyResultShapes.cpp   |  7 +++-
 2 files changed, 36 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index 4645d49cab2be..f3e40aaa29075 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -183,19 +183,38 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
 }
 
 def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
-  let summary = "Reifies the results of all `ReifyRankedShapedTypeOpInterface` operations";
+  let summary ="Reifies the results of `tensor::PadOp` and `tensor::ConcatOp`.";
   let description = [{
-    This pass reifies the shapes of every `ReifyRankedShapedTypeOpInterface`
-    operation with ranked `memref` and `tensor` results. Replacing the
-    operations with their reified versions, and inserting casts when results
-    shapes are updated.
+    This pass reifies the shapes of a subset of `ReifyRankedShapedTypeOpInterface`
+    ops with `tensor` results.
+    
+    The pass currently only supports result shape type reification for:
+      - tensor::PadOp
+      - tensor::ConcatOp
+    It addresses a representation gap where implicit op semantics are needed to
+    infer static result types from dynamic operands.
+    But it does so by using `ReifyRankedShapedTypeOpInterface` as the source of
+    truth rather than the op itself. As a consequence, this cannot generalize
+    today.
+
+    TODO: in the future, we should consider coupling this information with op
+    "transfer functions" (e.g. `IndexingMapOpInterface`) to provide a source of
+    truth that can work across result shape inference, canonicalization and op
+    verifiers.
+
+    The pass replaces the operations with their reified versions, when more
+    static information can be derived, and inserts casts when results shapes
+    are updated.
 
     Example:
     ```mlir
     #map = affine_map<(d0) -> (-d0 + 256)>
-    func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
+    func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>)
+        -> tensor<1x?x64xf32>
+    {
       %0 = affine.apply #map(%arg1)
-      %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+      %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1]
+        : tensor<64x?x64xf32> to tensor<1x?x64xf32>
       %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
       ^bb0(%arg3: index, %arg4: index, %arg5: index):
         tensor.yield %arg0 : f32
@@ -205,9 +224,12 @@ def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
 
     // mlir-opt --reify-result-shapes
     #map = affine_map<()[s0] -> (-s0 + 256)>
-    func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
+    func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>)
+        -> tensor<1x?x64xf32>
+    {
       %0 = affine.apply #map()[%arg1]
-      %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+      %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1]
+        : tensor<64x?x64xf32> to tensor<1x?x64xf32>
       %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
       ^bb0(%arg3: index, %arg4: index, %arg5: index):
         tensor.yield %arg0 : f32
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
index b00a0f2103d43..0a8aacb9d15a1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "llvm/Support/InterleavedRange.h"
 
@@ -134,8 +135,10 @@ struct ReifyResultShapesPass final
 void ReifyResultShapesPass::runOnOperation() {
   SmallVector<ReifyRankedShapedTypeOpInterface> ops;
   getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
-    // Some ops have rigid type checkers and need to update their operands.
-    // Only admit the ones that are explicitly supported for now.
+    // Handle ops that are not DPS and that do not carry an tied operand shapes.
+    // For now, limit to tensor::PadOp and tensor::ConcatOp.
+    if (isa<DestinationStyleOpInterface>(op.getOperation()))
+      return;
     if (!isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation()))
       return;
     ops.push_back(op);

>From b45abc7dcb48f84fad4d49b491fbd200eac71ab1 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Tue, 1 Jul 2025 15:30:26 +0200
Subject: [PATCH 5/5] Review comments

---
 .../Dialect/MemRef/Transforms/Transforms.h    | 12 -----------
 .../MemRef/Transforms/ReifyResultShapes.cpp   | 20 +++++++++++++------
 2 files changed, 14 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 5f9f09d7992ca..33e3d94f02b1c 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -209,18 +209,6 @@ FailureOr<Value> replaceWithIndependentOp(RewriterBase &rewriter,
 memref::AllocaOp allocToAlloca(
     RewriterBase &rewriter, memref::AllocOp alloc,
     function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
-
-/// Reifies the results of `op`, potentially replacing `op` with a reified
-/// version. Returns `failure` if `mlir::reifyResultShapes` returned failure,
-/// otherwise it always succeeds. Users of this transform should always expect
-/// it to modify the IR, even when it fails. If any of the result types changes,
-/// the transform will insert cast operations to the old type to keep the IR
-/// consistent.
-///
-/// Note: This transform only works on ranked `memref` or `tensor` results,
-/// other types are ignored.
-LogicalResult reifyOpResultShapes(RewriterBase &rewriter,
-                                  ReifyRankedShapedTypeOpInterface op);
 } // namespace memref
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
index 0a8aacb9d15a1..e6b9e2f7e8213 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
@@ -33,9 +33,14 @@ namespace memref {
 
 using namespace mlir;
 
-LogicalResult
-mlir::memref::reifyOpResultShapes(RewriterBase &rewriter,
-                                  ReifyRankedShapedTypeOpInterface op) {
+/// Reifies the results of `op`, potentially replacing `op` with a reified
+/// version. Returns `failure` if `mlir::reifyResultShapes` returned failure,
+/// otherwise it always succeeds. Users of this transform should always expect
+/// it to modify the IR, even when it fails. If any of the result types changes,
+/// the transform will insert cast operations to the old type to keep the IR
+/// consistent.
+static LogicalResult reifyOpResultShapes(RewriterBase &rewriter,
+                                         ReifyRankedShapedTypeOpInterface op) {
   LLVM_DEBUG({ DBGS() << " reifying op: " << op << "\n"; });
   // Get the reified out shapes.
   ReifiedRankedShapedTypeDims reifiedResultShapes;
@@ -93,6 +98,11 @@ mlir::memref::reifyOpResultShapes(RewriterBase &rewriter,
   // We now have outTypes that need to be turned to cast ops.
   Location loc = op->getLoc();
   SmallVector<Value> newResults;
+  // TODO: `mlir::reifyResultShapes` and op verifiers may not agree atm.
+  // This is a confluence problem that will need to be addressed.
+  // For now, we know PadOp and ConcatOp are fine.
+  assert((isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation())) &&
+         "incorrect op");
   Operation *newOp = rewriter.clone(*op);
   for (auto [reifiedTy, oldRes] : llvm::zip(outTypes, op->getResults())) {
     OpResult newRes = newOp->getResult(oldRes.getResultNumber());
@@ -137,8 +147,6 @@ void ReifyResultShapesPass::runOnOperation() {
   getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
     // Handle ops that are not DPS and that do not carry an tied operand shapes.
     // For now, limit to tensor::PadOp and tensor::ConcatOp.
-    if (isa<DestinationStyleOpInterface>(op.getOperation()))
-      return;
     if (!isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation()))
       return;
     ops.push_back(op);
@@ -146,6 +154,6 @@ void ReifyResultShapesPass::runOnOperation() {
   IRRewriter rewriter(&getContext());
   for (ReifyRankedShapedTypeOpInterface op : ops) {
     rewriter.setInsertionPoint(op);
-    (void)memref::reifyOpResultShapes(rewriter, op);
+    (void)reifyOpResultShapes(rewriter, op);
   }
 }



More information about the Mlir-commits mailing list