[Mlir-commits] [mlir] [mlir] Fold memref.cast static-to-dynamic to memref.expand_shape (PR #170037)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 1 23:30:53 PST 2025
================
@@ -2504,11 +2504,82 @@ LogicalResult ExpandShapeOp::verify() {
return success();
}
+struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
+public:
+ using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExpandShapeOp op,
+ PatternRewriter &rewriter) const override {
+ auto cast = op.getSrc().getDefiningOp<CastOp>();
+ if (!cast)
+ return failure();
+
+ if (!CastOp::canFoldIntoConsumerOp(cast))
+ return failure();
+
+ SmallVector<OpFoldResult> originalOutputShape = op.getMixedOutputShape();
+ SmallVector<OpFoldResult> newOutputShape = originalOutputShape;
+ SmallVector<int64_t> newOutputShapeSizes;
+ SmallVector<Value> newOperands;
+
+ // Convert output shape dims from dynamic to static where possible.
+ for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
+ auto sizeOpt = getConstantIntValue(dimSize);
+ if (sizeOpt.has_value()) {
+ newOutputShapeSizes.push_back(sizeOpt.value());
+ newOutputShape[dimIdx] = rewriter.getIndexAttr(sizeOpt.value());
+ continue;
+ }
+
+ newOperands.push_back(llvm::cast<Value>(dimSize));
+ newOutputShapeSizes.push_back(ShapedType::kDynamic);
+ }
+
+ if (newOperands.size() == op->getNumOperands())
+ return rewriter.notifyMatchFailure(
+ op, "no static-to-dynamic conversions found");
+
+ auto castSource = cast.getSource();
+ auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
+ auto reassociationIndices = op.getReassociationIndices();
+ for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
+ int64_t castSourceDynCount = castSourceType.isDynamicDim(idx) ? 1 : 0;
+ auto newOutputShapeSizesSlice =
+ ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
+ int64_t newOutputDynCount =
+ llvm::count_if(newOutputShapeSizesSlice, ShapedType::isDynamic);
+ if (castSourceDynCount != newOutputDynCount)
+ return rewriter.notifyMatchFailure(
+ op, "folding cast will result in changing dynamicity in "
+ "reassociation group");
----------------
kdmitry1 wrote:
I was under impression that memref.expand_shape does not allow expanding single dynamic dim into multiple. I was wrong. Fixed to allow and added test for this.
https://github.com/llvm/llvm-project/pull/170037
More information about the Mlir-commits
mailing list