[Mlir-commits] [mlir] [mlir][ArmSME] Pattern to swap shape_cast(tranpose) with transpose(shape_cast) (PR #100731)
Benjamin Maxwell
llvmlistbot at llvm.org
Fri Jul 26 04:56:50 PDT 2024
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/100731
This applies when the shape_cast is simply for dropping unit dims, and the result rank is >= 2.
This simplifies the transpose making it possible for other ArmSME legalization patterns to handle it.
Example:
```mlir
%0 = vector.transpose %vector, [3, 0, 1, 2]
: vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
%1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
```
```mlir
%0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
%1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
```
>From f91360b27c1d9d0016f40cb9332d38dcb0919af4 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 26 Jul 2024 10:25:46 +0000
Subject: [PATCH] [mlir][ArmSME] Pattern to swap shape_cast(tranpose) with
transpose(shape_cast)
This applies when the shape_cast is simply for dropping unit dims, and
the result rank is >= 2.
This simplifies the transpose making it possible for other ArmSME
legalization patterns to handle it.
Example:
```mlir
%0 = vector.transpose %vector, [3, 0, 1, 2]
: vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
%1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
```
```mlir
%0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
%1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
```
---
.../ArmSME/Transforms/VectorLegalization.cpp | 91 ++++++++++++++++++-
.../Dialect/ArmSME/vector-legalization.mlir | 26 ++++++
2 files changed, 116 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 004428bd343e2..7efe320bd5371 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -774,6 +774,94 @@ struct ConvertIllegalShapeCastOpsToTransposes
}
};
+/// Returns an iterator over the dims (inc scalability) of a VectorType.
+static auto getDims(VectorType vType) {
+ return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
+}
+
+/// Helper to drop unit dims from a VectorType.
+static VectorType dropUnitDims(VectorType vType) {
+ SmallVector<bool> scalableFlags;
+ SmallVector<int64_t> dimSizes;
+ for (auto dim : getDims(vType)) {
+ if (dim == std::make_tuple(1, false))
+ continue;
+ auto [size, scalableFlag] = dim;
+ dimSizes.push_back(size);
+ scalableFlags.push_back(scalableFlag);
+ }
+ return VectorType::get(dimSizes, vType.getElementType(), scalableFlags);
+}
+
+/// A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
+/// shape_cast only drops unit dimensions.
+///
+/// This simplifies the transpose making it possible for other legalization
+/// rewrites to handle it.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %0 = vector.transpose %vector, [3, 0, 1, 2]
+/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+/// %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+/// %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+/// ```
+struct SwapShapeCastOfTranspose : public OpRewritePattern<vector::ShapeCastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ auto transposeOp =
+ shapeCastOp.getSource().getDefiningOp<vector::TransposeOp>();
+ if (!transposeOp)
+ return rewriter.notifyMatchFailure(shapeCastOp, "not TransposeOp");
+
+ auto resultType = shapeCastOp.getResultVectorType();
+ if (resultType.getRank() <= 1)
+ return rewriter.notifyMatchFailure(shapeCastOp, "result rank too low");
+
+ if (resultType != dropUnitDims(shapeCastOp.getSourceVectorType()))
+ return rewriter.notifyMatchFailure(
+ shapeCastOp, "ShapeCastOp changes non-unit dimension(s)");
+
+ auto transposeSourceVectorType = transposeOp.getSourceVectorType();
+ auto transposeSourceDims =
+ llvm::to_vector(getDims(transposeSourceVectorType));
+
+ // Construct a map from dimIdx -> number of dims dropped before dimIdx.
+ SmallVector<int64_t> droppedDimsBefore(transposeSourceVectorType.getRank());
+ int64_t droppedDims = 0;
+ for (auto [i, dim] : llvm::enumerate(transposeSourceDims)) {
+ droppedDimsBefore[i] = droppedDims;
+ if (dim == std::make_tuple(1, false))
+ ++droppedDims;
+ }
+
+ // Drop unit dims from transpose permutation.
+ auto perm = transposeOp.getPermutation();
+ SmallVector<int64_t> newPerm;
+ for (int64_t idx : perm) {
+ if (transposeSourceDims[idx] == std::make_tuple(1, false))
+ continue;
+ newPerm.push_back(idx - droppedDimsBefore[idx]);
+ }
+
+ auto loc = shapeCastOp.getLoc();
+ auto newShapeCastOp = rewriter.create<vector::ShapeCastOp>(
+ loc, dropUnitDims(transposeSourceVectorType), transposeOp.getVector());
+ rewriter.replaceOpWithNewOp<vector::TransposeOp>(shapeCastOp,
+ newShapeCastOp, newPerm);
+ return success();
+ }
+};
+
/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
/// the ZA state. This workaround rewrite to support these transposes when ZA is
/// available.
@@ -939,7 +1027,8 @@ struct VectorLegalizationPass
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
LiftIllegalVectorTransposeToMemory,
ConvertIllegalShapeCastOpsToTransposes,
- LowerIllegalTransposeStoreViaZA>(context);
+ SwapShapeCastOfTranspose, LowerIllegalTransposeStoreViaZA>(
+ context);
// Note: These two patterns are added with a high benefit to ensure:
// - Masked outer products are handled before unmasked ones
// - Multi-tile writes are lowered as a store loop (if possible)
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 458906a187982..adc02adb6e974 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -646,3 +646,29 @@ func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vect
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>, memref<?x?xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: @swap_shape_cast_of_transpose(
+// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
+func.func @swap_shape_cast_of_transpose(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x4xf32> {
+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+ // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+ // CHECK: return %[[TRANSPOSE]]
+ %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+ %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
+ return %1 : vector<[4]x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @swap_shape_cast_of_transpose_units_dims_before_and_after(
+// CHECK-SAME: %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
+func.func @swap_shape_cast_of_transpose_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x4xf32> {
+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
+ // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+ // CHECK: return %[[TRANSPOSE]]
+ %0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
+ %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4x1xf32> to vector<[4]x4xf32>
+ return %1 : vector<[4]x4xf32>
+}
More information about the Mlir-commits
mailing list