[Mlir-commits] [mlir] [mlir][memref] Add a new `ReifyResultShapes` pass (PR #145927)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 26 12:49:22 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Nicolas Vasilache (nicolasvasilache)
<details>
<summary>Changes</summary>
This patch introduces the `ReifyResultShapes ` pass. This pass reifies the shapes of every `ReifyRankedShapedTypeOpInterface` operation with ranked `memref` and `tensor` results. Replacing the operations with their reified versions, and inserting casts when results shapes are updated.
Example:
```mlir
#map = affine_map<(d0) -> (-d0 + 256)>
func.func @<!-- -->func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
%0 = affine.apply #map(%arg1)
%extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
%padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
^bb0(%arg3: index, %arg4: index, %arg5: index):
tensor.yield %arg0 : f32
} : tensor<1x?x64xf32> to tensor<1x?x64xf32>
return %padded : tensor<1x?x64xf32>
}
// mlir-opt --reify-result-shapes
#map = affine_map<()[s0] -> (-s0 + 256)>
func.func @<!-- -->func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
%0 = affine.apply #map()[%arg1]
%extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
%padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
^bb0(%arg3: index, %arg4: index, %arg5: index):
tensor.yield %arg0 : f32
} : tensor<1x?x64xf32> to tensor<1x256x64xf32>
%cast = tensor.cast %padded : tensor<1x256x64xf32> to tensor<1x?x64xf32>
return %cast : tensor<1x?x64xf32>
}
---
Full diff: https://github.com/llvm/llvm-project/pull/145927.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td (+40)
- (modified) mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h (+12)
- (modified) mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp (+144)
- (added) mlir/test/Dialect/Tensor/reify-shapes.mlir (+31)
``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index a8d135caa74f0..4645d49cab2be 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -182,6 +182,46 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
];
}
+def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
+ let summary = "Reifies the results of all `ReifyRankedShapedTypeOpInterface` operations";
+ let description = [{
+ This pass reifies the shapes of every `ReifyRankedShapedTypeOpInterface`
+ operation with ranked `memref` and `tensor` results. Replacing the
+ operations with their reified versions, and inserting casts when results
+ shapes are updated.
+
+ Example:
+ ```mlir
+ #map = affine_map<(d0) -> (-d0 + 256)>
+ func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
+ %0 = affine.apply #map(%arg1)
+ %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+ %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index):
+ tensor.yield %arg0 : f32
+ } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+ return %padded : tensor<1x?x64xf32>
+ }
+
+ // mlir-opt --reify-result-shapes
+ #map = affine_map<()[s0] -> (-s0 + 256)>
+ func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
+ %0 = affine.apply #map()[%arg1]
+ %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+ %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index):
+ tensor.yield %arg0 : f32
+ } : tensor<1x?x64xf32> to tensor<1x256x64xf32>
+ %cast = tensor.cast %padded : tensor<1x256x64xf32> to tensor<1x?x64xf32>
+ return %cast : tensor<1x?x64xf32>
+ }
+ ```
+ }];
+ let dependentDialects = [
+ "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
+ ];
+}
+
def ExpandStridedMetadataPass : Pass<"expand-strided-metadata"> {
let summary = "Expand memref operations into easier to analyze constructs";
let description = [{
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index c2b8cb05be922..5f9f09d7992ca 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -23,6 +23,7 @@ class RewritePatternSet;
class RewriterBase;
class Value;
class ValueRange;
+class ReifyRankedShapedTypeOpInterface;
namespace arith {
class WideIntEmulationConverter;
@@ -209,6 +210,17 @@ memref::AllocaOp allocToAlloca(
RewriterBase &rewriter, memref::AllocOp alloc,
function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
+/// Reifies the results of `op`, potentially replacing `op` with a reified
+/// version. Returns `failure` if `mlir::reifyResultShapes` returned failure,
+/// otherwise it always succeeds. Users of this transform should always expect
+/// it to modify the IR, even when it fails. If any of the result types changes,
+/// the transform will insert cast operations to the old type to keep the IR
+/// consistent.
+///
+/// Note: This transform only works on ranked `memref` or `tensor` results,
+/// other types are ignored.
+LogicalResult reifyOpResultShapes(RewriterBase &rewriter,
+ ReifyRankedShapedTypeOpInterface op);
} // namespace memref
} // namespace mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 637f5ec1c9f9b..9049faccadef3 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
IndependenceTransforms.cpp
MultiBuffer.cpp
NormalizeMemRefs.cpp
+ ReifyResultShapes.cpp
ResolveShapedTypeResultDims.cpp
RuntimeOpVerification.cpp
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
new file mode 100644
index 0000000000000..dcb601577f88f
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp
@@ -0,0 +1,144 @@
+//===- ReifyResultShapes.cpp - Reify result shapes ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transform reifies result shapes of `ReifyRankedShapedTypeOpInterface`
+// operations with ranked `memref` and `tensor` results.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "llvm/Support/InterleavedRange.h"
+
+#define DEBUG_TYPE "reify-result-shapes"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+namespace mlir {
+namespace memref {
+#define GEN_PASS_DEF_REIFYRESULTSHAPESPASS
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace memref
+} // namespace mlir
+
+using namespace mlir;
+
+LogicalResult
+mlir::memref::reifyOpResultShapes(RewriterBase &rewriter,
+ ReifyRankedShapedTypeOpInterface op) {
+ LLVM_DEBUG({ DBGS() << " reifying op: " << op << "\n"; });
+ // Get the reified out shapes.
+ ReifiedRankedShapedTypeDims reifiedResultShapes;
+ if (failed(mlir::reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
+ reifiedResultShapes.empty()) {
+ return op.emitError() << "failed to get the reified shapes";
+ }
+
+ bool modified = false;
+ // Compute the new output types.
+ SmallVector<Type> outTypes;
+ for (const auto &[oldTy, reifiedShape] :
+ llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
+ // Skip if it's not a memref or tensor type.
+ if (!isa<RankedTensorType, MemRefType>(oldTy)) {
+ outTypes.push_back(oldTy);
+ continue;
+ }
+
+ ShapedType shapedTy = dyn_cast<ShapedType>(oldTy);
+
+ SmallVector<int64_t> shape = llvm::to_vector(shapedTy.getShape());
+ for (auto &&[dim, ofr] : llvm::zip_equal(shape, reifiedShape)) {
+ std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
+ // If the reified dim is dynamic set it appropriately.
+ if (!maybeCst.has_value()) {
+ dim = ShapedType::kDynamic;
+ continue;
+ }
+ // Set the static dim.
+ dim = *maybeCst;
+ }
+
+ // If the shape didn't change continue.
+ if (shape == shapedTy.getShape()) {
+ outTypes.push_back(oldTy);
+ continue;
+ }
+ modified = true;
+ outTypes.push_back(shapedTy.cloneWith(shape, shapedTy.getElementType()));
+ }
+
+ // Return if we don't need to update.
+ if (!modified) {
+ LLVM_DEBUG({ DBGS() << "- op doesn't require update\n"; });
+ return success();
+ }
+
+ LLVM_DEBUG({
+ DBGS() << "- oldTypes: " << llvm::interleaved_array(op->getResultTypes())
+ << " \n";
+ DBGS() << "- outTypes: " << llvm::interleaved_array(outTypes) << " \n";
+ });
+
+ // We now have outTypes that need to be turned to cast ops.
+ Location loc = op->getLoc();
+ SmallVector<Value> newResults;
+ Operation *newOp = rewriter.clone(*op);
+ for (auto [reifiedTy, oldRes] : llvm::zip(outTypes, op->getResults())) {
+ OpResult newRes = newOp->getResult(oldRes.getResultNumber());
+ Type oldTy = oldRes.getType();
+ // Continue if the type remained invariant or is not shaped.
+ if (oldTy == reifiedTy || !isa<MemRefType, RankedTensorType>(oldTy)) {
+ newResults.push_back(newRes);
+ continue;
+ }
+
+ // Update the type.
+ newRes.setType(reifiedTy);
+ if (isa<RankedTensorType>(reifiedTy)) {
+ newResults.push_back(rewriter.create<tensor::CastOp>(loc, oldTy, newRes));
+ } else {
+ assert(isa<MemRefType>(reifiedTy) && "expected a memref type");
+ newResults.push_back(rewriter.create<memref::CastOp>(loc, oldTy, newRes));
+ }
+ }
+
+ LLVM_DEBUG({
+ DBGS() << "- reified results " << llvm::interleaved_array(newResults)
+ << "\n";
+ });
+ rewriter.replaceOp(op, newResults);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Pass registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ReifyResultShapesPass final
+ : public memref::impl::ReifyResultShapesPassBase<ReifyResultShapesPass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void ReifyResultShapesPass::runOnOperation() {
+ SmallVector<ReifyRankedShapedTypeOpInterface> ops;
+ getOperation()->walk(
+ [&](ReifyRankedShapedTypeOpInterface op) { ops.push_back(op); });
+ IRRewriter rewriter(&getContext());
+ for (ReifyRankedShapedTypeOpInterface op : ops) {
+ rewriter.setInsertionPoint(op);
+ if (failed(memref::reifyOpResultShapes(rewriter, op)))
+ return signalPassFailure();
+ }
+}
diff --git a/mlir/test/Dialect/Tensor/reify-shapes.mlir b/mlir/test/Dialect/Tensor/reify-shapes.mlir
new file mode 100644
index 0000000000000..5569d90f8b731
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/reify-shapes.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt -reify-result-shapes %s | FileCheck %s
+
+// The test below checks concat op reification. In the first case, no cast is inserted while on the second a cast gets inserted.
+// CHECK-LABEL: func.func @concat_reification
+func.func @concat_reification(%arg0: tensor<4x7x3xf32>, %arg1 : tensor<4x4x3xf32>, %arg2: tensor<?x?x?xf32>)
+ -> (tensor<4x11x3xf32>, tensor<?x?x?xf32>) {
+ // CHECK: %[[RES0:.*]] = tensor.concat dim(1) %{{.*}} : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+ %1 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+ // CHECK: %[[V0:.*]] = tensor.concat dim(2) %{{.*}} : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<4x7x?xf32>
+ // CHECK: %[[RES1:.*]] = tensor.cast %[[V0]] : tensor<4x7x?xf32> to tensor<?x?x?xf32>
+ %2 = tensor.concat dim(2) %arg0, %arg2 : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ // CHECK: return %[[RES0]], %[[RES1]] : tensor<4x11x3xf32>, tensor<?x?x?xf32>
+ return %1, %2 : tensor<4x11x3xf32>, tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: func.func @pad_reification
+func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> {
+ %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
+ %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1]
+ : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+
+ // CHECK: tensor.pad
+ // CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32>
+ // CHECK: tensor.cast %{{.*}} : tensor<1x256x64xf32> to tensor<1x?x64xf32>
+ %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
+ ^bb0(%a: index, %b: index, %c: index):
+ tensor.yield %cst : f32
+ } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+
+ return %padded : tensor<1x?x64xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/145927
More information about the Mlir-commits
mailing list