[Mlir-commits] [mlir] [mlir] Fold memref.cast static-to-dynamic to memref.expand_shape (PR #170037)

Quinn Dawkins llvmlistbot at llvm.org
Mon Dec 1 11:15:55 PST 2025


================
@@ -2504,11 +2504,82 @@ LogicalResult ExpandShapeOp::verify() {
   return success();
 }
 
+struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
+public:
+  using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExpandShapeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto cast = op.getSrc().getDefiningOp<CastOp>();
+    if (!cast)
+      return failure();
+
+    if (!CastOp::canFoldIntoConsumerOp(cast))
+      return failure();
+
+    SmallVector<OpFoldResult> originalOutputShape = op.getMixedOutputShape();
+    SmallVector<OpFoldResult> newOutputShape = originalOutputShape;
+    SmallVector<int64_t> newOutputShapeSizes;
+    SmallVector<Value> newOperands;
+
+    // Convert output shape dims from dynamic to static where possible.
+    for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
+      auto sizeOpt = getConstantIntValue(dimSize);
+      if (sizeOpt.has_value()) {
+        newOutputShapeSizes.push_back(sizeOpt.value());
+        newOutputShape[dimIdx] = rewriter.getIndexAttr(sizeOpt.value());
+        continue;
+      }
+
+      newOperands.push_back(llvm::cast<Value>(dimSize));
+      newOutputShapeSizes.push_back(ShapedType::kDynamic);
+    }
+
+    if (newOperands.size() == op->getNumOperands())
+      return rewriter.notifyMatchFailure(
+          op, "no static-to-dynamic conversions found");
+
+    auto castSource = cast.getSource();
+    auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
+    auto reassociationIndices = op.getReassociationIndices();
+    for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
+      int64_t castSourceDynCount = castSourceType.isDynamicDim(idx) ? 1 : 0;
+      auto newOutputShapeSizesSlice =
+          ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
+      int64_t newOutputDynCount =
+          llvm::count_if(newOutputShapeSizesSlice, ShapedType::isDynamic);
+      if (castSourceDynCount != newOutputDynCount)
+        return rewriter.notifyMatchFailure(
+            op, "folding cast will result in changing dynamicity in "
+                "reassociation group");
----------------
qedawkins wrote:

I'm not sure I follow this check. Doesn't this cause any cast of the following form to fail this pattern?

```
%0 = memref.cast ... : memref<...x?x...> to memref<...x?x...> // other dims are casted, this dynamic one is unaffected.
%s0 = ... : index
%s1 = ... : index
%1 = memref.expand_shape %0 [..., [n, n+1], ...] output_shape [..., %s0, %s1, ...]
                         : memref<...x?x...> to memref<...x?x?x...>
```

i.e. one dynamic dimension is expanded to 2 (or more) and that dimension is unaffected by the cast. In this case the expanded group would get `castSourceDynCount = 1` and `newOutputDynCount = 2` if I'm reading the logic here correctly, but the folding of the cast should be perfectly fine.

There is a similar pattern for `tensor.expand_shape` here in case it helps as a reference: https://github.com/llvm/llvm-project/blob/e6ae2462bd6dcf583ccd13c6627fe3ffe8a17f2c/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp#L2173


https://github.com/llvm/llvm-project/pull/170037


More information about the Mlir-commits mailing list