[Mlir-commits] [mlir] b87a80d - [mlir][vector] Add n-d deinterleave lowering (#94237)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 7 02:57:04 PDT 2024


Author: Mubashar Ahmad
Date: 2024-06-07T10:57:00+01:00
New Revision: b87a80d4ebca9e1c065f0d2762e500078c4badca

URL: https://github.com/llvm/llvm-project/commit/b87a80d4ebca9e1c065f0d2762e500078c4badca
DIFF: https://github.com/llvm/llvm-project/commit/b87a80d4ebca9e1c065f0d2762e500078c4badca.diff

LOG: [mlir][vector] Add n-d deinterleave lowering (#94237)

This patch implements the lowering for vector
deinterleave for vector of n-dimensions. Process
involves unrolling the n-d vector to a series
of one-dimensional vectors. The deinterleave
operation is then used on these vectors.

From:
```
%0, %1 = vector.deinterleave %a : vector<2x8xi8> -> vector<2x4xi8>
```

To:
```
%cst = arith.constant dense<0> : vector<2x4xi32>
%0 = vector.extract %arg0[0] : vector<8xi32> from vector<2x8xi32>
%res1, %res2 = vector.deinterleave %0 : vector<8xi32> -> vector<4xi32>
%1 = vector.insert %res1, %cst [0] : vector<4xi32> into vector<2x4xi32>
%2 = vector.insert %res2, %cst [0] : vector<4xi32> into vector<2x4xi32>
%3 = vector.extract %arg0[1] : vector<8xi32> from vector<2x8xi32>
%res1_0, %res2_1 = vector.deinterleave %3 : vector<8xi32> -> vector<4xi32>
%4 = vector.insert %res1_0, %1 [1] : vector<4xi32> into vector<2x4xi32>
%5 = vector.insert %res2_1, %2 [1] : vector<4xi32> into vector<2x4xi32>
...etc.
```

Added: 
    mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir

Modified: 
    mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 77c97b2f1497c..f7e01c7b12e4f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -79,6 +79,73 @@ class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
   int64_t targetRank = 1;
 };
 
+/// A one-shot unrolling of vector.deinterleave to the `targetRank`.
+///
+/// Example:
+///
+/// ```mlir
+/// %0, %1 = vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64>
+/// ```
+/// Would be unrolled to:
+/// ```mlir
+/// %result = arith.constant dense<0> : vector<1x2x3x4xi64>
+/// %0 = vector.extract %a[0, 0, 0]                  ─┐
+///        : vector<8xi64> from vector<1x2x3x8xi64>   |
+/// %1, %2 = vector.deinterleave %0                   |
+///        : vector<8xi64> -> vector<4xi64>           | -- Initial deinterleave
+/// %3 = vector.insert %1, %result [0, 0, 0]          |    operation unrolled.
+///        : vector<4xi64> into vector<1x2x3x4xi64>   |
+/// %4 = vector.insert %2, %result [0, 0, 0]          |
+///        : vector<4xi64> into vector<1x2x3x4xi64>   ┘
+/// %5 = vector.extract %a[0, 0, 1]                  ─┐
+///        : vector<8xi64> from vector<1x2x3x8xi64>   |
+/// %6, %7 = vector.deinterleave %5                   |
+///        : vector<8xi64> -> vector<4xi64>           | -- Recursive pattern for
+/// %8 = vector.insert %6, %3 [0, 0, 1]               |    subsequent unrolled
+///        : vector<4xi64> into vector<1x2x3x4xi64>   |    deinterleave
+/// %9 = vector.insert %7, %4 [0, 0, 1]               |    operations. Repeated
+///        : vector<4xi64> into vector<1x2x3x4xi64>   ┘    5x in this case.
+/// ```
+///
+/// Note: If any leading dimension before the `targetRank` is scalable the
+/// unrolling will stop before the scalable dimension.
+class UnrollDeinterleaveOp final
+    : public OpRewritePattern<vector::DeinterleaveOp> {
+public:
+  UnrollDeinterleaveOp(int64_t targetRank, MLIRContext *context,
+                       PatternBenefit benefit = 1)
+      : OpRewritePattern(context, benefit), targetRank(targetRank) {};
+
+  LogicalResult matchAndRewrite(vector::DeinterleaveOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResultVectorType();
+    auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
+    if (!unrollIterator)
+      return failure();
+
+    auto loc = op.getLoc();
+    Value emptyResult = rewriter.create<arith::ConstantOp>(
+        loc, resultType, rewriter.getZeroAttr(resultType));
+    Value evenResult = emptyResult;
+    Value oddResult = emptyResult;
+
+    for (auto position : *unrollIterator) {
+      auto extractSrc =
+          rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
+      auto deinterleave =
+          rewriter.create<vector::DeinterleaveOp>(loc, extractSrc);
+      evenResult = rewriter.create<vector::InsertOp>(
+          loc, deinterleave.getRes1(), evenResult, position);
+      oddResult = rewriter.create<vector::InsertOp>(loc, deinterleave.getRes2(),
+                                                    oddResult, position);
+    }
+    rewriter.replaceOp(op, ValueRange{evenResult, oddResult});
+    return success();
+  }
+
+private:
+  int64_t targetRank = 1;
+};
 /// Rewrite vector.interleave op into an equivalent vector.shuffle op, when
 /// applicable: `sourceType` must be 1D and non-scalable.
 ///
@@ -116,7 +183,8 @@ struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
 
 void mlir::vector::populateVectorInterleaveLoweringPatterns(
     RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
-  patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
+  patterns.add<UnrollInterleaveOp, UnrollDeinterleaveOp>(
+      targetRank, patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateVectorInterleaveToShufflePatterns(

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 12121ea0dd70e..54dcf07053906 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2565,6 +2565,22 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi
     return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
 }
 
+// CHECK-LABEL: @vector_deinterleave_2d
+// CHECK-SAME: %[[SRC:.*]]: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>)
+func.func @vector_deinterleave_2d(%a: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>) {
+  // CHECK: llvm.shufflevector
+  // CHECK-NOT: vector.deinterleave %{{.*}} : vector<2x8xf32>
+  %0, %1 = vector.deinterleave %a : vector<2x8xf32> -> vector<2x4xf32>
+  return %0, %1 : vector<2x4xf32>, vector<2x4xf32>
+}
+
+func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xf32>) -> (vector<2x[4]xf32>, vector<2x[4]xf32>) {
+    // CHECK: llvm.intr.vector.deinterleave2
+    // CHECK-NOT: vector.deinterleave %{{.*}} : vector<2x[8]xf32>
+    %0, %1 = vector.deinterleave %a : vector<2x[8]xf32> -> vector<2x[4]xf32>
+    return %0, %1 : vector<2x[4]xf32>, vector<2x[4]xf32>
+}
+
 // -----
 
 // CHECK-LABEL: func.func @vector_bitcast_2d

diff  --git a/mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir
new file mode 100644
index 0000000000000..53f4a8970c794
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: @vector_deinterleave_2d
+// CHECK-SAME: %[[SRC:.*]]: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>)
+func.func @vector_deinterleave_2d(%a: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>) {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0>
+  // CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0]
+  // CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]]
+  // CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %[[CST]] [0]
+  // CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %[[CST]] [0]
+  // CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1]
+  // CHECK: %[[UNZIP_2:.*]], %[[UNZIP_3:.*]] = vector.deinterleave %[[SRC_1]]
+  // CHECK: %[[RES_2:.*]] = vector.insert %[[UNZIP_2]], %[[RES_0]] [1]
+  // CHECK: %[[RES_3:.*]] = vector.insert %[[UNZIP_3]], %[[RES_1]] [1]
+  // CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x4xi32>, vector<2x4xi32>
+  %0, %1 = vector.deinterleave %a : vector<2x8xi32> -> vector<2x4xi32>
+  return %0, %1 : vector<2x4xi32>, vector<2x4xi32>
+}
+
+// CHECK-LABEL: @vector_deinterleave_2d_scalable
+// CHECK-SAME: %[[SRC:.*]]: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>)
+func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>) {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0>
+  // CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0]
+  // CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]]
+  // CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %[[CST]] [0]
+  // CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %[[CST]] [0]
+  // CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1]
+  // CHECK: %[[UNZIP_2:.*]], %[[UNZIP_3:.*]] = vector.deinterleave %[[SRC_1]]
+  // CHECK: %[[RES_2:.*]] = vector.insert %[[UNZIP_2]], %[[RES_0]] [1]
+  // CHECK: %[[RES_3:.*]] = vector.insert %[[UNZIP_3]], %[[RES_1]] [1]
+  // CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x[4]xi32>, vector<2x[4]xi32>
+  %0, %1 = vector.deinterleave %a : vector<2x[8]xi32> -> vector<2x[4]xi32>
+  return %0, %1 : vector<2x[4]xi32>, vector<2x[4]xi32>
+}
+
+// CHECK-LABEL: @vector_deinterleave_4d
+// CHECK-SAME: %[[SRC:.*]]: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>)
+func.func @vector_deinterleave_4d(%a: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>) {
+    // CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0, 0, 0] : vector<8xi64> from vector<1x2x3x8xi64>
+    // CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]] : vector<8xi64> -> vector<4xi64>
+    // CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64>
+    // CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64>
+    // CHECK-COUNT-5: vector.deinterleave %{{.*}} : vector<8xi64> -> vector<4xi64>
+    %0, %1 = vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64>
+    return %0, %1 : vector<1x2x3x4xi64>, vector<1x2x3x4xi64>
+}
+
+// CHECK-LABEL: @vector_deinterleave_nd_with_scalable_dim
+func.func @vector_deinterleave_nd_with_scalable_dim(
+  %a: vector<1x3x[2]x2x3x8xf16>) -> (vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16>) {
+  // The scalable dim blocks unrolling so only the first two dims are unrolled.
+  // CHECK-COUNT-3: vector.deinterleave %{{.*}} : vector<[2]x2x3x8xf16>
+  %0, %1 = vector.deinterleave %a: vector<1x3x[2]x2x3x8xf16> -> vector<1x3x[2]x2x3x4xf16>
+  return %0, %1 : vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %f = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.lower_interleave
+    } : !transform.any_op
+    transform.yield
+  }
+}


        


More information about the Mlir-commits mailing list