[Mlir-commits] [mlir] [mlir][ArmSME] Rewrite illegal `shape_casts` to `vector.transpose` ops (PR #82985)
Benjamin Maxwell
llvmlistbot at llvm.org
Thu Mar 7 08:32:30 PST 2024
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/82985
>From 3f032032cd847f399b56bfe0f7c8cd77bdf429e2 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 26 Feb 2024 10:51:41 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Rewrite illegal `shape_casts` to
`vector.transpose` ops
This adds a rewrite that converts illegal 2D unit-dim `shape_casts`
into `vector.transpose` ops.
E.g.
```mlir
// Case 1:
%a = vector.shape_cast %0 : vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%b = vector.shape_cast %1 : vector<[4]x1xf32> to vector<[4]xf32>
```
Becomes:
```
// Case 1:
%a = vector.transpose %0 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%t = vector.transpose %1 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
%b = vector.shape_cast %t : vector<1x[4]xf32> to vector<[4]xf32>
```
Various lowerings and drop unit-dims patterns add such shape_casts,
however, if they do not cancel out (which they likely won't if we've
reached the vector-legalization pass) they will prevent lowering the IR.
Rewriting them as a transpose gives `LiftIllegalVectorTransposeToMemory`
a chance to eliminate the illegal types.
---
.../ArmSME/Transforms/VectorLegalization.cpp | 85 ++++++++++++++++---
.../Dialect/ArmSME/vector-legalization.mlir | 45 ++++++++++
2 files changed, 116 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 11f8bc04b21844..f06d68bcf661a3 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -46,6 +46,8 @@ static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
"op mask is unsupported for legalization/decomposition");
static constexpr StringLiteral
kMatchFailureNonPermutationMap("op affine map is not a permutation");
+static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
+ "expected transpose from illegal type to legal type");
/// An SMESubTile represents a single SME-sized sub-tile from decomposing a
/// larger vector type. The (`row`, `col`) are the position of the tile in the
@@ -416,6 +418,17 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
}
};
+/// A vector type where no fixed dimension comes after a scalable dimension.
+bool isLegalVectorType(VectorType vType) {
+ bool seenFixedDim = false;
+ for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
+ seenFixedDim |= !scalableFlag;
+ if (seenFixedDim && scalableFlag)
+ return false;
+ }
+ return true;
+}
+
/// Lifts an illegal vector.transpose and vector.transfer_read to a
/// memref.subview + memref.transpose, followed by a legal read.
///
@@ -448,16 +461,6 @@ struct LiftIllegalVectorTransposeToMemory
: public OpRewritePattern<vector::TransposeOp> {
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
- static bool isIllegalVectorType(VectorType vType) {
- bool seenFixedDim = false;
- for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
- seenFixedDim |= !scalableFlag;
- if (seenFixedDim && scalableFlag)
- return true;
- }
- return false;
- }
-
static Value getExtensionSource(Operation *op) {
if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
return op->getOperand(0);
@@ -468,9 +471,9 @@ struct LiftIllegalVectorTransposeToMemory
PatternRewriter &rewriter) const override {
auto sourceType = transposeOp.getSourceVectorType();
auto resultType = transposeOp.getResultVectorType();
- if (!isIllegalVectorType(sourceType) || isIllegalVectorType(resultType))
- return rewriter.notifyMatchFailure(
- transposeOp, "expected transpose from illegal type to legal type");
+ if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
+ return rewriter.notifyMatchFailure(transposeOp,
+ kMatchFailureNotIllegalToLegal);
// Look through extend for transfer_read.
Value maybeRead = transposeOp.getVector();
@@ -556,6 +559,59 @@ struct LiftIllegalVectorTransposeToMemory
}
};
+/// A rewrite to turn unit dim transpose-like vector.shape_casts into
+/// vector.transposes. The shape_cast has to be from an illegal vector type to a
+/// legal one (as defined by isLegalVectorType).
+///
+/// The reasoning for this is if we've got to this pass and we still have
+/// shape_casts of illegal types, then they likely will not cancel out. Turning
+/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
+/// eliminate them.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+/// ```
+struct ConvertIllegalShapeCastOpsToTransposes
+ : public OpRewritePattern<vector::ShapeCastOp> {
+ using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ auto sourceType = shapeCastOp.getSourceVectorType();
+ auto resultType = shapeCastOp.getResultVectorType();
+ if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
+ return rewriter.notifyMatchFailure(shapeCastOp,
+ kMatchFailureNotIllegalToLegal);
+
+ // Note: If we know the that is source is an illegal vector type (and 2D)
+ // then dim 0 is scalable and dim 1 is fixed.
+ if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
+ return rewriter.notifyMatchFailure(
+ shapeCastOp, "expected source to be a 2D scalable vector with a "
+ "trailing unit dim");
+
+ auto loc = shapeCastOp.getLoc();
+ auto transpose = rewriter.create<vector::TransposeOp>(
+ loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
+
+ if (resultType.getRank() == 1)
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
+ transpose);
+ else
+ rewriter.replaceOp(shapeCastOp, transpose);
+
+ return success();
+ }
+};
+
struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
@@ -576,7 +632,8 @@ struct VectorLegalizationPass
});
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
- LiftIllegalVectorTransposeToMemory>(context);
+ LiftIllegalVectorTransposeToMemory,
+ ConvertIllegalShapeCastOpsToTransposes>(context);
// Note: High benefit to ensure masked outer products are lowered first.
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
converter, context, 1024);
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index bf0b58ff4cf073..f8be697548c197 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -388,3 +388,48 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v
%0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
return %0 : vector<1x[4]xf32>
}
+
+// -----
+
+// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
+func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
+ // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+ %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
+ return %0 : vector<1x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
+func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
+ // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+ // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
+ %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
+ return %0 : vector<[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
+func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
+ // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
+ // CHECK-NOT: vector.shape_cast
+ %pad = arith.constant 0.0 : f32
+ %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
+ %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
+ return %cast : vector<1x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
+func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
+ // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
+ // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
+ %pad = arith.constant 0.0 : f32
+ %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
+ %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
+ return %cast : vector<[4]xf32>
+}
>From 517e5f0552ac41417f7252a3cf3ce9f75ac5df3f Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 7 Mar 2024 16:29:05 +0000
Subject: [PATCH 2/2] Fixups
---
mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index f06d68bcf661a3..31500c62c0d600 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -591,7 +591,7 @@ struct ConvertIllegalShapeCastOpsToTransposes
return rewriter.notifyMatchFailure(shapeCastOp,
kMatchFailureNotIllegalToLegal);
- // Note: If we know the that is source is an illegal vector type (and 2D)
+ // Note: If we know that `sourceType` is an illegal vector type (and 2D)
// then dim 0 is scalable and dim 1 is fixed.
if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
return rewriter.notifyMatchFailure(
More information about the Mlir-commits
mailing list