[Mlir-commits] [mlir] 4c49144 - [mlir][memref] Refactor `ViewOpShapeFolder` (#176567)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jan 17 21:17:55 PST 2026


Author: Longsheng Mou
Date: 2026-01-18T13:17:50+08:00
New Revision: 4c49144a42eb1927d34b77c77f0722a97930d296

URL: https://github.com/llvm/llvm-project/commit/4c49144a42eb1927d34b77c77f0722a97930d296
DIFF: https://github.com/llvm/llvm-project/commit/4c49144a42eb1927d34b77c77f0722a97930d296.diff

LOG: [mlir][memref] Refactor `ViewOpShapeFolder` (#176567)

This PR makes the following changes to ViewOpShapeFolder:
- Add comments for `ViewOpShapeFolder`.
- Drop the redundant offset check.
- Simplify the implementation by introducing
`foldDynamicToStaticDimSizes`.
- Add missing test coverage.

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9ebf349c807aa..b782a8be19154 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3777,71 +3777,69 @@ SmallVector<OpFoldResult> ViewOp::getMixedSizes() {
 }
 
 namespace {
+/// Given a memref type and a range of values that defines its dynamic
+/// dimension sizes, turn all dynamic sizes that have a constant value into
+/// static dimension sizes.
+static MemRefType
+foldDynamicToStaticDimSizes(MemRefType type, ValueRange dynamicSizes,
+                            SmallVectorImpl<Value> &foldedDynamicSizes) {
+  SmallVector<int64_t> staticShape(type.getShape());
+  assert(type.getNumDynamicDims() == dynamicSizes.size() &&
+         "incorrect number of dynamic sizes");
+
+  // Compute new static and dynamic sizes.
+  unsigned ctr = 0;
+  for (auto [dim, dimSize] : llvm::enumerate(type.getShape())) {
+    if (ShapedType::isStatic(dimSize))
+      continue;
+
+    Value dynamicSize = dynamicSizes[ctr++];
+    if (auto cst = getConstantIntValue(dynamicSize)) {
+      // Dynamic size must be non-negative.
+      if (cst.value() < 0) {
+        foldedDynamicSizes.push_back(dynamicSize);
+        continue;
+      }
+      staticShape[dim] = cst.value();
+    } else {
+      foldedDynamicSizes.push_back(dynamicSize);
+    }
+  }
+
+  return MemRefType::Builder(type).setShape(staticShape);
+}
 
+/// Change the result type of a `memref.view` by making originally dynamic
+/// dimensions static when their sizes come from `constant` ops.
+/// Example:
+///  ```
+///  %c5 = arith.constant 5: index
+///  %0 = memref.view %src[%offset][%c5] : memref<?xi8> to memref<?x4xf32>
+///  ```
+///  to
+///  ```
+///  %0 = memref.view %src[%offset][] : memref<?xi8> to memref<5x4xf32>
+///  ```
 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
   using Base::Base;
 
   LogicalResult matchAndRewrite(ViewOp viewOp,
                                 PatternRewriter &rewriter) const override {
-    // Return if none of the operands are constants.
-    if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
-          return matchPattern(operand, matchConstantIndex());
-        }))
-      return failure();
-
-    // Get result memref type.
-    auto memrefType = viewOp.getType();
-
-    // Get offset from old memref view type 'memRefType'.
-    int64_t oldOffset;
-    SmallVector<int64_t, 4> oldStrides;
-    if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
-      return failure();
-    assert(oldOffset == 0 && "Expected 0 offset");
-
-    SmallVector<Value, 4> newOperands;
-
-    // Offset cannot be folded into result type.
+    SmallVector<Value> foldedDynamicSizes;
+    MemRefType resultType = viewOp.getType();
+    MemRefType foldedMemRefType = foldDynamicToStaticDimSizes(
+        resultType, viewOp.getSizes(), foldedDynamicSizes);
 
-    // Fold any dynamic dim operands which are produced by a constant.
-    SmallVector<int64_t, 4> newShapeConstants;
-    newShapeConstants.reserve(memrefType.getRank());
-
-    unsigned dynamicDimPos = 0;
-    unsigned rank = memrefType.getRank();
-    for (unsigned dim = 0, e = rank; dim < e; ++dim) {
-      int64_t dimSize = memrefType.getDimSize(dim);
-      // If this is already static dimension, keep it.
-      if (ShapedType::isStatic(dimSize)) {
-        newShapeConstants.push_back(dimSize);
-        continue;
-      }
-      auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
-      if (auto constantIndexOp =
-              dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
-        // Dynamic shape dimension will be folded.
-        newShapeConstants.push_back(constantIndexOp.value());
-      } else {
-        // Dynamic shape dimension not folded; copy operand from old memref.
-        newShapeConstants.push_back(dimSize);
-        newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
-      }
-      dynamicDimPos++;
-    }
-
-    // Create new memref type with constant folded dims.
-    MemRefType newMemRefType =
-        MemRefType::Builder(memrefType).setShape(newShapeConstants);
-    // Nothing new, don't fold.
-    if (newMemRefType == memrefType)
+    // Stop here if no dynamic size was promoted to static.
+    if (foldedMemRefType == resultType)
       return failure();
 
     // Create new ViewOp.
-    auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType,
-                                    viewOp.getOperand(0), viewOp.getByteShift(),
-                                    newOperands);
+    auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), foldedMemRefType,
+                                    viewOp.getSource(), viewOp.getByteShift(),
+                                    foldedDynamicSizes);
     // Insert a cast so we have the same type as the old memref type.
-    rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
+    rewriter.replaceOpWithNewOp<CastOp>(viewOp, resultType, newViewOp);
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 122906037b952..d32f8f7efc5ff 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -1577,3 +1577,31 @@ func.func @non_fold_view_same_source_dynamic_size(%0: memref<?xi8>, %arg0 : inde
   %res = memref.view %0[%c0][%arg0] : memref<?xi8> to memref<?xi8>
   return %res : memref<?xi8>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @replace_view_static_dims(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xi8>, %[[ARG1:.*]]: index) -> memref<?x4xi32> {
+// CHECK:           %[[VIEW_0:.*]] = memref.view %[[ARG0]]{{\[}}%[[ARG1]]][] : memref<?xi8> to memref<5x4xi32>
+// CHECK:           %[[CAST_0:.*]] = memref.cast %[[VIEW_0]] : memref<5x4xi32> to memref<?x4xi32>
+// CHECK:           return %[[CAST_0]] : memref<?x4xi32>
+// CHECK:         }
+func.func @replace_view_static_dims(%src: memref<?xi8>, %offset : index) -> memref<?x4xi32> {
+  %c5 = arith.constant 5: index
+  %res = memref.view %src[%offset][%c5] : memref<?xi8> to memref<?x4xi32>
+  return %res : memref<?x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @non_replace_view_negative_static_dims(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?xi8>, %[[ARG1:.*]]: index) -> memref<?x4xi32> {
+// CHECK:           %[[CONSTANT_0:.*]] = arith.constant -1 : index
+// CHECK:           %[[VIEW_0:.*]] = memref.view %[[ARG0]]{{\[}}%[[ARG1]]]{{\[}}%[[CONSTANT_0]]] : memref<?xi8> to memref<?x4xi32>
+// CHECK:           return %[[VIEW_0]] : memref<?x4xi32>
+// CHECK:         }
+func.func @non_replace_view_negative_static_dims(%src: memref<?xi8>, %offset : index) -> memref<?x4xi32> {
+  %c-1 = arith.constant -1: index
+  %res = memref.view %src[%offset][%c-1] : memref<?xi8> to memref<?x4xi32>
+  return %res : memref<?x4xi32>
+}


        


More information about the Mlir-commits mailing list