[Mlir-commits] [mlir] [mlir][vector] Add n-d deinterleave lowering (PR #94237)
Mubashar Ahmad
llvmlistbot at llvm.org
Wed Jun 5 11:49:30 PDT 2024
https://github.com/mub-at-arm updated https://github.com/llvm/llvm-project/pull/94237
>From 471d9a803568d74b67abeba482a77e2b16744182 Mon Sep 17 00:00:00 2001
From: "Mubashar.Ahmad at arm.com" <mubashar.ahmad at arm.com>
Date: Mon, 3 Jun 2024 15:43:35 +0000
Subject: [PATCH] [mlir][vector] Add n-d deinterleave lowering
This patch implements the lowering for vector
deinterleave for vector of n-dimensions. Process
involves unrolling the n-d vector to a series
of one-dimensional vectors. The deinterleave
operation is then used on these vectors.
From:
```
%0, %1 = vector.deinterleave %a : vector<2x[4]xi8> -> vector<2x[2]xi8>
```
To:
```
%2 = llvm.extractvalue %0[0] : !llvm.array<2 x vector<8xf32>>
%3 = llvm.mlir.poison : vector<8xf32>
%4 = llvm.shufflevector %2, %3 [0, 2, 4, 6] : vector<8xf32>
%5 = llvm.shufflevector %2, %3 [1, 3, 5, 7] : vector<8xf32>
%6 = llvm.insertvalue %4, %1[0] : !llvm.array<2 x vector<4xf32>>
%7 = llvm.insertvalue %5, %1[0] : !llvm.array<2 x vector<4xf32>>
%8 = llvm.extractvalue %0[1] : !llvm.array<2 x vector<8xf32>>
...etc.
```
---
.../Transforms/LowerVectorInterleave.cpp | 62 ++++++++++++++++-
.../VectorToLLVM/vector-to-llvm.mlir | 16 +++++
...ctor-deinterleave-lowering-transforms.mlir | 68 +++++++++++++++++++
3 files changed, 145 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 77c97b2f1497c..591f00df7ad18 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -79,6 +79,65 @@ class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
int64_t targetRank = 1;
};
+/// A one-shot unrolling of vector.deinterleave to the `targetRank`.
+///
+/// Example:
+///
+/// ```mlir
+/// vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64>
+/// ```
+/// Would be unrolled to:
+/// ```mlir
+/// %result = arith.constant dense<0> : vector<1x2x3x8xi64>
+/// %0 = vector.extract %a[0, 0, 0] ─┐
+/// : vector<4xi64> from vector<1x2x3x4xi64> | | - Repeated 6x for
+/// %1, %2 = vector.deinterleave %0 : | all leading positions
+/// : vector<8xi64> -> vector<4xi64> |
+/// %3 = vector.insert %1, %result [0, 0, 0] |
+/// : vector<4xi64> into vector<1x2x3x4xi64> |
+/// %3 = vector.insert %2, %result [0, 0, 0] |
+/// : vector<4xi64> into vector<1x2x3x4xi64> ┘
+/// ```
+///
+/// Note: If any leading dimension before the `targetRank` is scalable the
+/// unrolling will stop before the scalable dimension.
+class UnrollDeinterleaveOp final
+ : public OpRewritePattern<vector::DeinterleaveOp> {
+public:
+ UnrollDeinterleaveOp(int64_t targetRank, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit), targetRank(targetRank) {};
+
+ LogicalResult matchAndRewrite(vector::DeinterleaveOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResultVectorType();
+ auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
+ if (!unrollIterator)
+ return failure();
+
+ auto loc = op.getLoc();
+ Value evenResult = rewriter.create<arith::ConstantOp>(
+ loc, resultType, rewriter.getZeroAttr(resultType));
+ Value oddResult = rewriter.create<arith::ConstantOp>(
+ loc, resultType, rewriter.getZeroAttr(resultType));
+
+ for (auto position : *unrollIterator) {
+ auto extractSrc =
+ rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
+ auto deinterleave =
+ rewriter.create<vector::DeinterleaveOp>(loc, extractSrc);
+ evenResult = rewriter.create<vector::InsertOp>(
+ loc, deinterleave.getRes1(), evenResult, position);
+ oddResult = rewriter.create<vector::InsertOp>(loc, deinterleave.getRes2(),
+ oddResult, position);
+ }
+ rewriter.replaceOp(op, ValueRange{evenResult, oddResult});
+ return success();
+ }
+
+private:
+ int64_t targetRank = 1;
+};
/// Rewrite vector.interleave op into an equivalent vector.shuffle op, when
/// applicable: `sourceType` must be 1D and non-scalable.
///
@@ -116,7 +175,8 @@ struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
void mlir::vector::populateVectorInterleaveLoweringPatterns(
RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
- patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
+ patterns.add<UnrollInterleaveOp, UnrollDeinterleaveOp>(
+ targetRank, patterns.getContext(), benefit);
}
void mlir::vector::populateVectorInterleaveToShufflePatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 12121ea0dd70e..54dcf07053906 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2565,6 +2565,22 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi
return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
}
+// CHECK-LABEL: @vector_deinterleave_2d
+// CHECK-SAME: %[[SRC:.*]]: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>)
+func.func @vector_deinterleave_2d(%a: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>) {
+ // CHECK: llvm.shufflevector
+ // CHECK-NOT: vector.deinterleave %{{.*}} : vector<2x8xf32>
+ %0, %1 = vector.deinterleave %a : vector<2x8xf32> -> vector<2x4xf32>
+ return %0, %1 : vector<2x4xf32>, vector<2x4xf32>
+}
+
+func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xf32>) -> (vector<2x[4]xf32>, vector<2x[4]xf32>) {
+ // CHECK: llvm.intr.vector.deinterleave2
+ // CHECK-NOT: vector.deinterleave %{{.*}} : vector<2x[8]xf32>
+ %0, %1 = vector.deinterleave %a : vector<2x[8]xf32> -> vector<2x[4]xf32>
+ return %0, %1 : vector<2x[4]xf32>, vector<2x[4]xf32>
+}
+
// -----
// CHECK-LABEL: func.func @vector_bitcast_2d
diff --git a/mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir
new file mode 100644
index 0000000000000..b3335863043b0
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: @vector_deinterleave_2d
+// CHECK-SAME: %[[SRC:.*]]: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>)
+func.func @vector_deinterleave_2d(%a: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0>
+ // CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0]
+ // CHECK: %[[ZIP_0:.*]], %[[ZIP_1:.*]] = vector.deinterleave %[[SRC_0]]
+ // CHECK: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
+ // CHECK: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[CST]] [0]
+ // CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1]
+ // CHECK: %[[ZIP_1:.*]], %[[ZIP_2:.*]] = vector.deinterleave %[[SRC_1]]
+ // CHECK: %[[RES_2:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
+ // CHECK: %[[RES_3:.*]] = vector.insert %[[ZIP_2]], %[[RES_1]] [1]
+ // CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x4xi32>, vector<2x4xi32>
+ %0, %1 = vector.deinterleave %a : vector<2x8xi32> -> vector<2x4xi32>
+ return %0, %1 : vector<2x4xi32>, vector<2x4xi32>
+}
+
+// CHECK-LABEL: @vector_deinterleave_2d_scalable
+// CHECK-SAME: %[[SRC:.*]]: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>)
+func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0>
+ // CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0]
+ // CHECK: %[[ZIP_0:.*]], %[[ZIP_1:.*]] = vector.deinterleave %[[SRC_0]]
+ // CHECK: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
+ // CHECK: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[CST]] [0]
+ // CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1]
+ // CHECK: %[[ZIP_1:.*]], %[[ZIP_2:.*]] = vector.deinterleave %[[SRC_1]]
+ // CHECK: %[[RES_2:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
+ // CHECK: %[[RES_3:.*]] = vector.insert %[[ZIP_2]], %[[RES_1]] [1]
+ // CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x[4]xi32>, vector<2x[4]xi32>
+ %0, %1 = vector.deinterleave %a : vector<2x[8]xi32> -> vector<2x[4]xi32>
+ return %0, %1 : vector<2x[4]xi32>, vector<2x[4]xi32>
+}
+
+// CHECK-LABEL: @vector_deinterleave_4d
+// CHECK-SAME: %[[SRC:.*]]: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>)
+func.func @vector_deinterleave_4d(%a: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>) {
+ // CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0, 0, 0] : vector<8xi64> from vector<1x2x3x8xi64>
+ // CHECK: %[[ZIP_0:.*]], %[[ZIP_1:.*]] = vector.deinterleave %[[SRC_0]] : vector<8xi64> -> vector<4xi64>
+ // CHECK: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64>
+ // CHECK: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64>
+ // CHECK-COUNT-5: vector.deinterleave %{{.*}} : vector<8xi64> -> vector<4xi64>
+ %0, %1 = vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64>
+ return %0, %1 : vector<1x2x3x4xi64>, vector<1x2x3x4xi64>
+}
+
+// CHECK-LABEL: @vector_deinterleave_nd_with_scalable_dim
+func.func @vector_deinterleave_nd_with_scalable_dim(
+ %a: vector<1x3x[2]x2x3x8xf16>) -> (vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16>) {
+ // The scalable dim blocks unrolling so only the first two dims are unrolled.
+ // CHECK-COUNT-3: vector.deinterleave %{{.*}} : vector<[2]x2x3x8xf16>
+ %0, %1 = vector.deinterleave %a: vector<1x3x[2]x2x3x8xf16> -> vector<1x3x[2]x2x3x4xf16>
+ return %0, %1 : vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.lower_interleave
+ } : !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list