[Mlir-commits] [mlir] [mlir][vector] Add n-d deinterleave lowering (PR #94237)

Mubashar Ahmad llvmlistbot at llvm.org
Wed Jun 5 11:54:03 PDT 2024


https://github.com/mub-at-arm updated https://github.com/llvm/llvm-project/pull/94237

>From b0649b595ed603117682a9f43bed5a552eb6bf76 Mon Sep 17 00:00:00 2001
From: "Mubashar.Ahmad at arm.com" <mubashar.ahmad at arm.com>
Date: Mon, 3 Jun 2024 15:43:35 +0000
Subject: [PATCH] [mlir][vector] Add n-d deinterleave lowering

This patch implements the lowering for vector
deinterleave for vector of n-dimensions. Process
involves unrolling the n-d vector to a series
of one-dimensional vectors. The deinterleave
operation is then used on these vectors.

From:
```
%0, %1 = vector.deinterleave %a : vector<2x[4]xi8> -> vector<2x[2]xi8>
```

To:
```
%2 = llvm.extractvalue %0[0] : !llvm.array<2 x vector<8xf32>>
%3 = llvm.mlir.poison : vector<8xf32>
%4 = llvm.shufflevector %2, %3 [0, 2, 4, 6] : vector<8xf32>
%5 = llvm.shufflevector %2, %3 [1, 3, 5, 7] : vector<8xf32>
%6 = llvm.insertvalue %4, %1[0] : !llvm.array<2 x vector<4xf32>>
%7 = llvm.insertvalue %5, %1[0] : !llvm.array<2 x vector<4xf32>>
%8 = llvm.extractvalue %0[1] : !llvm.array<2 x vector<8xf32>>
...etc.
```
---
 .../Transforms/LowerVectorInterleave.cpp      | 62 ++++++++++++++++-
 .../VectorToLLVM/vector-to-llvm.mlir          | 16 +++++
 ...ctor-deinterleave-lowering-transforms.mlir | 68 +++++++++++++++++++
 3 files changed, 145 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 77c97b2f1497c..4de7d084534b2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -79,6 +79,65 @@ class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
   int64_t targetRank = 1;
 };
 
+/// A one-shot unrolling of vector.deinterleave to the `targetRank`.
+///
+/// Example:
+///
+/// ```mlir
+/// vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64>
+/// ```
+/// Would be unrolled to:
+/// ```mlir
+/// %result = arith.constant dense<0> : vector<1x2x3x8xi64>
+/// %0 = vector.extract %a[0, 0, 0]                 ─┐
+///        : vector<4xi64> from vector<1x2x3x4xi64>  |  - Repeated 6x for
+/// %1, %2 = vector.deinterleave %0 :                |    all leading positions
+///        : vector<8xi64> -> vector<4xi64>          |
+/// %3 = vector.insert %1, %result [0, 0, 0]         |
+///        : vector<4xi64> into vector<1x2x3x4xi64>  |
+/// %3 = vector.insert %2, %result [0, 0, 0]         |
+///        : vector<4xi64> into vector<1x2x3x4xi64>  ┘
+/// ```
+///
+/// Note: If any leading dimension before the `targetRank` is scalable the
+/// unrolling will stop before the scalable dimension.
+class UnrollDeinterleaveOp final
+    : public OpRewritePattern<vector::DeinterleaveOp> {
+public:
+  UnrollDeinterleaveOp(int64_t targetRank, MLIRContext *context,
+                       PatternBenefit benefit = 1)
+      : OpRewritePattern(context, benefit), targetRank(targetRank) {};
+
+  LogicalResult matchAndRewrite(vector::DeinterleaveOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResultVectorType();
+    auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
+    if (!unrollIterator)
+      return failure();
+
+    auto loc = op.getLoc();
+    Value evenResult = rewriter.create<arith::ConstantOp>(
+        loc, resultType, rewriter.getZeroAttr(resultType));
+    Value oddResult = rewriter.create<arith::ConstantOp>(
+        loc, resultType, rewriter.getZeroAttr(resultType));
+
+    for (auto position : *unrollIterator) {
+      auto extractSrc =
+          rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
+      auto deinterleave =
+          rewriter.create<vector::DeinterleaveOp>(loc, extractSrc);
+      evenResult = rewriter.create<vector::InsertOp>(
+          loc, deinterleave.getRes1(), evenResult, position);
+      oddResult = rewriter.create<vector::InsertOp>(loc, deinterleave.getRes2(),
+                                                    oddResult, position);
+    }
+    rewriter.replaceOp(op, ValueRange{evenResult, oddResult});
+    return success();
+  }
+
+private:
+  int64_t targetRank = 1;
+};
 /// Rewrite vector.interleave op into an equivalent vector.shuffle op, when
 /// applicable: `sourceType` must be 1D and non-scalable.
 ///
@@ -116,7 +175,8 @@ struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
 
 void mlir::vector::populateVectorInterleaveLoweringPatterns(
     RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
-  patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
+  patterns.add<UnrollInterleaveOp, UnrollDeinterleaveOp>(
+      targetRank, patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateVectorInterleaveToShufflePatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 12121ea0dd70e..54dcf07053906 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2565,6 +2565,22 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi
     return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
 }
 
+// CHECK-LABEL: @vector_deinterleave_2d
+// CHECK-SAME: %[[SRC:.*]]: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>)
+func.func @vector_deinterleave_2d(%a: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>) {
+  // CHECK: llvm.shufflevector
+  // CHECK-NOT: vector.deinterleave %{{.*}} : vector<2x8xf32>
+  %0, %1 = vector.deinterleave %a : vector<2x8xf32> -> vector<2x4xf32>
+  return %0, %1 : vector<2x4xf32>, vector<2x4xf32>
+}
+
+func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xf32>) -> (vector<2x[4]xf32>, vector<2x[4]xf32>) {
+    // CHECK: llvm.intr.vector.deinterleave2
+    // CHECK-NOT: vector.deinterleave %{{.*}} : vector<2x[8]xf32>
+    %0, %1 = vector.deinterleave %a : vector<2x[8]xf32> -> vector<2x[4]xf32>
+    return %0, %1 : vector<2x[4]xf32>, vector<2x[4]xf32>
+}
+
 // -----
 
 // CHECK-LABEL: func.func @vector_bitcast_2d
diff --git a/mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir
new file mode 100644
index 0000000000000..b3335863043b0
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: @vector_deinterleave_2d
+// CHECK-SAME: %[[SRC:.*]]: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>)
+func.func @vector_deinterleave_2d(%a: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>) {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0>
+  // CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0]
+  // CHECK: %[[ZIP_0:.*]], %[[ZIP_1:.*]] = vector.deinterleave %[[SRC_0]]
+  // CHECK: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
+  // CHECK: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[CST]] [0]
+  // CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1]
+  // CHECK: %[[ZIP_1:.*]], %[[ZIP_2:.*]] = vector.deinterleave %[[SRC_1]]
+  // CHECK: %[[RES_2:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
+  // CHECK: %[[RES_3:.*]] = vector.insert %[[ZIP_2]], %[[RES_1]] [1]
+  // CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x4xi32>, vector<2x4xi32>
+  %0, %1 = vector.deinterleave %a : vector<2x8xi32> -> vector<2x4xi32>
+  return %0, %1 : vector<2x4xi32>, vector<2x4xi32>
+}
+
+// CHECK-LABEL: @vector_deinterleave_2d_scalable
+// CHECK-SAME: %[[SRC:.*]]: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>)
+func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>) {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0>
+  // CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0]
+  // CHECK: %[[ZIP_0:.*]], %[[ZIP_1:.*]] = vector.deinterleave %[[SRC_0]]
+  // CHECK: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
+  // CHECK: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[CST]] [0]
+  // CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1]
+  // CHECK: %[[ZIP_1:.*]], %[[ZIP_2:.*]] = vector.deinterleave %[[SRC_1]]
+  // CHECK: %[[RES_2:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
+  // CHECK: %[[RES_3:.*]] = vector.insert %[[ZIP_2]], %[[RES_1]] [1]
+  // CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x[4]xi32>, vector<2x[4]xi32>
+  %0, %1 = vector.deinterleave %a : vector<2x[8]xi32> -> vector<2x[4]xi32>
+  return %0, %1 : vector<2x[4]xi32>, vector<2x[4]xi32>
+}
+
+// CHECK-LABEL: @vector_deinterleave_4d
+// CHECK-SAME: %[[SRC:.*]]: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>)
+func.func @vector_deinterleave_4d(%a: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>) {
+    // CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0, 0, 0] : vector<8xi64> from vector<1x2x3x8xi64>
+    // CHECK: %[[ZIP_0:.*]], %[[ZIP_1:.*]] = vector.deinterleave %[[SRC_0]] : vector<8xi64> -> vector<4xi64>
+    // CHECK: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64>
+    // CHECK: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64>
+    // CHECK-COUNT-5: vector.deinterleave %{{.*}} : vector<8xi64> -> vector<4xi64>
+    %0, %1 = vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64>
+    return %0, %1 : vector<1x2x3x4xi64>, vector<1x2x3x4xi64>
+}
+
+// CHECK-LABEL: @vector_deinterleave_nd_with_scalable_dim
+func.func @vector_deinterleave_nd_with_scalable_dim(
+  %a: vector<1x3x[2]x2x3x8xf16>) -> (vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16>) {
+  // The scalable dim blocks unrolling so only the first two dims are unrolled.
+  // CHECK-COUNT-3: vector.deinterleave %{{.*}} : vector<[2]x2x3x8xf16>
+  %0, %1 = vector.deinterleave %a: vector<1x3x[2]x2x3x8xf16> -> vector<1x3x[2]x2x3x4xf16>
+  return %0, %1 : vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %f = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.lower_interleave
+    } : !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list