[Mlir-commits] [mlir] Revert "[mlir][Vector] Add fold transpose(shape_cast) -> shape_cast (#73951)" (PR #74579)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 6 02:14:07 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Adam Paszke (apaszke)

<details>
<summary>Changes</summary>

…#<!-- -->73951)"

This reverts commit f42b7615b862bb5f77981f619f92877eb20adf54.

The fold pattern is incorrect, because it does not even look at the permutation of non-unit dims and is happy to replace a pattern such as
```
%22 = vector.shape_cast %21 : vector<1x256x256xf32> to vector<256x256xf32>
%23 = vector.transpose %22, [1, 0] : vector<256x256xf32> to vector<256x256xf32>
```
with
```
%22 = vector.shape_cast %21 : vector<1x256x256xf32> to vector<256x256xf32>
```
which is obviously incorrect.

---
Full diff: https://github.com/llvm/llvm-project/pull/74579.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+1-46) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (-12) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index caffd344848b3..c462b23e1133f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5548,57 +5548,12 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
-/// Folds transpose(shape_cast) into a new shape_cast, when the transpose just
-/// permutes a unit dim from the result of the shape_cast.
-class FoldTransposeShapeCast : public OpRewritePattern<TransposeOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(TransposeOp transpOp,
-                                PatternRewriter &rewriter) const override {
-    Value transposeSrc = transpOp.getVector();
-    auto shapeCastOp = transposeSrc.getDefiningOp<vector::ShapeCastOp>();
-    if (!shapeCastOp)
-      return rewriter.notifyMatchFailure(
-          transpOp, "TransposeOp source is not ShapeCastOp");
-
-    auto sourceType = transpOp.getSourceVectorType();
-    auto resultType = transpOp.getResultVectorType();
-
-    auto filterUnitDims = [](VectorType type) {
-      return llvm::make_filter_range(
-          llvm::zip_equal(type.getShape(), type.getScalableDims()),
-          [&](auto dim) {
-            auto [size, isScalable] = dim;
-            return size != 1 || isScalable;
-          });
-    };
-
-    auto sourceWithoutUnitDims = filterUnitDims(sourceType);
-    auto resultWithoutUnitDims = filterUnitDims(resultType);
-
-    // If this transpose just permutes a unit dim, then we can fold it into the
-    // shape_cast.
-    for (auto [srcDim, resDim] :
-         llvm::zip_equal(sourceWithoutUnitDims, resultWithoutUnitDims)) {
-      if (srcDim != resDim)
-        return rewriter.notifyMatchFailure(transpOp,
-                                           "TransposeOp permutes non-unit dim");
-    }
-
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resultType,
-                                                     shapeCastOp.getSource());
-
-    return success();
-  };
-};
-
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
-              TransposeFolder, FoldTransposeSplat, FoldTransposeShapeCast>(
-      context);
+              TransposeFolder, FoldTransposeSplat>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6bfb477ecf972..1021c73cc57d3 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -67,18 +67,6 @@ func.func @create_mask_transpose_to_transposed_create_mask(
 
 // -----
 
-// CHECK-LABEL: transposed_unit_dim_shape_cast_to_shape_cast
-//  CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
-func.func @transposed_unit_dim_shape_cast_to_shape_cast(%vec: vector<[4]xf32>) -> vector<1x[4]xf32> {
-  //     CHECK: vector.shape_cast %[[VEC]] : vector<[4]xf32> to vector<1x[4]xf32>
-  // CHECK-NOT: vector.transpose
-  %0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32>
-  %1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-  return %1 : vector<1x[4]xf32>
-}
-
-// -----
-
 // CHECK-LABEL: extract_from_create_mask
 //  CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
 func.func @extract_from_create_mask(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {

``````````

</details>


https://github.com/llvm/llvm-project/pull/74579


More information about the Mlir-commits mailing list