[Mlir-commits] [mlir] [mlir][memref] Add a new `ReifyResultShapes` pass (PR #145927)

Kunwar Grover llvmlistbot at llvm.org
Tue Jul 1 06:17:45 PDT 2025


================
@@ -0,0 +1,151 @@
+//===- ReifyResultShapes.cpp - Reify result shapes ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transform reifies result shapes of `ReifyRankedShapedTypeOpInterface`
+// operations with ranked `memref` and `tensor` results.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "llvm/Support/InterleavedRange.h"
+
+#define DEBUG_TYPE "reify-result-shapes"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+namespace mlir {
+namespace memref {
+#define GEN_PASS_DEF_REIFYRESULTSHAPESPASS
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace memref
+} // namespace mlir
+
+using namespace mlir;
+
+LogicalResult
+mlir::memref::reifyOpResultShapes(RewriterBase &rewriter,
+                                  ReifyRankedShapedTypeOpInterface op) {
+  LLVM_DEBUG({ DBGS() << " reifying op: " << op << "\n"; });
+  // Get the reified out shapes.
+  ReifiedRankedShapedTypeDims reifiedResultShapes;
+  if (failed(mlir::reifyResultShapes(rewriter, op, reifiedResultShapes)) ||
+      reifiedResultShapes.empty()) {
+    return op->emitWarning() << "failed to get the reified shapes";
+  }
+
+  bool modified = false;
+  // Compute the new output types.
+  SmallVector<Type> outTypes;
+  for (const auto &[oldTy, reifiedShape] :
+       llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
+    // Skip if it's not a memref or tensor type.
+    if (!isa<RankedTensorType, MemRefType>(oldTy)) {
+      outTypes.push_back(oldTy);
+      continue;
+    }
+
+    ShapedType shapedTy = dyn_cast<ShapedType>(oldTy);
+
+    SmallVector<int64_t> shape = llvm::to_vector(shapedTy.getShape());
+    for (auto &&[dim, ofr] : llvm::zip_equal(shape, reifiedShape)) {
+      std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
+      // If the reified dim is dynamic set it appropriately.
+      if (!maybeCst.has_value()) {
+        dim = ShapedType::kDynamic;
+        continue;
+      }
+      // Set the static dim.
+      dim = *maybeCst;
+    }
+
+    // If the shape didn't change continue.
+    if (shape == shapedTy.getShape()) {
+      outTypes.push_back(oldTy);
+      continue;
+    }
+    modified = true;
+    outTypes.push_back(shapedTy.cloneWith(shape, shapedTy.getElementType()));
+  }
+
+  // Return if we don't need to update.
+  if (!modified) {
+    LLVM_DEBUG({ DBGS() << "- op doesn't require update\n"; });
+    return success();
+  }
+
+  LLVM_DEBUG({
+    DBGS() << "- oldTypes: " << llvm::interleaved_array(op->getResultTypes())
+           << " \n";
+    DBGS() << "- outTypes: " << llvm::interleaved_array(outTypes) << " \n";
+  });
+
+  // We now have outTypes that need to be turned to cast ops.
+  Location loc = op->getLoc();
+  SmallVector<Value> newResults;
+  Operation *newOp = rewriter.clone(*op);
----------------
Groverkss wrote:

Add a comment here that this may not be safe for all operations and currently only works for tensor.pad and tensor.concat.

https://github.com/llvm/llvm-project/pull/145927


More information about the Mlir-commits mailing list