[Mlir-commits] [mlir] [mlir][VectorOps] Add fold vector.shuffle -> vector.interleave (4/4) (PR #80968)
Benjamin Maxwell
llvmlistbot at llvm.org
Wed Mar 6 05:05:18 PST 2024
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/80968
>From d45d47e4a8b0ef29a6b08a729d4cd4622757b4ab Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 7 Feb 2024 10:04:53 +0000
Subject: [PATCH 1/2] [mlir][VectorOps] Add fold vector.shuffle ->
vector.interleave
This folds fixed-size vector.shuffle ops that perform a 1-D interleave
to a vector.interleave operation.
i.e.:
```mlir
%0 = vector.shuffle %a, %b [0, 2, 1, 4] : vector<2xi32>, vector<2xi32>
```
to:
```mlir
%0 = vector.interleave %a, %b : vector<2xi32>
```
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 43 +++++++++++++++++++++-
mlir/test/Dialect/Vector/canonicalize.mlir | 23 ++++++++++++
2 files changed, 65 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5be6a628904cdf..59ca634ce3e3d6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2479,11 +2479,52 @@ class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
}
};
+/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
+/// vector.interleave.
+class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ShuffleOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResultVectorType();
+ if (resultType.isScalable())
+ return rewriter.notifyMatchFailure(
+ op, "ShuffleOp can't represent a scalable interleave");
+
+ if (resultType.getRank() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "ShuffleOp can't represent an n-D interleave");
+
+ VectorType sourceType = op.getV1VectorType();
+ if (sourceType != op.getV2VectorType() ||
+ ArrayRef<int64_t>{sourceType.getNumElements() * 2} !=
+ resultType.getShape()) {
+ return rewriter.notifyMatchFailure(
+ op, "ShuffleOp types don't match an interleave");
+ }
+
+ ArrayAttr shuffleMask = op.getMask();
+ int64_t resultVectorSize = resultType.getNumElements();
+ for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
+ int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
+ int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
+ if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
+ return rewriter.notifyMatchFailure(op,
+ "ShuffleOp mask not interleaving");
+ }
+
+ rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
+ return success();
+ }
+};
+
} // namespace
void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ShuffleSplat, Canonicalize0DShuffleOp>(context);
+ results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e6f045e12e5197..4c73a6271786e6 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2567,3 +2567,26 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
tensor<4x4x4xf32>, vector<1x100x4x5xf32>
return %r : vector<1x100x4x5xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @rank_0_shuffle_to_interleave(
+// CHECK-SAME: %[[LHS:.*]]: vector<f64>, %[[RHS:.*]]: vector<f64>)
+func.func @rank_0_shuffle_to_interleave(%arg0: vector<f64>, %arg1: vector<f64>) -> vector<2xf64>
+{
+ // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<f64>
+ // CHECK: return %[[ZIP]]
+ %0 = vector.shuffle %arg0, %arg1 [0, 1] : vector<f64>, vector<f64>
+ return %0 : vector<2xf64>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @rank_1_shuffle_to_interleave(
+// CHECK-SAME: %[[LHS:.*]]: vector<6xi32>, %[[RHS:.*]]: vector<6xi32>)
+func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi32>) -> vector<12xi32> {
+ // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32>
+ // CHECK: return %[[ZIP]]
+ %0 = vector.shuffle %arg0, %arg1 [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11] : vector<6xi32>, vector<6xi32>
+ return %0 : vector<12xi32>
+}
>From 657461153ca4cb898c1d1110e1a6e979fa2279b6 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 6 Mar 2024 13:04:02 +0000
Subject: [PATCH 2/2] Fixups
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 59ca634ce3e3d6..75f6220ad8f3fa 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2498,8 +2498,7 @@ class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
VectorType sourceType = op.getV1VectorType();
if (sourceType != op.getV2VectorType() ||
- ArrayRef<int64_t>{sourceType.getNumElements() * 2} !=
- resultType.getShape()) {
+ sourceType.getNumElements() * 2 != resultType.getNumElements()) {
return rewriter.notifyMatchFailure(
op, "ShuffleOp types don't match an interleave");
}
More information about the Mlir-commits
mailing list