[Mlir-commits] [mlir] [mlir][memref] Add a new `ReifyResultShapes` pass (PR #145927)
Kunwar Grover
llvmlistbot at llvm.org
Tue Jul 1 06:17:45 PDT 2025
================
@@ -0,0 +1,151 @@
+//===- ReifyResultShapes.cpp - Reify result shapes ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transform reifies result shapes of `ReifyRankedShapedTypeOpInterface`
+// operations with ranked `memref` and `tensor` results.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "llvm/Support/InterleavedRange.h"
+
+#define DEBUG_TYPE "reify-result-shapes"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+namespace mlir {
+namespace memref {
+#define GEN_PASS_DEF_REIFYRESULTSHAPESPASS
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace memref
+} // namespace mlir
+
+using namespace mlir;
+
+LogicalResult
+mlir::memref::reifyOpResultShapes(RewriterBase &rewriter,
+ ReifyRankedShapedTypeOpInterface op) {
+ LLVM_DEBUG({ DBGS() << " reifying op: " << op << "\n"; });
+ // Get the reified out shapes.
+ ReifiedRankedShapedTypeDims reifiedResultShapes;
+ if (failed(mlir::reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
+ reifiedResultShapes.empty()) {
+ return op->emitWarning() << "failed to get the reified shapes";
+ }
+
+ bool modified = false;
+ // Compute the new output types.
+ SmallVector<Type> outTypes;
+ for (const auto &[oldTy, reifiedShape] :
+ llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
+ // Skip if it's not a memref or tensor type.
+ if (!isa<RankedTensorType, MemRefType>(oldTy)) {
+ outTypes.push_back(oldTy);
+ continue;
+ }
+
+ ShapedType shapedTy = dyn_cast<ShapedType>(oldTy);
+
+ SmallVector<int64_t> shape = llvm::to_vector(shapedTy.getShape());
+ for (auto &&[dim, ofr] : llvm::zip_equal(shape, reifiedShape)) {
+ std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
+ // If the reified dim is dynamic set it appropriately.
+ if (!maybeCst.has_value()) {
+ dim = ShapedType::kDynamic;
+ continue;
+ }
+ // Set the static dim.
+ dim = *maybeCst;
+ }
+
+ // If the shape didn't change continue.
+ if (shape == shapedTy.getShape()) {
+ outTypes.push_back(oldTy);
+ continue;
+ }
+ modified = true;
+ outTypes.push_back(shapedTy.cloneWith(shape, shapedTy.getElementType()));
+ }
+
+ // Return if we don't need to update.
+ if (!modified) {
+ LLVM_DEBUG({ DBGS() << "- op doesn't require update\n"; });
+ return success();
+ }
+
+ LLVM_DEBUG({
+ DBGS() << "- oldTypes: " << llvm::interleaved_array(op->getResultTypes())
+ << " \n";
+ DBGS() << "- outTypes: " << llvm::interleaved_array(outTypes) << " \n";
+ });
+
+ // We now have outTypes that need to be turned to cast ops.
+ Location loc = op->getLoc();
+ SmallVector<Value> newResults;
+ Operation *newOp = rewriter.clone(*op);
----------------
Groverkss wrote:
Add a comment here that this may not be safe for all operations and currently only works for tensor.pad and tensor.concat.
https://github.com/llvm/llvm-project/pull/145927
More information about the Mlir-commits
mailing list