[Mlir-commits] [mlir] [mlir][vector] Add n-d deinterleave lowering (PR #94237)
Mubashar Ahmad
llvmlistbot at llvm.org
Tue Jun 4 04:23:55 PDT 2024
https://github.com/mub-at-arm updated https://github.com/llvm/llvm-project/pull/94237
>From b95894d9c3365d863d7a1c70588fcbec5d84b6e5 Mon Sep 17 00:00:00 2001
From: "Mubashar.Ahmad at arm.com" <mubashar.ahmad at arm.com>
Date: Mon, 3 Jun 2024 15:43:35 +0000
Subject: [PATCH] [mlir][vector] Add n-d deinterleave lowering
This patch implements the lowering for vector
deinterleave for vector of n-dimensions. Process
involves unrolling the n-d vector to a series
of one-dimensional vectors. The deinterleave
operation is then used on these vectors.
From:
```
%0, %1 = vector.deinterleave %a : vector<2x[4]xi8> -> vector<2x[2]xi8>
```
To:
```
%2 = llvm.extractvalue %0[0] : !llvm.array<2 x vector<8xf32>>
%3 = llvm.mlir.poison : vector<8xf32>
%4 = llvm.shufflevector %2, %3 [0, 2, 4, 6] : vector<8xf32>
%5 = llvm.shufflevector %2, %3 [1, 3, 5, 7] : vector<8xf32>
%6 = llvm.insertvalue %4, %1[0] : !llvm.array<2 x vector<4xf32>>
%7 = llvm.insertvalue %5, %1[0] : !llvm.array<2 x vector<4xf32>>
%8 = llvm.extractvalue %0[1] : !llvm.array<2 x vector<8xf32>>
...etc.
```
---
.../Transforms/LowerVectorInterleave.cpp | 39 +++++++++++++++++++
.../VectorToLLVM/vector-to-llvm.mlir | 36 +++++++++++++++++
2 files changed, 75 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 77c97b2f1497c..ff1f5b89e358f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -79,6 +79,43 @@ class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
int64_t targetRank = 1;
};
+class UnrollDeinterleaveOp final
+ : public OpRewritePattern<vector::DeinterleaveOp> {
+public:
+ UnrollDeinterleaveOp(int64_t targetRank, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit), targetRank(targetRank){};
+
+ LogicalResult matchAndRewrite(vector::DeinterleaveOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResultVectorType();
+ auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
+ if (!unrollIterator)
+ return failure();
+
+ auto loc = op.getLoc();
+ Value evenResult = rewriter.create<arith::ConstantOp>(
+ loc, resultType, rewriter.getZeroAttr(resultType));
+ Value oddResult = rewriter.create<arith::ConstantOp>(
+ loc, resultType, rewriter.getZeroAttr(resultType));
+
+ for (auto position : *unrollIterator) {
+ auto extractSrc =
+ rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
+ auto deinterleave =
+ rewriter.create<vector::DeinterleaveOp>(loc, extractSrc);
+ evenResult = rewriter.create<vector::InsertOp>(
+ loc, deinterleave.getRes1(), evenResult, position);
+ oddResult = rewriter.create<vector::InsertOp>(loc, deinterleave.getRes2(),
+ oddResult, position);
+ }
+ rewriter.replaceOp(op, ValueRange{evenResult, oddResult});
+ return success();
+ }
+
+private:
+ int64_t targetRank = 1;
+};
/// Rewrite vector.interleave op into an equivalent vector.shuffle op, when
/// applicable: `sourceType` must be 1D and non-scalable.
///
@@ -117,6 +154,8 @@ struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
void mlir::vector::populateVectorInterleaveLoweringPatterns(
RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
+ patterns.add<UnrollDeinterleaveOp>(targetRank, patterns.getContext(),
+ benefit);
}
void mlir::vector::populateVectorInterleaveToShufflePatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 12121ea0dd70e..d0b7516f2a437 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2565,6 +2565,42 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi
return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
}
+// CHECK-LABEL: @vector_deinterleave_2d
+// CHECK-SAME: %[[SRC:.*]]: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>)
+func.func @vector_deinterleave_2d(%a: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>) {
+ // CHECK: %[[EXTRACT_ONE:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<8xf32>>
+ // CHECK: %[[POISON_ONE:.*]] = llvm.mlir.poison : vector<8xf32>
+ // CHECK: %[[SHUFFLE_A:.*]] = llvm.shufflevector %[[EXTRACT_ONE]], %[[POISON_ONE]] [0, 2, 4, 6] : vector<8xf32>
+ // CHECK: %[[SHUFFLE_B:.*]] = llvm.shufflevector %[[EXTRACT_ONE]], %[[POISON_ONE]] [1, 3, 5, 7] : vector<8xf32>
+ // CHECK: %[[INSERT_A:.*]] = llvm.insertvalue %[[SHUFFLE_A]], %{{.*}}[0] : !llvm.array<2 x vector<4xf32>>
+ // CHECK: %[[INSERT_B:.*]] = llvm.insertvalue %[[SHUFFLE_B]], %{{.*}}[0] : !llvm.array<2 x vector<4xf32>>
+ // CHECK: %[[EXTRACT_TWO:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<8xf32>>
+ // CHECK: %[[POISON_TWO:.*]] = llvm.mlir.poison : vector<8xf32>
+ // CHECK: %[[SHUFFLE_C:.*]] = llvm.shufflevector %[[EXTRACT_TWO]], %[[POISON_TWO]] [0, 2, 4, 6] : vector<8xf32>
+ // CHECK: %[[SHUFFLE_D:.*]] = llvm.shufflevector %[[EXTRACT_TWO]], %[[POISON_TWO]] [1, 3, 5, 7] : vector<8xf32>
+ // CHECK: %[[INSERT_C:.*]] = llvm.insertvalue %[[SHUFFLE_C]], %[[INSERT_A]][1] : !llvm.array<2 x vector<4xf32>>
+ // CHECK: %[[INSERT_D:.*]] = llvm.insertvalue %[[SHUFFLE_D]], %[[INSERT_B]][1] : !llvm.array<2 x vector<4xf32>>
+ %0, %1 = vector.deinterleave %a : vector<2x8xf32> -> vector<2x4xf32>
+ return %0, %1 : vector<2x4xf32>, vector<2x4xf32>
+}
+
+func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xf32>) -> (vector<2x[4]xf32>, vector<2x[4]xf32>) {
+ // CHECK: %[[EXTRACT_A:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[8]xf32>>
+ // CHECK: %[[VECTOR_ONE:.*]] = "llvm.intr.vector.deinterleave2"(%[[EXTRACT_ONE]]) : (vector<[8]xf32>) -> !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
+ // CHECK: %[[EXTRACT_B:.*]] = llvm.extractvalue %[[VECTOR_ONE]][0] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
+ // CHECK: %[[EXTRACT_C:.*]] = llvm.extractvalue %[[VECTOR_ONE]][1] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
+ // CHECK: %[[INSERT_A:.*]] = llvm.insertvalue %[[EXTRACT_B]], %{{.*}}[0] : !llvm.array<2 x vector<[4]xf32>>
+ // CHECK: %[[INSERT_B:.*]] = llvm.insertvalue %[[EXTRACT_C]], %{{.*}}[0] : !llvm.array<2 x vector<[4]xf32>>
+ // CHECK: %[[EXTRACT_D:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[8]xf32>>
+ // CHECK: %[[VECTOR_TWO:.*]] = "llvm.intr.vector.deinterleave2"(%[[EXTRACT_D]]) : (vector<[8]xf32>) -> !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
+ // CHECK: %[[EXTRACT_E:.*]] = llvm.extractvalue %[[VECTOR_TWO]][0] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
+ // CHECK: %[[EXTRACT_F:.*]] = llvm.extractvalue %[[VECTOR_TWO]][1] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
+ // CHECK: %[[INSERT_C:.*]] = llvm.insertvalue %[[EXTRACT_E]], %[[INSERT_A]][1] : !llvm.array<2 x vector<[4]xf32>>
+ // CHECK: %[[INSERT_D:.*]] = llvm.insertvalue %[[EXTRACT_F]], %[[INSERT_B]][1] : !llvm.array<2 x vector<[4]xf32>>
+ %0, %1 = vector.deinterleave %a : vector<2x[8]xf32> -> vector<2x[4]xf32>
+ return %0, %1 : vector<2x[4]xf32>, vector<2x[4]xf32>
+}
+
// -----
// CHECK-LABEL: func.func @vector_bitcast_2d
More information about the Mlir-commits
mailing list