[Mlir-commits] [mlir] b87a80d - [mlir][vector] Add n-d deinterleave lowering (#94237)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 7 02:57:04 PDT 2024
Author: Mubashar Ahmad
Date: 2024-06-07T10:57:00+01:00
New Revision: b87a80d4ebca9e1c065f0d2762e500078c4badca
URL: https://github.com/llvm/llvm-project/commit/b87a80d4ebca9e1c065f0d2762e500078c4badca
DIFF: https://github.com/llvm/llvm-project/commit/b87a80d4ebca9e1c065f0d2762e500078c4badca.diff
LOG: [mlir][vector] Add n-d deinterleave lowering (#94237)
This patch implements the lowering for vector
deinterleave for vector of n-dimensions. Process
involves unrolling the n-d vector to a series
of one-dimensional vectors. The deinterleave
operation is then used on these vectors.
From:
```
%0, %1 = vector.deinterleave %a : vector<2x8xi8> -> vector<2x4xi8>
```
To:
```
%cst = arith.constant dense<0> : vector<2x4xi32>
%0 = vector.extract %arg0[0] : vector<8xi32> from vector<2x8xi32>
%res1, %res2 = vector.deinterleave %0 : vector<8xi32> -> vector<4xi32>
%1 = vector.insert %res1, %cst [0] : vector<4xi32> into vector<2x4xi32>
%2 = vector.insert %res2, %cst [0] : vector<4xi32> into vector<2x4xi32>
%3 = vector.extract %arg0[1] : vector<8xi32> from vector<2x8xi32>
%res1_0, %res2_1 = vector.deinterleave %3 : vector<8xi32> -> vector<4xi32>
%4 = vector.insert %res1_0, %1 [1] : vector<4xi32> into vector<2x4xi32>
%5 = vector.insert %res2_1, %2 [1] : vector<4xi32> into vector<2x4xi32>
...etc.
```
Added:
mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir
Modified:
mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 77c97b2f1497c..f7e01c7b12e4f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -79,6 +79,73 @@ class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
int64_t targetRank = 1;
};
+/// A one-shot unrolling of vector.deinterleave to the `targetRank`.
+///
+/// Example:
+///
+/// ```mlir
+/// %0, %1 = vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64>
+/// ```
+/// Would be unrolled to:
+/// ```mlir
+/// %result = arith.constant dense<0> : vector<1x2x3x4xi64>
+/// %0 = vector.extract %a[0, 0, 0] ─┐
+/// : vector<8xi64> from vector<1x2x3x8xi64> |
+/// %1, %2 = vector.deinterleave %0 |
+/// : vector<8xi64> -> vector<4xi64> | -- Initial deinterleave
+/// %3 = vector.insert %1, %result [0, 0, 0] | operation unrolled.
+/// : vector<4xi64> into vector<1x2x3x4xi64> |
+/// %4 = vector.insert %2, %result [0, 0, 0] |
+/// : vector<4xi64> into vector<1x2x3x4xi64> ┘
+/// %5 = vector.extract %a[0, 0, 1] ─┐
+/// : vector<8xi64> from vector<1x2x3x8xi64> |
+/// %6, %7 = vector.deinterleave %5 |
+/// : vector<8xi64> -> vector<4xi64> | -- Recursive pattern for
+/// %8 = vector.insert %6, %3 [0, 0, 1] | subsequent unrolled
+/// : vector<4xi64> into vector<1x2x3x4xi64> | deinterleave
+/// %9 = vector.insert %7, %4 [0, 0, 1] | operations. Repeated
+/// : vector<4xi64> into vector<1x2x3x4xi64> ┘ 5x in this case.
+/// ```
+///
+/// Note: If any leading dimension before the `targetRank` is scalable the
+/// unrolling will stop before the scalable dimension.
+class UnrollDeinterleaveOp final
+ : public OpRewritePattern<vector::DeinterleaveOp> {
+public:
+ UnrollDeinterleaveOp(int64_t targetRank, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit), targetRank(targetRank) {};
+
+ LogicalResult matchAndRewrite(vector::DeinterleaveOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResultVectorType();
+ auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
+ if (!unrollIterator)
+ return failure();
+
+ auto loc = op.getLoc();
+ Value emptyResult = rewriter.create<arith::ConstantOp>(
+ loc, resultType, rewriter.getZeroAttr(resultType));
+ Value evenResult = emptyResult;
+ Value oddResult = emptyResult;
+
+ for (auto position : *unrollIterator) {
+ auto extractSrc =
+ rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
+ auto deinterleave =
+ rewriter.create<vector::DeinterleaveOp>(loc, extractSrc);
+ evenResult = rewriter.create<vector::InsertOp>(
+ loc, deinterleave.getRes1(), evenResult, position);
+ oddResult = rewriter.create<vector::InsertOp>(loc, deinterleave.getRes2(),
+ oddResult, position);
+ }
+ rewriter.replaceOp(op, ValueRange{evenResult, oddResult});
+ return success();
+ }
+
+private:
+ int64_t targetRank = 1;
+};
/// Rewrite vector.interleave op into an equivalent vector.shuffle op, when
/// applicable: `sourceType` must be 1D and non-scalable.
///
@@ -116,7 +183,8 @@ struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
void mlir::vector::populateVectorInterleaveLoweringPatterns(
RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
- patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
+ patterns.add<UnrollInterleaveOp, UnrollDeinterleaveOp>(
+ targetRank, patterns.getContext(), benefit);
}
void mlir::vector::populateVectorInterleaveToShufflePatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 12121ea0dd70e..54dcf07053906 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2565,6 +2565,22 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi
return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
}
+// CHECK-LABEL: @vector_deinterleave_2d
+// CHECK-SAME: %[[SRC:.*]]: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>)
+func.func @vector_deinterleave_2d(%a: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>) {
+ // CHECK: llvm.shufflevector
+ // CHECK-NOT: vector.deinterleave %{{.*}} : vector<2x8xf32>
+ %0, %1 = vector.deinterleave %a : vector<2x8xf32> -> vector<2x4xf32>
+ return %0, %1 : vector<2x4xf32>, vector<2x4xf32>
+}
+
+func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xf32>) -> (vector<2x[4]xf32>, vector<2x[4]xf32>) {
+ // CHECK: llvm.intr.vector.deinterleave2
+ // CHECK-NOT: vector.deinterleave %{{.*}} : vector<2x[8]xf32>
+ %0, %1 = vector.deinterleave %a : vector<2x[8]xf32> -> vector<2x[4]xf32>
+ return %0, %1 : vector<2x[4]xf32>, vector<2x[4]xf32>
+}
+
// -----
// CHECK-LABEL: func.func @vector_bitcast_2d
diff --git a/mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir
new file mode 100644
index 0000000000000..53f4a8970c794
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: @vector_deinterleave_2d
+// CHECK-SAME: %[[SRC:.*]]: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>)
+func.func @vector_deinterleave_2d(%a: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0>
+ // CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0]
+ // CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]]
+ // CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %[[CST]] [0]
+ // CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %[[CST]] [0]
+ // CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1]
+ // CHECK: %[[UNZIP_2:.*]], %[[UNZIP_3:.*]] = vector.deinterleave %[[SRC_1]]
+ // CHECK: %[[RES_2:.*]] = vector.insert %[[UNZIP_2]], %[[RES_0]] [1]
+ // CHECK: %[[RES_3:.*]] = vector.insert %[[UNZIP_3]], %[[RES_1]] [1]
+ // CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x4xi32>, vector<2x4xi32>
+ %0, %1 = vector.deinterleave %a : vector<2x8xi32> -> vector<2x4xi32>
+ return %0, %1 : vector<2x4xi32>, vector<2x4xi32>
+}
+
+// CHECK-LABEL: @vector_deinterleave_2d_scalable
+// CHECK-SAME: %[[SRC:.*]]: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>)
+func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0>
+ // CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0]
+ // CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]]
+ // CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %[[CST]] [0]
+ // CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %[[CST]] [0]
+ // CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1]
+ // CHECK: %[[UNZIP_2:.*]], %[[UNZIP_3:.*]] = vector.deinterleave %[[SRC_1]]
+ // CHECK: %[[RES_2:.*]] = vector.insert %[[UNZIP_2]], %[[RES_0]] [1]
+ // CHECK: %[[RES_3:.*]] = vector.insert %[[UNZIP_3]], %[[RES_1]] [1]
+ // CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x[4]xi32>, vector<2x[4]xi32>
+ %0, %1 = vector.deinterleave %a : vector<2x[8]xi32> -> vector<2x[4]xi32>
+ return %0, %1 : vector<2x[4]xi32>, vector<2x[4]xi32>
+}
+
+// CHECK-LABEL: @vector_deinterleave_4d
+// CHECK-SAME: %[[SRC:.*]]: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>)
+func.func @vector_deinterleave_4d(%a: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>) {
+ // CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0, 0, 0] : vector<8xi64> from vector<1x2x3x8xi64>
+ // CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]] : vector<8xi64> -> vector<4xi64>
+ // CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64>
+ // CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64>
+ // CHECK-COUNT-5: vector.deinterleave %{{.*}} : vector<8xi64> -> vector<4xi64>
+ %0, %1 = vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64>
+ return %0, %1 : vector<1x2x3x4xi64>, vector<1x2x3x4xi64>
+}
+
+// CHECK-LABEL: @vector_deinterleave_nd_with_scalable_dim
+func.func @vector_deinterleave_nd_with_scalable_dim(
+ %a: vector<1x3x[2]x2x3x8xf16>) -> (vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16>) {
+ // The scalable dim blocks unrolling so only the first two dims are unrolled.
+ // CHECK-COUNT-3: vector.deinterleave %{{.*}} : vector<[2]x2x3x8xf16>
+ %0, %1 = vector.deinterleave %a: vector<1x3x[2]x2x3x8xf16> -> vector<1x3x[2]x2x3x4xf16>
+ return %0, %1 : vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.lower_interleave
+ } : !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list