[Mlir-commits] [mlir] [mlir][memref] Refactor `ViewOpShapeFolder` (PR #176567)

Longsheng Mou llvmlistbot at llvm.org
Sat Jan 17 08:05:07 PST 2026


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/176567

>From 204987f4f34db5b201fac941e9f7e6834cc93797 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sat, 17 Jan 2026 21:01:45 +0800
Subject: [PATCH 1/2] [mlir][memref] Refactor `ViewOpShapeFolder`

This PR makes the following changes to ViewOpShapeFolder:
- Add comments for `ViewOpShapeFolder`.
- Drop the redundant offset check.
- Simplify the implementation by introducing `foldDynamicToStaticDimSizes`.
- Add missing test coverage.
---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp   | 106 ++++++++++-----------
 mlir/test/Dialect/MemRef/canonicalize.mlir |  28 ++++++
 2 files changed, 80 insertions(+), 54 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e0f7a8b452a1d..48636a95fd989 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3763,71 +3763,69 @@ OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
 }
 
 namespace {
+/// Given a memref type and a range of values that defines its dynamic
+/// dimension sizes, turn all dynamic sizes that have a constant value into
+/// static dimension sizes.
+static MemRefType
+foldDynamicToStaticDimSizes(MemRefType type, ValueRange dynamicSizes,
+                            SmallVector<Value> &foldedDynamicSizes) {
+  SmallVector<int64_t> staticShape(type.getShape());
+  assert(type.getNumDynamicDims() == dynamicSizes.size() &&
+         "incorrect number of dynamic sizes");
+
+  // Compute new static and dynamic sizes.
+  unsigned ctr = 0;
+  for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
+    if (type.isDynamicDim(i)) {
+      Value dynamicSize = dynamicSizes[ctr++];
+      std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
+      if (cst.has_value()) {
+        // Dynamic size must be non-negative.
+        if (cst.value() < 0) {
+          foldedDynamicSizes.push_back(dynamicSize);
+          continue;
+        }
+        staticShape[i] = *cst;
+      } else {
+        foldedDynamicSizes.push_back(dynamicSize);
+      }
+    }
+  }
+
+  return MemRefType::Builder(type).setShape(staticShape);
+}
 
+/// Change the result type of a `memref.view` by making originally dynamic
+/// dimensions static when their sizes come from `constant` ops.
+/// Example:
+///  ```
+///  %c5 = arith.constant 5: index
+///  %0 = memref.view %src[%offset][%c5] : memref<?xi8> to memref<?x4xf32>
+///  ```
+///  to
+///  ```
+///  %0 = memref.view %src[%offset][] : memref<?xi8> to memref<5x4xf32>
+///  ```
 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
   using Base::Base;
 
   LogicalResult matchAndRewrite(ViewOp viewOp,
                                 PatternRewriter &rewriter) const override {
-    // Return if none of the operands are constants.
-    if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
-          return matchPattern(operand, matchConstantIndex());
-        }))
-      return failure();
-
-    // Get result memref type.
-    auto memrefType = viewOp.getType();
-
-    // Get offset from old memref view type 'memRefType'.
-    int64_t oldOffset;
-    SmallVector<int64_t, 4> oldStrides;
-    if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
-      return failure();
-    assert(oldOffset == 0 && "Expected 0 offset");
-
-    SmallVector<Value, 4> newOperands;
-
-    // Offset cannot be folded into result type.
-
-    // Fold any dynamic dim operands which are produced by a constant.
-    SmallVector<int64_t, 4> newShapeConstants;
-    newShapeConstants.reserve(memrefType.getRank());
-
-    unsigned dynamicDimPos = 0;
-    unsigned rank = memrefType.getRank();
-    for (unsigned dim = 0, e = rank; dim < e; ++dim) {
-      int64_t dimSize = memrefType.getDimSize(dim);
-      // If this is already static dimension, keep it.
-      if (ShapedType::isStatic(dimSize)) {
-        newShapeConstants.push_back(dimSize);
-        continue;
-      }
-      auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
-      if (auto constantIndexOp =
-              dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
-        // Dynamic shape dimension will be folded.
-        newShapeConstants.push_back(constantIndexOp.value());
-      } else {
-        // Dynamic shape dimension not folded; copy operand from old memref.
-        newShapeConstants.push_back(dimSize);
-        newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
-      }
-      dynamicDimPos++;
-    }
+    SmallVector<Value> foldedDynamicSizes;
+    MemRefType resultType = viewOp.getType();
+    MemRefType foldedMemRefType = foldDynamicToStaticDimSizes(
+        resultType, viewOp.getSizes(), foldedDynamicSizes);
 
-    // Create new memref type with constant folded dims.
-    MemRefType newMemRefType =
-        MemRefType::Builder(memrefType).setShape(newShapeConstants);
-    // Nothing new, don't fold.
-    if (newMemRefType == memrefType)
+    // Stop here if no dynamic size was promoted to static.
+    if (foldedMemRefType == resultType)
       return failure();
 
     // Create new ViewOp.
-    auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType,
-                                    viewOp.getOperand(0), viewOp.getByteShift(),
-                                    newOperands);
+    auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), foldedMemRefType,
+                                    viewOp.getSource(), viewOp.getByteShift(),
+                                    foldedDynamicSizes);
     // Insert a cast so we have the same type as the old memref type.
-    rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
+    rewriter.replaceOpWithNewOp<CastOp>(viewOp, resultType, newViewOp);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 122906037b952..d32f8f7efc5ff 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1577,3 +1577,31 @@ func.func @non_fold_view_same_source_dynamic_size(%0: memref<?xi8>, %arg0 : inde
   %res = memref.view %0[%c0][%arg0] : memref<?xi8> to memref<?xi8>
   return %res : memref<?xi8>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @replace_view_static_dims(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xi8>, %[[ARG1:.*]]: index) -> memref<?x4xi32> {
+// CHECK:           %[[VIEW_0:.*]] = memref.view %[[ARG0]]{{\[}}%[[ARG1]]][] : memref<?xi8> to memref<5x4xi32>
+// CHECK:           %[[CAST_0:.*]] = memref.cast %[[VIEW_0]] : memref<5x4xi32> to memref<?x4xi32>
+// CHECK:           return %[[CAST_0]] : memref<?x4xi32>
+// CHECK:         }
+func.func @replace_view_static_dims(%src: memref<?xi8>, %offset : index) -> memref<?x4xi32> {
+  %c5 = arith.constant 5: index
+  %res = memref.view %src[%offset][%c5] : memref<?xi8> to memref<?x4xi32>
+  return %res : memref<?x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @non_replace_view_negative_static_dims(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xi8>, %[[ARG1:.*]]: index) -> memref<?x4xi32> {
+// CHECK:           %[[CONSTANT_0:.*]] = arith.constant -1 : index
+// CHECK:           %[[VIEW_0:.*]] = memref.view %[[ARG0]]{{\[}}%[[ARG1]]]{{\[}}%[[CONSTANT_0]]] : memref<?xi8> to memref<?x4xi32>
+// CHECK:           return %[[VIEW_0]] : memref<?x4xi32>
+// CHECK:         }
+func.func @non_replace_view_negative_static_dims(%src: memref<?xi8>, %offset : index) -> memref<?x4xi32> {
+  %c-1 = arith.constant -1: index
+  %res = memref.view %src[%offset][%c-1] : memref<?xi8> to memref<?x4xi32>
+  return %res : memref<?x4xi32>
+}

>From 5f47a547058dd2d9ba8aa8619de50854b4480dd0 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Sun, 18 Jan 2026 00:03:18 +0800
Subject: [PATCH 2/2] use early exits and continue to simplify code

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 26 ++++++++++++------------
 1 file changed, 13 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 48636a95fd989..bc82a2dd5931e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3768,27 +3768,27 @@ namespace {
 /// static dimension sizes.
 static MemRefType
 foldDynamicToStaticDimSizes(MemRefType type, ValueRange dynamicSizes,
-                            SmallVector<Value> &foldedDynamicSizes) {
+                            SmallVectorImpl<Value> &foldedDynamicSizes) {
   SmallVector<int64_t> staticShape(type.getShape());
   assert(type.getNumDynamicDims() == dynamicSizes.size() &&
          "incorrect number of dynamic sizes");
 
   // Compute new static and dynamic sizes.
   unsigned ctr = 0;
-  for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
-    if (type.isDynamicDim(i)) {
-      Value dynamicSize = dynamicSizes[ctr++];
-      std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
-      if (cst.has_value()) {
-        // Dynamic size must be non-negative.
-        if (cst.value() < 0) {
-          foldedDynamicSizes.push_back(dynamicSize);
-          continue;
-        }
-        staticShape[i] = *cst;
-      } else {
+  for (auto [dim, dimSize] : llvm::enumerate(type.getShape())) {
+    if (ShapedType::isStatic(dimSize))
+      continue;
+
+    Value dynamicSize = dynamicSizes[ctr++];
+    if (auto cst = getConstantIntValue(dynamicSize)) {
+      // Dynamic size must be non-negative.
+      if (cst.value() < 0) {
         foldedDynamicSizes.push_back(dynamicSize);
+        continue;
       }
+      staticShape[dim] = cst.value();
+    } else {
+      foldedDynamicSizes.push_back(dynamicSize);
     }
   }
 



More information about the Mlir-commits mailing list