[Mlir-commits] [mlir] [mlir][memref] Add a new `ReifyResultShapes` pass (PR #145927)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 26 12:49:22 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Nicolas Vasilache (nicolasvasilache)

<details>
<summary>Changes</summary>

  This patch introduces the `ReifyResultShapes ` pass. This pass reifies the shapes of every `ReifyRankedShapedTypeOpInterface` operation with ranked `memref` and `tensor` results. Replacing the operations with their reified versions, and inserting casts when results shapes are updated.

  Example:
  ```mlir
  #map = affine_map<(d0) -> (-d0 + 256)>
  func.func @<!-- -->func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
    %0 = affine.apply #map(%arg1)
    %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
    %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
    ^bb0(%arg3: index, %arg4: index, %arg5: index):
      tensor.yield %arg0 : f32
    } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
    return %padded : tensor<1x?x64xf32>
  }

  // mlir-opt --reify-result-shapes
  #map = affine_map<()[s0] -> (-s0 + 256)>
  func.func @<!-- -->func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
    %0 = affine.apply #map()[%arg1]
    %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
    %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
    ^bb0(%arg3: index, %arg4: index, %arg5: index):
      tensor.yield %arg0 : f32
    } : tensor<1x?x64xf32> to tensor<1x256x64xf32>
    %cast = tensor.cast %padded : tensor<1x256x64xf32> to tensor<1x?x64xf32>
    return %cast : tensor<1x?x64xf32>
  }

---
Full diff: https://github.com/llvm/llvm-project/pull/145927.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td (+40) 
- (modified) mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h (+12) 
- (modified) mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp (+144) 
- (added) mlir/test/Dialect/Tensor/reify-shapes.mlir (+31) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index a8d135caa74f0..4645d49cab2be 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -182,6 +182,46 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
   ];
 }
 
+def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
+  let summary = "Reifies the results of all `ReifyRankedShapedTypeOpInterface` operations";
+  let description = [{
+    This pass reifies the shapes of every `ReifyRankedShapedTypeOpInterface`
+    operation with ranked `memref` and `tensor` results. Replacing the
+    operations with their reified versions, and inserting casts when results
+    shapes are updated.
+
+    Example:
+    ```mlir
+    #map = affine_map<(d0) -> (-d0 + 256)>
+    func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
+      %0 = affine.apply #map(%arg1)
+      %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+      %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
+      ^bb0(%arg3: index, %arg4: index, %arg5: index):
+        tensor.yield %arg0 : f32
+      } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+      return %padded : tensor<1x?x64xf32>
+    }
+
+    // mlir-opt --reify-result-shapes
+    #map = affine_map<()[s0] -> (-s0 + 256)>
+    func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
+      %0 = affine.apply #map()[%arg1]
+      %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+      %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
+      ^bb0(%arg3: index, %arg4: index, %arg5: index):
+        tensor.yield %arg0 : f32
+      } : tensor<1x?x64xf32> to tensor<1x256x64xf32>
+      %cast = tensor.cast %padded : tensor<1x256x64xf32> to tensor<1x?x64xf32>
+      return %cast : tensor<1x?x64xf32>
+    }
+    ```
+  }];
+  let dependentDialects = [
+    "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
+  ];
+}
+
 def ExpandStridedMetadataPass : Pass<"expand-strided-metadata"> {
   let summary = "Expand memref operations into easier to analyze constructs";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index c2b8cb05be922..5f9f09d7992ca 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -23,6 +23,7 @@ class RewritePatternSet;
 class RewriterBase;
 class Value;
 class ValueRange;
+class ReifyRankedShapedTypeOpInterface;
 
 namespace arith {
 class WideIntEmulationConverter;
@@ -209,6 +210,17 @@ memref::AllocaOp allocToAlloca(
     RewriterBase &rewriter, memref::AllocOp alloc,
     function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
 
+/// Reifies the results of `op`, potentially replacing `op` with a reified
+/// version. Returns `failure` if `mlir::reifyResultShapes` returned failure,
+/// otherwise it always succeeds. Users of this transform should always expect
+/// it to modify the IR, even when it fails. If any of the result types changes,
+/// the transform will insert cast operations to the old type to keep the IR
+/// consistent.
+///
+/// Note: This transform only works on ranked `memref` or `tensor` results,
+/// other types are ignored.
+LogicalResult reifyOpResultShapes(RewriterBase &rewriter,
+                                  ReifyRankedShapedTypeOpInterface op);
 } // namespace memref
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 637f5ec1c9f9b..9049faccadef3 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   IndependenceTransforms.cpp
   MultiBuffer.cpp
   NormalizeMemRefs.cpp
+  ReifyResultShapes.cpp
   ResolveShapedTypeResultDims.cpp
   RuntimeOpVerification.cpp
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
new file mode 100644
index 0000000000000..dcb601577f88f
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
@@ -0,0 +1,144 @@
+//===- ReifyResultShapes.cpp - Reify result shapes ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transform reifies result shapes of `ReifyRankedShapedTypeOpInterface`
+// operations with ranked `memref` and `tensor` results.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "llvm/Support/InterleavedRange.h"
+
+#define DEBUG_TYPE "reify-result-shapes"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+namespace mlir {
+namespace memref {
+#define GEN_PASS_DEF_REIFYRESULTSHAPESPASS
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace memref
+} // namespace mlir
+
+using namespace mlir;
+
+LogicalResult
+mlir::memref::reifyOpResultShapes(RewriterBase &rewriter,
+                                  ReifyRankedShapedTypeOpInterface op) {
+  LLVM_DEBUG({ DBGS() << " reifying op: " << op << "\n"; });
+  // Get the reified out shapes.
+  ReifiedRankedShapedTypeDims reifiedResultShapes;
+  if (failed(mlir::reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
+      reifiedResultShapes.empty()) {
+    return op.emitError() << "failed to get the reified shapes";
+  }
+
+  bool modified = false;
+  // Compute the new output types.
+  SmallVector<Type> outTypes;
+  for (const auto &[oldTy, reifiedShape] :
+       llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
+    // Skip if it's not a memref or tensor type.
+    if (!isa<RankedTensorType, MemRefType>(oldTy)) {
+      outTypes.push_back(oldTy);
+      continue;
+    }
+
+    ShapedType shapedTy = dyn_cast<ShapedType>(oldTy);
+
+    SmallVector<int64_t> shape = llvm::to_vector(shapedTy.getShape());
+    for (auto &&[dim, ofr] : llvm::zip_equal(shape, reifiedShape)) {
+      std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
+      // If the reified dim is dynamic set it appropriately.
+      if (!maybeCst.has_value()) {
+        dim = ShapedType::kDynamic;
+        continue;
+      }
+      // Set the static dim.
+      dim = *maybeCst;
+    }
+
+    // If the shape didn't change continue.
+    if (shape == shapedTy.getShape()) {
+      outTypes.push_back(oldTy);
+      continue;
+    }
+    modified = true;
+    outTypes.push_back(shapedTy.cloneWith(shape, shapedTy.getElementType()));
+  }
+
+  // Return if we don't need to update.
+  if (!modified) {
+    LLVM_DEBUG({ DBGS() << "- op doesn't require update\n"; });
+    return success();
+  }
+
+  LLVM_DEBUG({
+    DBGS() << "- oldTypes: " << llvm::interleaved_array(op->getResultTypes())
+           << " \n";
+    DBGS() << "- outTypes: " << llvm::interleaved_array(outTypes) << " \n";
+  });
+
+  // We now have outTypes that need to be turned to cast ops.
+  Location loc = op->getLoc();
+  SmallVector<Value> newResults;
+  Operation *newOp = rewriter.clone(*op);
+  for (auto [reifiedTy, oldRes] : llvm::zip(outTypes, op->getResults())) {
+    OpResult newRes = newOp->getResult(oldRes.getResultNumber());
+    Type oldTy = oldRes.getType();
+    // Continue if the type remained invariant or is not shaped.
+    if (oldTy == reifiedTy || !isa<MemRefType, RankedTensorType>(oldTy)) {
+      newResults.push_back(newRes);
+      continue;
+    }
+
+    // Update the type.
+    newRes.setType(reifiedTy);
+    if (isa<RankedTensorType>(reifiedTy)) {
+      newResults.push_back(rewriter.create<tensor::CastOp>(loc, oldTy, newRes));
+    } else {
+      assert(isa<MemRefType>(reifiedTy) && "expected a memref type");
+      newResults.push_back(rewriter.create<memref::CastOp>(loc, oldTy, newRes));
+    }
+  }
+
+  LLVM_DEBUG({
+    DBGS() << "- reified results " << llvm::interleaved_array(newResults)
+           << "\n";
+  });
+  rewriter.replaceOp(op, newResults);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Pass registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ReifyResultShapesPass final
+    : public memref::impl::ReifyResultShapesPassBase<ReifyResultShapesPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void ReifyResultShapesPass::runOnOperation() {
+  SmallVector<ReifyRankedShapedTypeOpInterface> ops;
+  getOperation()->walk(
+      [&](ReifyRankedShapedTypeOpInterface op) { ops.push_back(op); });
+  IRRewriter rewriter(&getContext());
+  for (ReifyRankedShapedTypeOpInterface op : ops) {
+    rewriter.setInsertionPoint(op);
+    if (failed(memref::reifyOpResultShapes(rewriter, op)))
+      return signalPassFailure();
+  }
+}
diff --git a/mlir/test/Dialect/Tensor/reify-shapes.mlir b/mlir/test/Dialect/Tensor/reify-shapes.mlir
new file mode 100644
index 0000000000000..5569d90f8b731
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/reify-shapes.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt -reify-result-shapes  %s | FileCheck %s
+
+// The test below checks concat op reification. In the first case, no cast is inserted while on the second a cast gets inserted.
+// CHECK-LABEL:  func.func @concat_reification
+func.func @concat_reification(%arg0: tensor<4x7x3xf32>, %arg1 : tensor<4x4x3xf32>, %arg2: tensor<?x?x?xf32>)
+  -> (tensor<4x11x3xf32>, tensor<?x?x?xf32>) {
+  // CHECK: %[[RES0:.*]] = tensor.concat dim(1) %{{.*}} : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+  %1 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+  // CHECK: %[[V0:.*]] = tensor.concat dim(2) %{{.*}} : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<4x7x?xf32>
+  // CHECK: %[[RES1:.*]] = tensor.cast %[[V0]] : tensor<4x7x?xf32> to tensor<?x?x?xf32>
+  %2 = tensor.concat dim(2) %arg0, %arg2 : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  // CHECK: return %[[RES0]], %[[RES1]] : tensor<4x11x3xf32>, tensor<?x?x?xf32>
+  return %1, %2 : tensor<4x11x3xf32>, tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL:  func.func @pad_reification
+func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
+  %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
+  %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] 
+    : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+
+  // CHECK: tensor.pad
+  // CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32>
+  // CHECK: tensor.cast %{{.*}} : tensor<1x256x64xf32> to tensor<1x?x64xf32>
+  %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
+    ^bb0(%a: index, %b: index, %c: index):
+    tensor.yield %cst : f32
+  } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+
+  return %padded : tensor<1x?x64xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/145927


More information about the Mlir-commits mailing list