[Mlir-commits] [mlir] d1fc59c - [mlir][ArmSME] Rewrite illegal `shape_casts` to `vector.transpose` ops (#82985)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 7 09:04:17 PST 2024


Author: Benjamin Maxwell
Date: 2024-03-07T17:04:12Z
New Revision: d1fc59c3b5c5ce292a6060d7a5545094cdf1b5fc

URL: https://github.com/llvm/llvm-project/commit/d1fc59c3b5c5ce292a6060d7a5545094cdf1b5fc
DIFF: https://github.com/llvm/llvm-project/commit/d1fc59c3b5c5ce292a6060d7a5545094cdf1b5fc.diff

LOG: [mlir][ArmSME] Rewrite illegal `shape_casts` to `vector.transpose` ops (#82985)

This adds a rewrite that converts illegal 2D unit-dim `shape_casts` into
`vector.transpose` ops.

E.g.

```mlir
// Case 1:
%a = vector.shape_cast %0 : vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%b = vector.shape_cast %1 : vector<[4]x1xf32> to vector<[4]xf32>
```

Becomes:

```mlir
// Case 1:
%a = vector.transpose %0 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%t = vector.transpose %1 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
%b = vector.shape_cast %t : vector<1x[4]xf32> to vector<[4]xf32>
```

Various lowerings and drop unit-dims patterns add such shape_casts,
however, if they do not cancel out (which they likely won't if we've
reached the vector-legalization pass) they will prevent lowering the IR.

Rewriting them as a transpose gives `LiftIllegalVectorTransposeToMemory`
a chance to eliminate the illegal types.

Added: 
    

Modified: 
    mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
    mlir/test/Dialect/ArmSME/vector-legalization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 11f8bc04b21844..31500c62c0d600 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -46,6 +46,8 @@ static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
     "op mask is unsupported for legalization/decomposition");
 static constexpr StringLiteral
     kMatchFailureNonPermutationMap("op affine map is not a permutation");
+static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
+    "expected transpose from illegal type to legal type");
 
 /// An SMESubTile represents a single SME-sized sub-tile from decomposing a
 /// larger vector type. The (`row`, `col`) are the position of the tile in the
@@ -416,6 +418,17 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
   }
 };
 
+/// A vector type where no fixed dimension comes after a scalable dimension.
+bool isLegalVectorType(VectorType vType) {
+  bool seenFixedDim = false;
+  for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
+    seenFixedDim |= !scalableFlag;
+    if (seenFixedDim && scalableFlag)
+      return false;
+  }
+  return true;
+}
+
 /// Lifts an illegal vector.transpose and vector.transfer_read to a
 /// memref.subview + memref.transpose, followed by a legal read.
 ///
@@ -448,16 +461,6 @@ struct LiftIllegalVectorTransposeToMemory
     : public OpRewritePattern<vector::TransposeOp> {
   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
 
-  static bool isIllegalVectorType(VectorType vType) {
-    bool seenFixedDim = false;
-    for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
-      seenFixedDim |= !scalableFlag;
-      if (seenFixedDim && scalableFlag)
-        return true;
-    }
-    return false;
-  }
-
   static Value getExtensionSource(Operation *op) {
     if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
       return op->getOperand(0);
@@ -468,9 +471,9 @@ struct LiftIllegalVectorTransposeToMemory
                                 PatternRewriter &rewriter) const override {
     auto sourceType = transposeOp.getSourceVectorType();
     auto resultType = transposeOp.getResultVectorType();
-    if (!isIllegalVectorType(sourceType) || isIllegalVectorType(resultType))
-      return rewriter.notifyMatchFailure(
-          transposeOp, "expected transpose from illegal type to legal type");
+    if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
+      return rewriter.notifyMatchFailure(transposeOp,
+                                         kMatchFailureNotIllegalToLegal);
 
     // Look through extend for transfer_read.
     Value maybeRead = transposeOp.getVector();
@@ -556,6 +559,59 @@ struct LiftIllegalVectorTransposeToMemory
   }
 };
 
+/// A rewrite to turn unit dim transpose-like vector.shape_casts into
+/// vector.transposes. The shape_cast has to be from an illegal vector type to a
+/// legal one (as defined by isLegalVectorType).
+///
+/// The reasoning for this is if we've got to this pass and we still have
+/// shape_casts of illegal types, then they likely will not cancel out. Turning
+/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
+/// eliminate them.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+///  ```
+struct ConvertIllegalShapeCastOpsToTransposes
+    : public OpRewritePattern<vector::ShapeCastOp> {
+  using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = shapeCastOp.getSourceVectorType();
+    auto resultType = shapeCastOp.getResultVectorType();
+    if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
+      return rewriter.notifyMatchFailure(shapeCastOp,
+                                         kMatchFailureNotIllegalToLegal);
+
+    // Note: If we know that `sourceType` is an illegal vector type (and 2D)
+    // then dim 0 is scalable and dim 1 is fixed.
+    if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
+      return rewriter.notifyMatchFailure(
+          shapeCastOp, "expected source to be a 2D scalable vector with a "
+                       "trailing unit dim");
+
+    auto loc = shapeCastOp.getLoc();
+    auto transpose = rewriter.create<vector::TransposeOp>(
+        loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
+
+    if (resultType.getRank() == 1)
+      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
+                                                       transpose);
+    else
+      rewriter.replaceOp(shapeCastOp, transpose);
+
+    return success();
+  }
+};
+
 struct VectorLegalizationPass
     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
   void runOnOperation() override {
@@ -576,7 +632,8 @@ struct VectorLegalizationPass
         });
 
     patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
-                 LiftIllegalVectorTransposeToMemory>(context);
+                 LiftIllegalVectorTransposeToMemory,
+                 ConvertIllegalShapeCastOpsToTransposes>(context);
     // Note: High benefit to ensure masked outer products are lowered first.
     patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
         converter, context, 1024);

diff  --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index bf0b58ff4cf073..f8be697548c197 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -388,3 +388,48 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v
   %0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
   return %0 : vector<1x[4]xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
+// CHECK-SAME:                                      %[[VEC:.*]]: vector<[4]x1xf32>)
+func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
+  // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+  %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
+  return %0 : vector<1x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
+// CHECK-SAME:                                      %[[VEC:.*]]: vector<[4]x1xf32>)
+func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
+  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+  // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
+  %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
+  return %0 : vector<[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
+func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
+  // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
+  // CHECK-NOT: vector.shape_cast
+  %pad = arith.constant 0.0 : f32
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
+  %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
+  return %cast : vector<1x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
+func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
+  // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
+  // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
+  %pad = arith.constant 0.0 : f32
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
+  %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
+  return %cast : vector<[4]xf32>
+}


        


More information about the Mlir-commits mailing list