[Mlir-commits] [mlir] [MLIR] Add shape propagation through tensor.pad (PR #136681)

Hyunsung Lee llvmlistbot at llvm.org
Tue Jul 29 02:47:46 PDT 2025


https://github.com/ita9naiwa updated https://github.com/llvm/llvm-project/pull/136681

>From 0d8c636d4fd4d5c9636cfd3599c804e4a89e81e6 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Tue, 22 Apr 2025 19:04:02 +0900
Subject: [PATCH 1/9] Add FoldReshapeWithProducerPadOpByExpansion

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 142 ++++++++++++++++++
 mlir/test/Dialect/Linalg/reshape_fusion.mlir  |  51 ++++++-
 2 files changed, 191 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index bf70597d5ddfe..dd4ac89e98090 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1101,6 +1101,146 @@ class FoldPadWithProducerReshapeOpByExpansion
   ControlFusionFn controlFoldingReshapes;
 };
 
+/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
+/// by bubbling the expand_shape before the pad.
+struct FoldReshapeWithProducerPadOpByExpansion
+    : public OpRewritePattern<tensor::ExpandShapeOp> {
+
+  FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
+                                          ControlFusionFn foldReshapes,
+                                          PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+    if (!padOp)
+      return failure();
+
+    if (!padOp->hasOneUse())
+      return failure();
+
+    if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+      return rewriter.notifyMatchFailure(expandOp,
+                                         "fusion blocked by control function");
+    }
+
+    SmallVector<ReassociationIndices> reassociations =
+        expandOp.getReassociationIndices();
+    SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
+    SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
+
+    auto isZeroPadding = [](OpFoldResult padValue) -> bool {
+      if (auto attr = dyn_cast<Attribute>(padValue)) {
+        if (auto intAttr = dyn_cast<IntegerAttr>(attr))
+          return intAttr.getInt() == 0;
+      }
+
+      if (auto val = dyn_cast<Value>(padValue)) {
+        if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
+          if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
+            return attr.getInt() == 0;
+        }
+      }
+
+      // when padding is dynamic and not constant, we don't know if it's zero or
+      // not. so we return false here.
+      return false;
+    };
+
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      OpFoldResult l = low[idx];
+      OpFoldResult h = high[idx];
+      if (reInd.size() != 1 && (!isZeroPadding(l) || !isZeroPadding(h)))
+        return failure();
+    }
+
+    SmallVector<OpFoldResult> newLow, newHigh;
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      for (size_t i = 0; i < reInd.size(); ++i) {
+        newLow.push_back(padOp.getMixedLowPad()[idx]);
+        newHigh.push_back(padOp.getMixedHighPad()[idx]);
+      }
+    }
+
+    Location loc = expandOp.getLoc();
+    auto finalType = cast<RankedTensorType>(expandOp.getType());
+    ArrayRef<int64_t> finalShape = finalType.getShape();
+
+    SmallVector<OpFoldResult> expandedShape;
+    for (int64_t dimSize : finalShape) {
+      if (dimSize == ShapedType::kDynamic) {
+        expandedShape.push_back(OpFoldResult{});
+      } else {
+        expandedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
+      }
+    }
+
+    for (auto [inDimIdx, outGroup] : llvm::enumerate(reassociations)) {
+      OpFoldResult l = low[inDimIdx];
+      OpFoldResult h = high[inDimIdx];
+
+      if (!isZeroPadding(l) || !isZeroPadding(h)) {
+        auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
+        int64_t originalSize = srcType.getDimSize(inDimIdx);
+
+        OpFoldResult originalSizeOFR;
+        if (originalSize == ShapedType::kDynamic) {
+          Value orgSizeVal =
+              rewriter.create<tensor::DimOp>(loc, padOp.getSource(), inDimIdx);
+          originalSizeOFR = orgSizeVal;
+        } else {
+          originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
+        }
+
+        for (auto outDimIdx : outGroup) {
+          expandedShape[outDimIdx] = originalSizeOFR;
+        }
+      }
+    }
+
+    for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
+      if (dimSize == ShapedType::kDynamic &&
+          !isa<Value>(expandedShape[outDimIdx]) &&
+          !isa<Attribute>(expandedShape[outDimIdx])) {
+        Value actualSize =
+            rewriter.create<tensor::DimOp>(loc, expandOp.getSrc(), outDimIdx);
+        expandedShape[outDimIdx] = actualSize;
+      }
+    }
+
+    SmallVector<int64_t> staticExpandedShape;
+    for (OpFoldResult dim : expandedShape) {
+      if (auto attr = dyn_cast<Attribute>(dim)) {
+        if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+          staticExpandedShape.push_back(intAttr.getInt());
+        } else {
+          staticExpandedShape.push_back(ShapedType::kDynamic);
+        }
+      } else {
+        staticExpandedShape.push_back(ShapedType::kDynamic);
+      }
+    }
+
+    auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
+        loc,
+        RankedTensorType::get(staticExpandedShape,
+                              padOp.getSource().getType().getElementType()),
+        padOp.getSource(), reassociations);
+
+    auto newPadOp = rewriter.create<tensor::PadOp>(
+        loc, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh,
+        padOp.getConstantPaddingValue(), padOp.getNofold());
+
+    rewriter.replaceOp(expandOp, newPadOp.getResult());
+    return success();
+  }
+
+private:
+  ControlFusionFn controlFoldingReshapes;
+};
+
 /// Pattern to fold a tensor.expand_shape op with its producer generic op
 /// by expanding the dimensionality of the loop in the producer op.
 struct FoldReshapeWithGenericOpByExpansion
@@ -2249,6 +2389,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
                                                     controlFoldingReshapes);
   patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                         controlFoldingReshapes);
+  patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
+                                                        controlFoldingReshapes);
   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                      controlFoldingReshapes);
 }
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 67b4f2b32bad5..3ea0babfa3b9d 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -247,7 +247,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
 
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
-                                         %arg1 : tensor<?x?xi32>, 
+                                         %arg1 : tensor<?x?xi32>,
                                          %sz0: index, %sz1: index) ->
                                          tensor<?x?x4x5xi32>
 {
@@ -515,7 +515,7 @@ func.func @fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
 // -----
 
 func.func @reshape_as_consumer_permutation_with_multiple_results
-  (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index, 
+  (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
    %sz1: index, %sz2: index, %sz3: index, %sz4: index)
     -> (tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>) {
   %c:2 = linalg.generic {
@@ -893,3 +893,50 @@ func.func @move_operand_deps(%arg0 : tensor<?x128xf16>,
 //      CHECK:   %[[GENERIC:.+]] = linalg.generic
 // CHECK-SAME:       ins(%[[EXPANDED]] :
 //      CHECK:   return %[[GENERIC]]
+
+// -----
+
+func.func @fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0   = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
+  %padded = tensor.pad %0 low[0, 1, 1] high[0, 1, 1] {
+    ^bb0(%i: index, %j: index, %k: index):
+      tensor.yield %cst : f32
+  } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
+  %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
+  return %expanded : tensor<32x16x258x258xf32>
+}
+//      CHECK: func @fold_tensor_pad_with_expand(
+// CHECK-SAME:     %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
+//  CHECK-DAG:   %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+//  CHECK-DAG:   %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+//      CHECK:   %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]] : tensor<32x16x256x256xf32>)
+//      CHECK:   %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
+//      CHECK:   ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
+//      CHECK:     tensor.yield %[[CST]] : f32
+//      CHECK:   } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
+//      CHECK:   return %[[PADDED]] : tensor<32x16x258x258xf32>
+
+// -----
+
+func.func @fold_tensor_pad_with_expand_dynamic_pad_zero(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0   = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
+  %padded = tensor.pad %0 low[%c0, %c1, %c1] high[%c0, %c1, %c1] {
+    ^bb0(%i: index, %j: index, %k: index):
+      tensor.yield %cst : f32
+  } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
+  %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
+  return %expanded : tensor<32x16x258x258xf32>
+}
+//      CHECK: func @fold_tensor_pad_with_expand_dynamic_pad_zero(
+// CHECK-SAME:     %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
+//      CHECK:   %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+//      CHECK:   %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+//      CHECK:   %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]]
+//      CHECK:   %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
+//      CHECK:   ^bb0(
+//      CHECK:     tensor.yield %[[CST]] : f32
+//      CHECK:   return %[[PADDED]]

>From 57ec65705339764b1a472f32b382c015909b25e8 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Mon, 14 Jul 2025 09:59:33 +0900
Subject: [PATCH 2/9] add collapse_shape

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 176 +++++++++++++++---
 .../fuse-with-reshape-by-collapsing.mlir      |  53 +++++-
 2 files changed, 204 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 39eed6dd4cba4..e65228ae0e3eb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -26,6 +26,8 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/LogicalResult.h"
 #include <optional>
 #include <utility>
 
@@ -1100,6 +1102,20 @@ class FoldPadWithProducerReshapeOpByExpansion
   ControlFusionFn controlFoldingReshapes;
 };
 
+bool isZero(OpFoldResult value) {
+  if (auto attr = dyn_cast<Attribute>(value)) {
+    if (auto intAttr = dyn_cast<IntegerAttr>(attr))
+      return intAttr.getInt() == 0;
+  }
+  if (auto val = dyn_cast<Value>(value)) {
+    if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
+      if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
+        return attr.getInt() == 0;
+    }
+  }
+  return false;
+}
+
 /// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
 /// by bubbling the expand_shape before the pad.
 struct FoldReshapeWithProducerPadOpByExpansion
@@ -1125,41 +1141,29 @@ struct FoldReshapeWithProducerPadOpByExpansion
                                          "fusion blocked by control function");
     }
 
+    Value constantPaddingValue = padOp.getConstantPaddingValue();
+    if (!constantPaddingValue) {
+      return rewriter.notifyMatchFailure(
+          expandOp, "cannot fold with non-constant padding value");
+    }
+
     SmallVector<ReassociationIndices> reassociations =
         expandOp.getReassociationIndices();
     SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
     SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
 
-    auto isZeroPadding = [](OpFoldResult padValue) -> bool {
-      if (auto attr = dyn_cast<Attribute>(padValue)) {
-        if (auto intAttr = dyn_cast<IntegerAttr>(attr))
-          return intAttr.getInt() == 0;
-      }
-
-      if (auto val = dyn_cast<Value>(padValue)) {
-        if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
-          if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
-            return attr.getInt() == 0;
-        }
-      }
-
-      // when padding is dynamic and not constant, we don't know if it's zero or
-      // not. so we return false here.
-      return false;
-    };
-
     for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
       OpFoldResult l = low[idx];
       OpFoldResult h = high[idx];
-      if (reInd.size() != 1 && (!isZeroPadding(l) || !isZeroPadding(h)))
+      if (reInd.size() > 1 && (!isZero(l) || !isZero(h)))
         return failure();
     }
 
     SmallVector<OpFoldResult> newLow, newHigh;
     for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
       for (size_t i = 0; i < reInd.size(); ++i) {
-        newLow.push_back(padOp.getMixedLowPad()[idx]);
-        newHigh.push_back(padOp.getMixedHighPad()[idx]);
+        newLow.push_back(low[idx]);
+        newHigh.push_back(high[idx]);
       }
     }
 
@@ -1176,11 +1180,11 @@ struct FoldReshapeWithProducerPadOpByExpansion
       }
     }
 
-    for (auto [inDimIdx, outGroup] : llvm::enumerate(reassociations)) {
+    for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
       OpFoldResult l = low[inDimIdx];
       OpFoldResult h = high[inDimIdx];
 
-      if (!isZeroPadding(l) || !isZeroPadding(h)) {
+      if (!isZero(l) || !isZero(h)) {
         auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
         int64_t originalSize = srcType.getDimSize(inDimIdx);
 
@@ -1193,7 +1197,7 @@ struct FoldReshapeWithProducerPadOpByExpansion
           originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
         }
 
-        for (auto outDimIdx : outGroup) {
+        for (auto outDimIdx : reInd) {
           expandedShape[outDimIdx] = originalSizeOFR;
         }
       }
@@ -1240,6 +1244,125 @@ struct FoldReshapeWithProducerPadOpByExpansion
   ControlFusionFn controlFoldingReshapes;
 };
 
+/// Pattern to fold a tensor.collapse_shape op with its producer tensor.pad op
+/// by bubbling the collapse_shape before the pad.
+struct FoldReshapeWithProducerPadOpByCollapsing
+    : public OpRewritePattern<tensor::CollapseShapeOp> {
+
+  FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
+                                           ControlFusionFn foldReshapes,
+                                           PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::PadOp padOp = collapseOp.getSrc().getDefiningOp<tensor::PadOp>();
+
+    if (!padOp)
+      return failure();
+
+    if (!padOp->hasOneUse())
+      return failure();
+
+    if (!controlFoldingReshapes(&collapseOp.getSrcMutable())) {
+      return rewriter.notifyMatchFailure(collapseOp,
+                                         "fusion blocked by control function");
+    }
+
+    Value constantPaddingValue = padOp.getConstantPaddingValue();
+    if (!constantPaddingValue) {
+      return rewriter.notifyMatchFailure(
+          collapseOp, "cannot fold with non-constant padding value");
+    }
+
+    SmallVector<ReassociationIndices> reassociations =
+        collapseOp.getReassociationIndices();
+    SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
+    SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
+
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      if (reInd.size() > 1) {
+        for (auto dimIdx : reInd) {
+          if (!isZero(low[dimIdx]) || !isZero(high[dimIdx])) {
+            return failure();
+          }
+        }
+      }
+    }
+
+    SmallVector<OpFoldResult> newLow, newHigh;
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      newLow.push_back(low[reInd[0]]);
+      newHigh.push_back(high[reInd[0]]);
+    }
+
+    Location loc = collapseOp.getLoc();
+    auto resultType = collapseOp.getResultType();
+
+    auto finalType = cast<RankedTensorType>(collapseOp.getType());
+    ArrayRef<int64_t> finalShape = finalType.getShape();
+
+    SmallVector<OpFoldResult> collapsedShape;
+    for (int64_t dimSize : finalShape) {
+      if (dimSize == ShapedType::kDynamic) {
+        collapsedShape.push_back(OpFoldResult{});
+      } else {
+        collapsedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
+      }
+    }
+
+    for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
+      OpFoldResult l = low[reInd[0]];
+      OpFoldResult h = high[reInd[0]];
+
+      if (!isZero(l) || !isZero(h)) {
+        auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
+        int64_t originalSize = srcType.getDimSize(reInd[0]);
+
+        OpFoldResult originalSizeOFR;
+        if (originalSize == ShapedType::kDynamic) {
+          Value orgSizeVal =
+              rewriter.create<tensor::DimOp>(loc, padOp.getSource(), reInd[0]);
+          originalSizeOFR = orgSizeVal;
+        } else {
+          originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
+        }
+        collapsedShape[inDimIdx] = originalSizeOFR;
+      }
+    }
+
+    SmallVector<int64_t> staticCollapsedShape;
+    for (OpFoldResult dim : collapsedShape) {
+      if (auto attr = dyn_cast<Attribute>(dim)) {
+        if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+          staticCollapsedShape.push_back(intAttr.getInt());
+        } else {
+          staticCollapsedShape.push_back(ShapedType::kDynamic);
+        }
+      } else {
+        staticCollapsedShape.push_back(ShapedType::kDynamic);
+      }
+    }
+
+    auto newCollapseType = RankedTensorType::get(
+        staticCollapsedShape, padOp.getSource().getType().getElementType());
+    auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
+        loc, newCollapseType, padOp.getSource(), reassociations);
+
+    auto newPadOp = rewriter.create<tensor::PadOp>(
+        loc, resultType, newCollapseOp.getResult(), newLow, newHigh,
+        padOp.getConstantPaddingValue(), padOp.getNofold());
+
+    rewriter.replaceOp(collapseOp, newPadOp.getResult());
+
+    return success();
+  }
+
+private:
+  ControlFusionFn controlFoldingReshapes;
+};
+
 /// Pattern to fold a tensor.expand_shape op with its producer generic op
 /// by expanding the dimensionality of the loop in the producer op.
 struct FoldReshapeWithGenericOpByExpansion
@@ -2388,6 +2511,11 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
                                                       controlFoldingReshapes);
   patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
       patterns.getContext(), controlFoldingReshapes);
+  patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
+      patterns.getContext(), controlFoldingReshapes);
+
+  patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
+      patterns.getContext(), controlFoldingReshapes);
   patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
                                                      controlFoldingReshapes);
 }
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 2bf3d21c35526..0ac1686361bf7 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -232,7 +232,7 @@ func.func @fuse_by_collapsing_dynamic_2(%arg0 : tensor<?xf32>, %sz0: index, %sz1
   %1 = linalg.generic {
       indexing_maps = [#map0, #map0],
       iterator_types = ["parallel", "parallel"]}
-      ins(%0 : tensor<?x?xf32>) 
+      ins(%0 : tensor<?x?xf32>)
       outs(%init : tensor<?x?xf32>) {
         ^bb0(%b0 : f32, %b1 : f32):
           %out = arith.negf %b0 : f32
@@ -858,3 +858,54 @@ func.func @partial_fuse_by_collapsing(%arg0: tensor<4x?x32x128x192xf16>, %arg1:
 //       CHECK:   %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]]
 //  CHECK-SAME:     tensor<4x128x192x?x32xf32> into tensor<512x192x?xf32>
 //       CHECK:   return %[[COLLAPSED]] : tensor<512x192x?xf32>
+
+// -----
+
+func.func @fold_tensor_pad_with_collapse(%arg0: tensor<32x16x256x256xf32>) -> tensor<512x258x258xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<32x16x256x256xf32>) -> tensor<32x16x256x256xf32>
+  %padded = tensor.pad %0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
+    ^bb0(%i0: index, %i1: index, %i2: index, %i3: index):
+      tensor.yield %cst : f32
+  } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
+  %collapsed = tensor.collapse_shape %padded [[0, 1], [2], [3]]
+    : tensor<32x16x258x258xf32> into tensor<512x258x258xf32>
+  return %collapsed : tensor<512x258x258xf32>
+}
+//      CHECK: func @fold_tensor_pad_with_collapse(
+// CHECK-SAME:     %[[ARG0:[^:]+]]: tensor<32x16x256x256xf32>
+//  CHECK-DAG:   %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+//      CHECK:   %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<32x16x256x256xf32>)
+//      CHECK:   %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FILLED]] {{\[}}[0, 1], [2], [3]{{\]}}
+// CHECK-SAME:       : tensor<32x16x256x256xf32> into tensor<512x256x256xf32>
+//      CHECK:   %[[PADDED:.*]] = tensor.pad %[[COLLAPSED]] low[0, 1, 1] high[0, 1, 1]
+//      CHECK:   ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index):
+//      CHECK:     tensor.yield %[[CST]] : f32
+//      CHECK:   } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
+//      CHECK:   return %[[PADDED]] : tensor<512x258x258xf32>
+
+// -----
+
+func.func @fold_tensor_pad_with_collapse_dynamic_pad_zero(%arg0: tensor<32x16x256x256xf32>) -> tensor<512x258x258xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<32x16x256x256xf32>) -> tensor<32x16x256x256xf32>
+  %padded = tensor.pad %0 low[%c0, %c0, %c1, %c1] high[%c0, %c0, %c1, %c1] {
+    ^bb0(%i0: index, %i1: index, %i2: index, %i3: index):
+      tensor.yield %cst : f32
+  } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
+  %collapsed = tensor.collapse_shape %padded [[0, 1], [2], [3]]
+    : tensor<32x16x258x258xf32> into tensor<512x258x258xf32>
+  return %collapsed : tensor<512x258x258xf32>
+}
+//      CHECK: func @fold_tensor_pad_with_collapse_dynamic_pad_zero(
+// CHECK-SAME:     %[[ARG0:[^:]+]]: tensor<32x16x256x256xf32>
+//      CHECK:   %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+//      CHECK:   %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<32x16x256x256xf32>)
+//      CHECK:   %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FILLED]] {{\[}}[0, 1], [2], [3]{{\]}}
+// CHECK-SAME:       : tensor<32x16x256x256xf32> into tensor<512x256x256xf32>
+//      CHECK:   %[[PADDED:.*]] = tensor.pad %[[COLLAPSED]] low[0, 1, 1] high[0, 1, 1]
+//      CHECK:   ^bb0(
+//      CHECK:     tensor.yield %[[CST]] : f32
+//      CHECK:   return %[[PADDED]]

>From 737d4a4c776030cf9154aeb10039c870a6a63211 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 19 Jul 2025 07:37:30 +0900
Subject: [PATCH 3/9] fix upon review

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 75 ++++++-------------
 1 file changed, 22 insertions(+), 53 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index e65228ae0e3eb..05dbb7cd7ba43 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1102,20 +1102,6 @@ class FoldPadWithProducerReshapeOpByExpansion
   ControlFusionFn controlFoldingReshapes;
 };
 
-bool isZero(OpFoldResult value) {
-  if (auto attr = dyn_cast<Attribute>(value)) {
-    if (auto intAttr = dyn_cast<IntegerAttr>(attr))
-      return intAttr.getInt() == 0;
-  }
-  if (auto val = dyn_cast<Value>(value)) {
-    if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
-      if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
-        return attr.getInt() == 0;
-    }
-  }
-  return false;
-}
-
 /// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
 /// by bubbling the expand_shape before the pad.
 struct FoldReshapeWithProducerPadOpByExpansion
@@ -1152,19 +1138,17 @@ struct FoldReshapeWithProducerPadOpByExpansion
     SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
     SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
 
-    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
-      OpFoldResult l = low[idx];
-      OpFoldResult h = high[idx];
-      if (reInd.size() > 1 && (!isZero(l) || !isZero(h)))
-        return failure();
+    for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
+      if (reInd.size() > 1 &&
+          (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)))
+        return rewriter.notifyMatchFailure(
+            expandOp, "fusion blocked by non-zero padding");
     }
 
     SmallVector<OpFoldResult> newLow, newHigh;
     for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
-      for (size_t i = 0; i < reInd.size(); ++i) {
-        newLow.push_back(low[idx]);
-        newHigh.push_back(high[idx]);
-      }
+      newLow.append(reInd.size(), low[idx]);
+      newHigh.append(reInd.size(), high[idx]);
     }
 
     Location loc = expandOp.getLoc();
@@ -1184,7 +1168,7 @@ struct FoldReshapeWithProducerPadOpByExpansion
       OpFoldResult l = low[inDimIdx];
       OpFoldResult h = high[inDimIdx];
 
-      if (!isZero(l) || !isZero(h)) {
+      if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) {
         auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
         int64_t originalSize = srcType.getDimSize(inDimIdx);
 
@@ -1196,10 +1180,8 @@ struct FoldReshapeWithProducerPadOpByExpansion
         } else {
           originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
         }
-
-        for (auto outDimIdx : reInd) {
-          expandedShape[outDimIdx] = originalSizeOFR;
-        }
+        assert(reInd.size() == 1 && "expected single dimension");
+        expandedShape[reInd[0]] = originalSizeOFR;
       }
     }
 
@@ -1207,36 +1189,24 @@ struct FoldReshapeWithProducerPadOpByExpansion
       if (dimSize == ShapedType::kDynamic &&
           !isa<Value>(expandedShape[outDimIdx]) &&
           !isa<Attribute>(expandedShape[outDimIdx])) {
-        Value actualSize =
-            rewriter.create<tensor::DimOp>(loc, expandOp.getSrc(), outDimIdx);
-        expandedShape[outDimIdx] = actualSize;
+        expandedShape[outDimIdx] =
+            tensor::getMixedSize(rewriter, loc, expandOp.getSrc(), outDimIdx);
       }
     }
 
     SmallVector<int64_t> staticExpandedShape;
-    for (OpFoldResult dim : expandedShape) {
-      if (auto attr = dyn_cast<Attribute>(dim)) {
-        if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
-          staticExpandedShape.push_back(intAttr.getInt());
-        } else {
-          staticExpandedShape.push_back(ShapedType::kDynamic);
-        }
-      } else {
-        staticExpandedShape.push_back(ShapedType::kDynamic);
-      }
-    }
+    std::tie(staticExpandedShape, std::ignore) =
+        decomposeMixedValues(expandedShape);
 
     auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
         loc,
         RankedTensorType::get(staticExpandedShape,
                               padOp.getSource().getType().getElementType()),
-        padOp.getSource(), reassociations);
+        padOp.getSource(), reassociations, expandedShape);
 
-    auto newPadOp = rewriter.create<tensor::PadOp>(
-        loc, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh,
+    rewriter.replaceOpWithNewOp<tensor::PadOp>(
+        expandOp, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh,
         padOp.getConstantPaddingValue(), padOp.getNofold());
-
-    rewriter.replaceOp(expandOp, newPadOp.getResult());
     return success();
   }
 
@@ -1284,7 +1254,8 @@ struct FoldReshapeWithProducerPadOpByCollapsing
     for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
       if (reInd.size() > 1) {
         for (auto dimIdx : reInd) {
-          if (!isZero(low[dimIdx]) || !isZero(high[dimIdx])) {
+          if (!isConstantIntValue(low[dimIdx], 0) ||
+              !isConstantIntValue(high[dimIdx], 0)) {
             return failure();
           }
         }
@@ -1316,7 +1287,7 @@ struct FoldReshapeWithProducerPadOpByCollapsing
       OpFoldResult l = low[reInd[0]];
       OpFoldResult h = high[reInd[0]];
 
-      if (!isZero(l) || !isZero(h)) {
+      if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) {
         auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
         int64_t originalSize = srcType.getDimSize(reInd[0]);
 
@@ -1350,12 +1321,10 @@ struct FoldReshapeWithProducerPadOpByCollapsing
     auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
         loc, newCollapseType, padOp.getSource(), reassociations);
 
-    auto newPadOp = rewriter.create<tensor::PadOp>(
-        loc, resultType, newCollapseOp.getResult(), newLow, newHigh,
+    rewriter.replaceOpWithNewOp<tensor::PadOp>(
+        collapseOp, resultType, newCollapseOp.getResult(), newLow, newHigh,
         padOp.getConstantPaddingValue(), padOp.getNofold());
 
-    rewriter.replaceOp(collapseOp, newPadOp.getResult());
-
     return success();
   }
 

>From d8ca03657a6445a99cbe4558b7b8990a703c3c57 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 19 Jul 2025 07:44:43 +0900
Subject: [PATCH 4/9] fix upon review

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp       | 13 ++-----------
 1 file changed, 2 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 05dbb7cd7ba43..6499a3387efca 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1304,17 +1304,8 @@ struct FoldReshapeWithProducerPadOpByCollapsing
     }
 
     SmallVector<int64_t> staticCollapsedShape;
-    for (OpFoldResult dim : collapsedShape) {
-      if (auto attr = dyn_cast<Attribute>(dim)) {
-        if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
-          staticCollapsedShape.push_back(intAttr.getInt());
-        } else {
-          staticCollapsedShape.push_back(ShapedType::kDynamic);
-        }
-      } else {
-        staticCollapsedShape.push_back(ShapedType::kDynamic);
-      }
-    }
+    std::tie(staticCollapsedShape, std::ignore) =
+        decomposeMixedValues(collapsedShape);
 
     auto newCollapseType = RankedTensorType::get(
         staticCollapsedShape, padOp.getSource().getType().getElementType());

>From 17a24473a1c881cd6ac16f8a3cf5d7665ee477fe Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Tue, 22 Jul 2025 15:48:27 +0900
Subject: [PATCH 5/9] fix upon review

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 31 +++----------------
 1 file changed, 5 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 6499a3387efca..1ec3bd2ac8f1d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -26,8 +26,6 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/RegionUtils.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/LogicalResult.h"
 #include <optional>
 #include <utility>
 
@@ -1169,19 +1167,10 @@ struct FoldReshapeWithProducerPadOpByExpansion
       OpFoldResult h = high[inDimIdx];
 
       if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) {
-        auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
-        int64_t originalSize = srcType.getDimSize(inDimIdx);
-
-        OpFoldResult originalSizeOFR;
-        if (originalSize == ShapedType::kDynamic) {
-          Value orgSizeVal =
-              rewriter.create<tensor::DimOp>(loc, padOp.getSource(), inDimIdx);
-          originalSizeOFR = orgSizeVal;
-        } else {
-          originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
-        }
         assert(reInd.size() == 1 && "expected single dimension");
-        expandedShape[reInd[0]] = originalSizeOFR;
+        expandedShape[reInd[0]] =
+            tensor::getMixedSize(rewriter, loc, padOp.getSource(), inDimIdx);
+        ;
       }
     }
 
@@ -1288,18 +1277,8 @@ struct FoldReshapeWithProducerPadOpByCollapsing
       OpFoldResult h = high[reInd[0]];
 
       if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) {
-        auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
-        int64_t originalSize = srcType.getDimSize(reInd[0]);
-
-        OpFoldResult originalSizeOFR;
-        if (originalSize == ShapedType::kDynamic) {
-          Value orgSizeVal =
-              rewriter.create<tensor::DimOp>(loc, padOp.getSource(), reInd[0]);
-          originalSizeOFR = orgSizeVal;
-        } else {
-          originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
-        }
-        collapsedShape[inDimIdx] = originalSizeOFR;
+        collapsedShape[inDimIdx] =
+            tensor::getMixedSize(rewriter, loc, padOp.getSource(), reInd[0]);
       }
     }
 

>From 9ee8e08f9a8e48403df169499d2cb9b765156017 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Tue, 22 Jul 2025 16:54:23 +0900
Subject: [PATCH 6/9] fix upon review

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp    | 16 +++-------------
 1 file changed, 3 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 1ec3bd2ac8f1d..e99de0e78eabe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1150,17 +1150,8 @@ struct FoldReshapeWithProducerPadOpByExpansion
     }
 
     Location loc = expandOp.getLoc();
-    auto finalType = cast<RankedTensorType>(expandOp.getType());
-    ArrayRef<int64_t> finalShape = finalType.getShape();
-
-    SmallVector<OpFoldResult> expandedShape;
-    for (int64_t dimSize : finalShape) {
-      if (dimSize == ShapedType::kDynamic) {
-        expandedShape.push_back(OpFoldResult{});
-      } else {
-        expandedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
-      }
-    }
+    ArrayRef<int64_t> finalShape = expandOp.getResultType().getShape();
+    SmallVector<OpFoldResult> expandedShape = expandOp.getMixedOutputShape();
 
     for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
       OpFoldResult l = low[inDimIdx];
@@ -1260,8 +1251,7 @@ struct FoldReshapeWithProducerPadOpByCollapsing
     Location loc = collapseOp.getLoc();
     auto resultType = collapseOp.getResultType();
 
-    auto finalType = cast<RankedTensorType>(collapseOp.getType());
-    ArrayRef<int64_t> finalShape = finalType.getShape();
+    ArrayRef<int64_t> finalShape = collapseOp.getResultType().getShape();
 
     SmallVector<OpFoldResult> collapsedShape;
     for (int64_t dimSize : finalShape) {

>From 3b916457029e89f871394df0f7a25cdf0b674aff Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 26 Jul 2025 13:56:22 +0900
Subject: [PATCH 7/9] fix upon review

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 21 ++++---------------
 1 file changed, 4 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index e99de0e78eabe..0687502cd1092 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1136,23 +1136,19 @@ struct FoldReshapeWithProducerPadOpByExpansion
     SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
     SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
 
-    for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
-      if (reInd.size() > 1 &&
-          (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)))
+    SmallVector<OpFoldResult> newLow, newHigh;
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      if (reInd.size() > 1 && (!isConstantIntValue(low[idx], 0) ||
+                               !isConstantIntValue(high[idx], 0)))
         return rewriter.notifyMatchFailure(
             expandOp, "fusion blocked by non-zero padding");
-    }
 
-    SmallVector<OpFoldResult> newLow, newHigh;
-    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
       newLow.append(reInd.size(), low[idx]);
       newHigh.append(reInd.size(), high[idx]);
     }
 
     Location loc = expandOp.getLoc();
-    ArrayRef<int64_t> finalShape = expandOp.getResultType().getShape();
     SmallVector<OpFoldResult> expandedShape = expandOp.getMixedOutputShape();
-
     for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
       OpFoldResult l = low[inDimIdx];
       OpFoldResult h = high[inDimIdx];
@@ -1165,15 +1161,6 @@ struct FoldReshapeWithProducerPadOpByExpansion
       }
     }
 
-    for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
-      if (dimSize == ShapedType::kDynamic &&
-          !isa<Value>(expandedShape[outDimIdx]) &&
-          !isa<Attribute>(expandedShape[outDimIdx])) {
-        expandedShape[outDimIdx] =
-            tensor::getMixedSize(rewriter, loc, expandOp.getSrc(), outDimIdx);
-      }
-    }
-
     SmallVector<int64_t> staticExpandedShape;
     std::tie(staticExpandedShape, std::ignore) =
         decomposeMixedValues(expandedShape);

>From 9c38ad58f93dbf18fc42a2284fd86cf24f34d2cb Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Sat, 26 Jul 2025 14:24:13 +0900
Subject: [PATCH 8/9] fix upon review

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 23 +++++--------------
 1 file changed, 6 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 0687502cd1092..86e287bae6cf5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1239,32 +1239,21 @@ struct FoldReshapeWithProducerPadOpByCollapsing
     auto resultType = collapseOp.getResultType();
 
     ArrayRef<int64_t> finalShape = collapseOp.getResultType().getShape();
-
-    SmallVector<OpFoldResult> collapsedShape;
-    for (int64_t dimSize : finalShape) {
-      if (dimSize == ShapedType::kDynamic) {
-        collapsedShape.push_back(OpFoldResult{});
-      } else {
-        collapsedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
-      }
-    }
-
+    SmallVector<int64_t> collapsedShape(finalShape.begin(), finalShape.end());
     for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
       OpFoldResult l = low[reInd[0]];
       OpFoldResult h = high[reInd[0]];
-
       if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) {
-        collapsedShape[inDimIdx] =
+        auto mixedSize =
             tensor::getMixedSize(rewriter, loc, padOp.getSource(), reInd[0]);
+        auto dimSize = getConstantIntValue(mixedSize);
+        assert(dimSize.has_value() && "Expected static dimension");
+        collapsedShape[inDimIdx] = *dimSize;
       }
     }
 
-    SmallVector<int64_t> staticCollapsedShape;
-    std::tie(staticCollapsedShape, std::ignore) =
-        decomposeMixedValues(collapsedShape);
-
     auto newCollapseType = RankedTensorType::get(
-        staticCollapsedShape, padOp.getSource().getType().getElementType());
+        collapsedShape, padOp.getSource().getType().getElementType());
     auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
         loc, newCollapseType, padOp.getSource(), reassociations);
 

>From 0faf0849388e7180e3606f9403a6b72cb07048e6 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Tue, 29 Jul 2025 18:46:34 +0900
Subject: [PATCH 9/9] fix upon review

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 25 ++++++-------------
 1 file changed, 8 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 86e287bae6cf5..a038d3c95c0d5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1217,20 +1217,15 @@ struct FoldReshapeWithProducerPadOpByCollapsing
         collapseOp.getReassociationIndices();
     SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
     SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
-
+    SmallVector<OpFoldResult> newLow, newHigh;
     for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
-      if (reInd.size() > 1) {
-        for (auto dimIdx : reInd) {
-          if (!isConstantIntValue(low[dimIdx], 0) ||
-              !isConstantIntValue(high[dimIdx], 0)) {
-            return failure();
-          }
-        }
+      if (reInd.size() > 1 && llvm::any_of(reInd, [&](int64_t dimIdx) {
+            return !isConstantIntValue(low[dimIdx], 0) ||
+                   !isConstantIntValue(high[dimIdx], 0);
+          })) {
+        return failure();
       }
-    }
 
-    SmallVector<OpFoldResult> newLow, newHigh;
-    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
       newLow.push_back(low[reInd[0]]);
       newHigh.push_back(high[reInd[0]]);
     }
@@ -1244,16 +1239,12 @@ struct FoldReshapeWithProducerPadOpByCollapsing
       OpFoldResult l = low[reInd[0]];
       OpFoldResult h = high[reInd[0]];
       if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) {
-        auto mixedSize =
-            tensor::getMixedSize(rewriter, loc, padOp.getSource(), reInd[0]);
-        auto dimSize = getConstantIntValue(mixedSize);
-        assert(dimSize.has_value() && "Expected static dimension");
-        collapsedShape[inDimIdx] = *dimSize;
+        collapsedShape[inDimIdx] = padOp.getSourceType().getShape()[reInd[0]];
       }
     }
 
     auto newCollapseType = RankedTensorType::get(
-        collapsedShape, padOp.getSource().getType().getElementType());
+        collapsedShape, padOp.getSourceType().getElementType());
     auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
         loc, newCollapseType, padOp.getSource(), reassociations);
 



More information about the Mlir-commits mailing list