[Mlir-commits] [mlir] [mlir][vector] Add `SwapShapeCastOfTranspose` canonicalizer pattern (PR #100933)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jul 28 07:19:30 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the shape_cast only drops unit dimensions.

This simplifies the transpose making it more likely to be matched by further patterns.

Example:

BEFORE:
```mlir
%0 = vector.transpose %vector, [3, 0, 1, 2]
       : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
%1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
```

AFTER:
```mlir
%0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
%1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
```

Note: This moves this pattern from the ArmSME dialect to a general vector pattern as it is useful for lowerings outside of ArmSME.

---
Full diff: https://github.com/llvm/llvm-project/pull/100933.diff


4 Files Affected:

- (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+3-90) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+89-1) 
- (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (-26) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+26) 


``````````diff
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 53df7af00aee8..ed6cd3d0cdbbc 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -774,94 +774,6 @@ struct ConvertIllegalShapeCastOpsToTransposes
   }
 };
 
-/// Returns an iterator over the dims (inc scalability) of a VectorType.
-static auto getDims(VectorType vType) {
-  return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
-}
-
-/// Helper to drop (fixed-size) unit dims from a VectorType.
-static VectorType dropUnitDims(VectorType vType) {
-  SmallVector<bool> scalableFlags;
-  SmallVector<int64_t> dimSizes;
-  for (auto dim : getDims(vType)) {
-    if (dim == std::make_tuple(1, false))
-      continue;
-    auto [size, scalableFlag] = dim;
-    dimSizes.push_back(size);
-    scalableFlags.push_back(scalableFlag);
-  }
-  return VectorType::get(dimSizes, vType.getElementType(), scalableFlags);
-}
-
-/// A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
-/// shape_cast only drops unit dimensions.
-///
-/// This simplifies the transpose making it possible for other legalization
-/// rewrites to handle it.
-///
-/// Example:
-///
-///  BEFORE:
-///  ```mlir
-///  %0 = vector.transpose %vector, [3, 0, 1, 2]
-///         : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
-///  %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
-///  ```
-///
-///  AFTER:
-///  ```mlir
-///  %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
-///  %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
-///  ```
-struct SwapShapeCastOfTranspose : public OpRewritePattern<vector::ShapeCastOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
-                                PatternRewriter &rewriter) const override {
-    auto transposeOp =
-        shapeCastOp.getSource().getDefiningOp<vector::TransposeOp>();
-    if (!transposeOp)
-      return rewriter.notifyMatchFailure(shapeCastOp, "not TransposeOp");
-
-    auto resultType = shapeCastOp.getResultVectorType();
-    if (resultType.getRank() <= 1)
-      return rewriter.notifyMatchFailure(shapeCastOp, "result rank too low");
-
-    if (resultType != dropUnitDims(shapeCastOp.getSourceVectorType()))
-      return rewriter.notifyMatchFailure(
-          shapeCastOp, "ShapeCastOp changes non-unit dimension(s)");
-
-    auto transposeSourceVectorType = transposeOp.getSourceVectorType();
-    auto transposeSourceDims =
-        llvm::to_vector(getDims(transposeSourceVectorType));
-
-    // Construct a map from dimIdx -> number of dims dropped before dimIdx.
-    SmallVector<int64_t> droppedDimsBefore(transposeSourceVectorType.getRank());
-    int64_t droppedDims = 0;
-    for (auto [i, dim] : llvm::enumerate(transposeSourceDims)) {
-      droppedDimsBefore[i] = droppedDims;
-      if (dim == std::make_tuple(1, false))
-        ++droppedDims;
-    }
-
-    // Drop unit dims from transpose permutation.
-    auto perm = transposeOp.getPermutation();
-    SmallVector<int64_t> newPerm;
-    for (int64_t idx : perm) {
-      if (transposeSourceDims[idx] == std::make_tuple(1, false))
-        continue;
-      newPerm.push_back(idx - droppedDimsBefore[idx]);
-    }
-
-    auto loc = shapeCastOp.getLoc();
-    auto newShapeCastOp = rewriter.create<vector::ShapeCastOp>(
-        loc, dropUnitDims(transposeSourceVectorType), transposeOp.getVector());
-    rewriter.replaceOpWithNewOp<vector::TransposeOp>(shapeCastOp,
-                                                     newShapeCastOp, newPerm);
-    return success();
-  }
-};
-
 /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
 /// the ZA state. This workaround rewrite to support these transposes when ZA is
 /// available.
@@ -1027,8 +939,9 @@ struct VectorLegalizationPass
     patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
                  LiftIllegalVectorTransposeToMemory,
                  ConvertIllegalShapeCastOpsToTransposes,
-                 SwapShapeCastOfTranspose, LowerIllegalTransposeStoreViaZA>(
-        context);
+                 LowerIllegalTransposeStoreViaZA>(context);
+    vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
+
     // Note: These two patterns are added with a high benefit to ensure:
     //  - Masked outer products are handled before unmasked ones
     //  - Multi-tile writes are lowered as a store loop (if possible)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d297c40760cd8..2da46aa86d74b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5480,12 +5480,100 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
   }
 };
 
+/// Returns an iterator over the dims (inc scalability) of a VectorType.
+static auto getDims(VectorType vType) {
+  return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
+}
+
+/// Helper to drop (fixed-size) unit dims from a VectorType.
+static VectorType dropUnitDims(VectorType vType) {
+  SmallVector<bool> scalableFlags;
+  SmallVector<int64_t> dimSizes;
+  for (auto dim : getDims(vType)) {
+    if (dim == std::make_tuple(1, false))
+      continue;
+    auto [size, scalableFlag] = dim;
+    dimSizes.push_back(size);
+    scalableFlags.push_back(scalableFlag);
+  }
+  return VectorType::get(dimSizes, vType.getElementType(), scalableFlags);
+}
+
+/// A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
+/// shape_cast only drops unit dimensions.
+///
+/// This simplifies the transpose making it more likely to be matched by further
+/// patterns.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %0 = vector.transpose %vector, [3, 0, 1, 2]
+///         : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+///  %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+///  %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+///  ```
+struct SwapShapeCastOfTranspose : public OpRewritePattern<vector::ShapeCastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto transposeOp =
+        shapeCastOp.getSource().getDefiningOp<vector::TransposeOp>();
+    if (!transposeOp)
+      return rewriter.notifyMatchFailure(shapeCastOp, "not TransposeOp");
+
+    auto resultType = shapeCastOp.getResultVectorType();
+    if (resultType.getRank() <= 1)
+      return rewriter.notifyMatchFailure(shapeCastOp, "result rank too low");
+
+    if (resultType != dropUnitDims(shapeCastOp.getSourceVectorType()))
+      return rewriter.notifyMatchFailure(
+          shapeCastOp, "ShapeCastOp changes non-unit dimension(s)");
+
+    auto transposeSourceVectorType = transposeOp.getSourceVectorType();
+    auto transposeSourceDims =
+        llvm::to_vector(getDims(transposeSourceVectorType));
+
+    // Construct a map from dimIdx -> number of dims dropped before dimIdx.
+    SmallVector<int64_t> droppedDimsBefore(transposeSourceVectorType.getRank());
+    int64_t droppedDims = 0;
+    for (auto [i, dim] : llvm::enumerate(transposeSourceDims)) {
+      droppedDimsBefore[i] = droppedDims;
+      if (dim == std::make_tuple(1, false))
+        ++droppedDims;
+    }
+
+    // Drop unit dims from transpose permutation.
+    auto perm = transposeOp.getPermutation();
+    SmallVector<int64_t> newPerm;
+    for (int64_t idx : perm) {
+      if (transposeSourceDims[idx] == std::make_tuple(1, false))
+        continue;
+      newPerm.push_back(idx - droppedDimsBefore[idx]);
+    }
+
+    auto loc = shapeCastOp.getLoc();
+    auto newShapeCastOp = rewriter.create<vector::ShapeCastOp>(
+        loc, dropUnitDims(transposeSourceVectorType), transposeOp.getVector());
+    rewriter.replaceOpWithNewOp<vector::TransposeOp>(shapeCastOp,
+                                                     newShapeCastOp, newPerm);
+    return success();
+  }
+};
+
 } // namespace
 
 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
   results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
-              ShapeCastBroadcastFolder>(context);
+              ShapeCastBroadcastFolder, SwapShapeCastOfTranspose>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index adc02adb6e974..458906a187982 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -646,29 +646,3 @@ func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vect
   vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>,  memref<?x?xf32>
   return
 }
-
-// -----
-
-// CHECK-LABEL: @swap_shape_cast_of_transpose(
-// CHECK-SAME:                                %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
-func.func @swap_shape_cast_of_transpose(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x4xf32> {
-  // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
-  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
-  // CHECK: return %[[TRANSPOSE]]
-  %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
-  %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
-  return %1 : vector<[4]x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @swap_shape_cast_of_transpose_units_dims_before_and_after(
-// CHECK-SAME:                                %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
-func.func @swap_shape_cast_of_transpose_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x4xf32> {
-  // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
-  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
-  // CHECK: return %[[TRANSPOSE]]
-  %0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
-  %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4x1xf32> to vector<[4]x4xf32>
-  return %1 : vector<[4]x4xf32>
-}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e71a6eb02ea46..f1a1120bf874e 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -867,6 +867,32 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
 
 // -----
 
+// CHECK-LABEL: @swap_shape_cast_of_transpose(
+// CHECK-SAME:                                %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
+func.func @swap_shape_cast_of_transpose(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x4xf32> {
+  // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+  // CHECK: return %[[TRANSPOSE]]
+  %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+  %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
+  return %1 : vector<[4]x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @swap_shape_cast_of_transpose_units_dims_before_and_after(
+// CHECK-SAME:                                %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
+func.func @swap_shape_cast_of_transpose_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x4xf32> {
+  // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
+  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+  // CHECK: return %[[TRANSPOSE]]
+  %0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
+  %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4x1xf32> to vector<[4]x4xf32>
+  return %1 : vector<[4]x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index

``````````

</details>


https://github.com/llvm/llvm-project/pull/100933


More information about the Mlir-commits mailing list