[Mlir-commits] [mlir] [mlir][ArmSME] Rewrite illegal `shape_casts` to `vector.transpose` ops (PR #82985)

Benjamin Maxwell llvmlistbot at llvm.org
Mon Feb 26 03:33:09 PST 2024


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/82985

This adds a rewrite that converts illegal 2D unit-dim `shape_casts` into `vector.transpose` ops.

E.g.

```mlir
// Case 1:
%a = vector.shape_cast %0 : vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%b = vector.shape_cast %1 : vector<[4]x1xf32> to vector<[4]xf32>
```

Becomes:

```
// Case 1:
%a = vector.transpose %0 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%t = vector.transpose %1 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
%b = vector.shape_cast %t : vector<1x[4]xf32> to vector<[4]xf32>
```

Various lowerings and drop unit-dims patterns add such shape_casts, however, if they do not cancel out (which they likely won't if we've reached the vector-legalization pass) they will prevent lowering the IR.

Rewriting them as a transpose gives `LiftIllegalVectorTransposeToMemory` a chance to eliminate the illegal types.

>From 9865dd54ea1a65c6c3a96b43542e1be6bc41bff3 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 26 Feb 2024 10:51:41 +0000
Subject: [PATCH] [mlir][ArmSME] Rewrite illegal `shape_casts` to
 `vector.transpose` ops

This adds a rewrite that converts illegal 2D unit-dim `shape_casts`
into `vector.transpose` ops.

E.g.

```mlir
// Case 1:
%a = vector.shape_cast %0 : vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%b = vector.shape_cast %1 : vector<[4]x1xf32> to vector<[4]xf32>
```

Becomes:

```
// Case 1:
%a = vector.transpose %0 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%t = vector.transpose %1 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
%b = vector.shape_cast %t : vector<1x[4]xf32> to vector<[4]xf32>
```

Various lowerings and drop unit-dims patterns add such shape_casts,
however, if they do not cancel out (which they likely won't if we've
reached the vector-legalization pass) they will prevent lowering the IR.

Rewriting them as a transpose gives `LiftIllegalVectorTransposeToMemory`
a chance to eliminate the illegal types.
---
 .../ArmSME/Transforms/VectorLegalization.cpp  | 85 ++++++++++++++++---
 .../Dialect/ArmSME/vector-legalization.mlir   | 45 ++++++++++
 2 files changed, 116 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 11f8bc04b21844..55b20e5a477d4e 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -46,6 +46,8 @@ static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
     "op mask is unsupported for legalization/decomposition");
 static constexpr StringLiteral
     kMatchFailureNonPermutationMap("op affine map is not a permutation");
+static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
+    "expected transpose from illegal type to legal type");
 
 /// An SMESubTile represents a single SME-sized sub-tile from decomposing a
 /// larger vector type. The (`row`, `col`) are the position of the tile in the
@@ -416,6 +418,17 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
   }
 };
 
+/// A vector type where no fixed dimension comes after a scalable dimension.
+bool isLegalVectorType(VectorType vType) {
+  bool seenFixedDim = false;
+  for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
+    seenFixedDim |= !scalableFlag;
+    if (seenFixedDim && scalableFlag)
+      return false;
+  }
+  return true;
+}
+
 /// Lifts an illegal vector.transpose and vector.transfer_read to a
 /// memref.subview + memref.transpose, followed by a legal read.
 ///
@@ -448,16 +461,6 @@ struct LiftIllegalVectorTransposeToMemory
     : public OpRewritePattern<vector::TransposeOp> {
   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
 
-  static bool isIllegalVectorType(VectorType vType) {
-    bool seenFixedDim = false;
-    for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
-      seenFixedDim |= !scalableFlag;
-      if (seenFixedDim && scalableFlag)
-        return true;
-    }
-    return false;
-  }
-
   static Value getExtensionSource(Operation *op) {
     if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
       return op->getOperand(0);
@@ -468,9 +471,9 @@ struct LiftIllegalVectorTransposeToMemory
                                 PatternRewriter &rewriter) const override {
     auto sourceType = transposeOp.getSourceVectorType();
     auto resultType = transposeOp.getResultVectorType();
-    if (!isIllegalVectorType(sourceType) || isIllegalVectorType(resultType))
-      return rewriter.notifyMatchFailure(
-          transposeOp, "expected transpose from illegal type to legal type");
+    if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
+      return rewriter.notifyMatchFailure(transposeOp,
+                                         kMatchFailureNotIllegalToLegal);
 
     // Look through extend for transfer_read.
     Value maybeRead = transposeOp.getVector();
@@ -556,6 +559,59 @@ struct LiftIllegalVectorTransposeToMemory
   }
 };
 
+/// A rewrite to turn unit dim transpose-like vector.shape_cast into a
+/// vector.transpose. The shape_cast has to be from an illegal vector type to a
+/// legal one (as defined by isLegalVectorType).
+///
+/// The reasoning for this is if we've got to this pass and we still have
+/// shape_casts of illegal types, then they likely will not cancel out. Turning
+/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
+/// eliminate them.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+///  ```
+struct ConvertIllegalShapeCastOpsToTransposes
+    : public OpRewritePattern<vector::ShapeCastOp> {
+  using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = shapeCastOp.getSourceVectorType();
+    auto resultType = shapeCastOp.getResultVectorType();
+    if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
+      return rewriter.notifyMatchFailure(shapeCastOp,
+                                         kMatchFailureNotIllegalToLegal);
+
+    // Note: If we know the that is source is an illegal vector type (and 2D)
+    // then dim 0 is scalable and dim 1 is fixed.
+    if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
+      return rewriter.notifyMatchFailure(
+          shapeCastOp, "expected source to be a 2D scalable vector with a "
+                       "trailing unit dim");
+
+    auto loc = shapeCastOp.getLoc();
+    auto transpose = rewriter.create<vector::TransposeOp>(
+        loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
+
+    if (resultType.getRank() == 1)
+      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
+                                                       transpose);
+    else
+      rewriter.replaceOp(shapeCastOp, transpose);
+
+    return success();
+  }
+};
+
 struct VectorLegalizationPass
     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
   void runOnOperation() override {
@@ -576,7 +632,8 @@ struct VectorLegalizationPass
         });
 
     patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
-                 LiftIllegalVectorTransposeToMemory>(context);
+                 LiftIllegalVectorTransposeToMemory,
+                 ConvertIllegalShapeCastOpsToTransposes>(context);
     // Note: High benefit to ensure masked outer products are lowered first.
     patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
         converter, context, 1024);
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index bf0b58ff4cf073..f8be697548c197 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -388,3 +388,48 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v
   %0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
   return %0 : vector<1x[4]xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
+// CHECK-SAME:                                      %[[VEC:.*]]: vector<[4]x1xf32>)
+func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
+  // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+  %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
+  return %0 : vector<1x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
+// CHECK-SAME:                                      %[[VEC:.*]]: vector<[4]x1xf32>)
+func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
+  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+  // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
+  %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
+  return %0 : vector<[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
+func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
+  // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
+  // CHECK-NOT: vector.shape_cast
+  %pad = arith.constant 0.0 : f32
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
+  %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
+  return %cast : vector<1x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
+func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
+  // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
+  // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
+  %pad = arith.constant 0.0 : f32
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
+  %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
+  return %cast : vector<[4]xf32>
+}



More information about the Mlir-commits mailing list