[Mlir-commits] [mlir] [mlir][vector] Remove redundant shape_cast(shape_cast(x)) pattern (PR #135447)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 14 11:44:47 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

<details>
<summary>Changes</summary>

This PR removes one OpRewritePattern  `shape_cast(shape_cast(x)) -> x`  that is already handled by `ShapeCastOp::fold`. 

Note that this might affect downstream users who indirectly call `populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit)` and then use `patterns` with a `GreedyRewriteConfig config` that has `config.fold = false`. (only user I've checked is IREE, that never uses config.fold = false). 

---
Full diff: https://github.com/llvm/llvm-project/pull/135447.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (-4) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (-2) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (-2) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+1-67) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 7de4a6a315750..ce97847172197 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -306,10 +306,6 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
                                   const UnrollVectorOptions &options,
                                   PatternBenefit benefit = 1);
 
-/// Collect a set of vector.shape_cast folding patterns.
-void populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
-                                      PatternBenefit benefit = 1);
-
 /// Collect a set of leading one dimension removal patterns.
 ///
 /// These patterns insert vector.shape_cast to remove leading one dimensions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index fda3baf3aa390..68a44ea889470 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -8,7 +8,6 @@
 
 #include <numeric>
 
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -577,5 +576,4 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
            CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim,
            CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim,
            CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit);
-  populateShapeCastFoldingPatterns(patterns, benefit);
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 62dfd439b0ad1..999fb9c415886 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -976,7 +976,6 @@ void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
   patterns
       .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
           patterns.getContext(), benefit);
-  populateShapeCastFoldingPatterns(patterns);
 }
 
 void mlir::vector::populateFlattenVectorTransferPatterns(
@@ -985,6 +984,5 @@ void mlir::vector::populateFlattenVectorTransferPatterns(
   patterns.add<FlattenContiguousRowMajorTransferReadPattern,
                FlattenContiguousRowMajorTransferWritePattern>(
       patterns.getContext(), targetVectorBitwidth, benefit);
-  populateShapeCastFoldingPatterns(patterns, benefit);
   populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index d50d5fe96f49a..89839d0440d3c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -16,36 +16,24 @@
 #include <cstdint>
 #include <functional>
 #include <optional>
-#include <type_traits>
 
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/VectorInterfaces.h"
 
-#include "llvm/ADT/DenseSet.h"
-#include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/raw_ostream.h"
 
 #define DEBUG_TYPE "vector-to-vector"
 
@@ -71,54 +59,6 @@ static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
 
 namespace {
 
-/// ShapeCastOpFolder folds cancelling ShapeCastOps away.
-//
-// Example:
-//
-//  The following MLIR with cancelling ShapeCastOps:
-//
-//   %0 = source : vector<5x4x2xf32>
-//   %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
-//   %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
-//   %3 = user %2 : vector<5x4x2xf32>
-//
-//  Should canonicalize to the following:
-//
-//   %0 = source : vector<5x4x2xf32>
-//   %1 = user %0 : vector<5x4x2xf32>
-//
-struct ShapeCastOpFolder : public OpRewritePattern<vector::ShapeCastOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
-                                PatternRewriter &rewriter) const override {
-    // Check if 'shapeCastOp' has vector source/result type.
-    auto sourceVectorType =
-        dyn_cast_or_null<VectorType>(shapeCastOp.getSource().getType());
-    auto resultVectorType =
-        dyn_cast_or_null<VectorType>(shapeCastOp.getResult().getType());
-    if (!sourceVectorType || !resultVectorType)
-      return failure();
-
-    // Check if shape cast op source operand is also a shape cast op.
-    auto sourceShapeCastOp = dyn_cast_or_null<vector::ShapeCastOp>(
-        shapeCastOp.getSource().getDefiningOp());
-    if (!sourceShapeCastOp)
-      return failure();
-    auto operandSourceVectorType =
-        cast<VectorType>(sourceShapeCastOp.getSource().getType());
-    auto operandResultVectorType = sourceShapeCastOp.getType();
-
-    // Check if shape cast operations invert each other.
-    if (operandSourceVectorType != resultVectorType ||
-        operandResultVectorType != sourceVectorType)
-      return failure();
-
-    rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource());
-    return success();
-  }
-};
-
 /// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
 /// Ex:
 /// ```
@@ -2113,11 +2053,6 @@ void mlir::vector::populateVectorMaskMaterializationPatterns(
   patterns.add<FoldI1Select>(patterns.getContext(), benefit);
 }
 
-void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
-                                                    PatternBenefit benefit) {
-  patterns.add<ShapeCastOpFolder>(patterns.getContext(), benefit);
-}
-
 void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   // TODO: Consider either:
@@ -2126,8 +2061,7 @@ void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
   //  * better naming to distinguish this and
   //    populateVectorTransferCollapseInnerMostContiguousDimsPatterns.
   patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
-               DropUnitDimsFromTransposeOp, ShapeCastOpFolder>(
-      patterns.getContext(), benefit);
+               DropUnitDimsFromTransposeOp>(patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateBubbleVectorBitCastOpPatterns(

``````````

</details>


https://github.com/llvm/llvm-project/pull/135447


More information about the Mlir-commits mailing list