[Mlir-commits] [mlir] [mlir][memref] Refactor `ViewOpShapeFolder` (PR #176567)
Longsheng Mou
llvmlistbot at llvm.org
Sat Jan 17 05:11:26 PST 2026
https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/176567
This PR makes the following changes to ViewOpShapeFolder:
- Add comments for `ViewOpShapeFolder`.
- Drop the redundant offset check.
- Simplify the implementation by introducing `foldDynamicToStaticDimSizes`.
- Add missing test coverage.
>From 204987f4f34db5b201fac941e9f7e6834cc93797 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sat, 17 Jan 2026 21:01:45 +0800
Subject: [PATCH] [mlir][memref] Refactor `ViewOpShapeFolder`
This PR makes the following changes to ViewOpShapeFolder:
- Add comments for `ViewOpShapeFolder`.
- Drop the redundant offset check.
- Simplify the implementation by introducing `foldDynamicToStaticDimSizes`.
- Add missing test coverage.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 106 ++++++++++-----------
mlir/test/Dialect/MemRef/canonicalize.mlir | 28 ++++++
2 files changed, 80 insertions(+), 54 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e0f7a8b452a1d..48636a95fd989 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3763,71 +3763,69 @@ OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
}
namespace {
+/// Given a memref type and a range of values that defines its dynamic
+/// dimension sizes, turn all dynamic sizes that have a constant value into
+/// static dimension sizes.
+static MemRefType
+foldDynamicToStaticDimSizes(MemRefType type, ValueRange dynamicSizes,
+ SmallVector<Value> &foldedDynamicSizes) {
+ SmallVector<int64_t> staticShape(type.getShape());
+ assert(type.getNumDynamicDims() == dynamicSizes.size() &&
+ "incorrect number of dynamic sizes");
+
+ // Compute new static and dynamic sizes.
+ unsigned ctr = 0;
+ for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
+ if (type.isDynamicDim(i)) {
+ Value dynamicSize = dynamicSizes[ctr++];
+ std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
+ if (cst.has_value()) {
+ // Dynamic size must be non-negative.
+ if (cst.value() < 0) {
+ foldedDynamicSizes.push_back(dynamicSize);
+ continue;
+ }
+ staticShape[i] = *cst;
+ } else {
+ foldedDynamicSizes.push_back(dynamicSize);
+ }
+ }
+ }
+
+ return MemRefType::Builder(type).setShape(staticShape);
+}
+/// Change the result type of a `memref.view` by making originally dynamic
+/// dimensions static when their sizes come from `constant` ops.
+/// Example:
+/// ```
+/// %c5 = arith.constant 5: index
+/// %0 = memref.view %src[%offset][%c5] : memref<?xi8> to memref<?x4xf32>
+/// ```
+/// to
+/// ```
+/// %0 = memref.view %src[%offset][] : memref<?xi8> to memref<5x4xf32>
+/// ```
struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
using Base::Base;
LogicalResult matchAndRewrite(ViewOp viewOp,
PatternRewriter &rewriter) const override {
- // Return if none of the operands are constants.
- if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
- return matchPattern(operand, matchConstantIndex());
- }))
- return failure();
-
- // Get result memref type.
- auto memrefType = viewOp.getType();
-
- // Get offset from old memref view type 'memRefType'.
- int64_t oldOffset;
- SmallVector<int64_t, 4> oldStrides;
- if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
- return failure();
- assert(oldOffset == 0 && "Expected 0 offset");
-
- SmallVector<Value, 4> newOperands;
-
- // Offset cannot be folded into result type.
-
- // Fold any dynamic dim operands which are produced by a constant.
- SmallVector<int64_t, 4> newShapeConstants;
- newShapeConstants.reserve(memrefType.getRank());
-
- unsigned dynamicDimPos = 0;
- unsigned rank = memrefType.getRank();
- for (unsigned dim = 0, e = rank; dim < e; ++dim) {
- int64_t dimSize = memrefType.getDimSize(dim);
- // If this is already static dimension, keep it.
- if (ShapedType::isStatic(dimSize)) {
- newShapeConstants.push_back(dimSize);
- continue;
- }
- auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
- if (auto constantIndexOp =
- dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
- // Dynamic shape dimension will be folded.
- newShapeConstants.push_back(constantIndexOp.value());
- } else {
- // Dynamic shape dimension not folded; copy operand from old memref.
- newShapeConstants.push_back(dimSize);
- newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
- }
- dynamicDimPos++;
- }
+ SmallVector<Value> foldedDynamicSizes;
+ MemRefType resultType = viewOp.getType();
+ MemRefType foldedMemRefType = foldDynamicToStaticDimSizes(
+ resultType, viewOp.getSizes(), foldedDynamicSizes);
- // Create new memref type with constant folded dims.
- MemRefType newMemRefType =
- MemRefType::Builder(memrefType).setShape(newShapeConstants);
- // Nothing new, don't fold.
- if (newMemRefType == memrefType)
+ // Stop here if no dynamic size was promoted to static.
+ if (foldedMemRefType == resultType)
return failure();
// Create new ViewOp.
- auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType,
- viewOp.getOperand(0), viewOp.getByteShift(),
- newOperands);
+ auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), foldedMemRefType,
+ viewOp.getSource(), viewOp.getByteShift(),
+ foldedDynamicSizes);
// Insert a cast so we have the same type as the old memref type.
- rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
+ rewriter.replaceOpWithNewOp<CastOp>(viewOp, resultType, newViewOp);
return success();
}
};
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 122906037b952..d32f8f7efc5ff 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1577,3 +1577,31 @@ func.func @non_fold_view_same_source_dynamic_size(%0: memref<?xi8>, %arg0 : inde
%res = memref.view %0[%c0][%arg0] : memref<?xi8> to memref<?xi8>
return %res : memref<?xi8>
}
+
+// -----
+
+// CHECK-LABEL: func.func @replace_view_static_dims(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xi8>, %[[ARG1:.*]]: index) -> memref<?x4xi32> {
+// CHECK: %[[VIEW_0:.*]] = memref.view %[[ARG0]]{{\[}}%[[ARG1]]][] : memref<?xi8> to memref<5x4xi32>
+// CHECK: %[[CAST_0:.*]] = memref.cast %[[VIEW_0]] : memref<5x4xi32> to memref<?x4xi32>
+// CHECK: return %[[CAST_0]] : memref<?x4xi32>
+// CHECK: }
+func.func @replace_view_static_dims(%src: memref<?xi8>, %offset : index) -> memref<?x4xi32> {
+ %c5 = arith.constant 5: index
+ %res = memref.view %src[%offset][%c5] : memref<?xi8> to memref<?x4xi32>
+ return %res : memref<?x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @non_replace_view_negative_static_dims(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xi8>, %[[ARG1:.*]]: index) -> memref<?x4xi32> {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant -1 : index
+// CHECK: %[[VIEW_0:.*]] = memref.view %[[ARG0]]{{\[}}%[[ARG1]]]{{\[}}%[[CONSTANT_0]]] : memref<?xi8> to memref<?x4xi32>
+// CHECK: return %[[VIEW_0]] : memref<?x4xi32>
+// CHECK: }
+func.func @non_replace_view_negative_static_dims(%src: memref<?xi8>, %offset : index) -> memref<?x4xi32> {
+ %c-1 = arith.constant -1: index
+ %res = memref.view %src[%offset][%c-1] : memref<?xi8> to memref<?x4xi32>
+ return %res : memref<?x4xi32>
+}
More information about the Mlir-commits
mailing list