[Mlir-commits] [mlir] [mlir][vector] Add n-d deinterleave lowering (PR #94237)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 3 09:00:28 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Mubashar Ahmad (mub-at-arm)

<details>
<summary>Changes</summary>

This patch implements the lowering for vector
deinterleave for vector of n-dimensions. Process
involves unrolling the n-d vector to a series
of one-dimensional vectors. The deinterleave
operation is then used on these vectors.

From:
```
%0, %1 = vector.deinterleave %a : vector<2x[4]xi8> -> vector<2x[2]xi8>
```

To:
```
%2 = llvm.extractvalue %0[0] : !llvm.array<2 x vector<8xf32>>
%3 = llvm.mlir.poison : vector<8xf32>
%4 = llvm.shufflevector %2, %3 [0, 2, 4, 6] : vector<8xf32>
%5 = llvm.shufflevector %2, %3 [1, 3, 5, 7] : vector<8xf32>
%6 = llvm.insertvalue %4, %1[0] : !llvm.array<2 x vector<4xf32>>
%7 = llvm.insertvalue %5, %1[0] : !llvm.array<2 x vector<4xf32>>
%8 = llvm.extractvalue %0[1] : !llvm.array<2 x vector<8xf32>>
...etc.
```

---
Full diff: https://github.com/llvm/llvm-project/pull/94237.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp (+37) 
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+36) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 77c97b2f1497c..557837426d855 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -79,6 +79,42 @@ class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
   int64_t targetRank = 1;
 };
 
+class UnrollDeinterleaveOp final : public OpRewritePattern<vector::DeinterleaveOp> {
+public:
+  UnrollDeinterleaveOp(int64_t targetRank, MLIRContext *context,
+                       PatternBenefit benefit = 1)
+      : OpRewritePattern(context, benefit), targetRank(targetRank) {};
+
+  LogicalResult matchAndRewrite(vector::DeinterleaveOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResultVectorType();
+    auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
+    if (!unrollIterator)
+      return failure();
+
+    auto loc = op.getLoc();
+    Value evenResult = rewriter.create<arith::ConstantOp>(
+        loc, resultType, rewriter.getZeroAttr(resultType));
+    Value oddResult = rewriter.create<arith::ConstantOp>(
+        loc, resultType, rewriter.getZeroAttr(resultType));
+
+    for (auto position : *unrollIterator) {
+      auto extractSrc = rewriter.create<vector::ExtractOp>(
+        loc, op.getSource(), position);
+      auto deinterleave = rewriter.create<vector::DeinterleaveOp>(
+        loc, extractSrc);
+      evenResult = rewriter.create<vector::InsertOp>(
+        loc, deinterleave.getRes1(), evenResult, position);
+      oddResult = rewriter.create<vector::InsertOp>(
+        loc, deinterleave.getRes2(), oddResult, position);
+    }
+    rewriter.replaceOp(op, ValueRange{evenResult, oddResult});
+    return success();
+  }
+
+private:
+  int64_t targetRank = 1;
+};
 /// Rewrite vector.interleave op into an equivalent vector.shuffle op, when
 /// applicable: `sourceType` must be 1D and non-scalable.
 ///
@@ -117,6 +153,7 @@ struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
 void mlir::vector::populateVectorInterleaveLoweringPatterns(
     RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
   patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
+  patterns.add<UnrollDeinterleaveOp>(targetRank, patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateVectorInterleaveToShufflePatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 245edb6789d30..21f4872bb2cd9 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2564,3 +2564,39 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi
     %0, %1 = vector.deinterleave %a : vector<[4]xi32> -> vector<[2]xi32>
     return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
 }
+
+// CHECK-LABEL: @vector_deinterleave_2d
+// CHECK-SAME: %[[SRC:.*]]: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>)
+func.func @vector_deinterleave_2d(%a: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>) {
+   // CHECK: %[[EXTRACT_ONE:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<8xf32>> 
+   // CHECK: %[[POISON_ONE:.*]] = llvm.mlir.poison : vector<8xf32>
+   // CHECK: %[[SHUFFLE_A:.*]] = llvm.shufflevector %[[EXTRACT_ONE]], %[[POISON_ONE]] [0, 2, 4, 6] : vector<8xf32> 
+   // CHECK: %[[SHUFFLE_B:.*]] = llvm.shufflevector %[[EXTRACT_ONE]], %[[POISON_ONE]] [1, 3, 5, 7] : vector<8xf32> 
+   // CHECK: %[[INSERT_A:.*]] = llvm.insertvalue %[[SHUFFLE_A]], %{{.*}}[0] : !llvm.array<2 x vector<4xf32>> 
+   // CHECK: %[[INSERT_B:.*]] = llvm.insertvalue %[[SHUFFLE_B]], %{{.*}}[0] : !llvm.array<2 x vector<4xf32>> 
+   // CHECK: %[[EXTRACT_TWO:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<8xf32>> 
+   // CHECK: %[[POISON_TWO:.*]] = llvm.mlir.poison : vector<8xf32>
+   // CHECK: %[[SHUFFLE_C:.*]] = llvm.shufflevector %[[EXTRACT_TWO]], %[[POISON_TWO]] [0, 2, 4, 6] : vector<8xf32> 
+   // CHECK: %[[SHUFFLE_D:.*]] = llvm.shufflevector %[[EXTRACT_TWO]], %[[POISON_TWO]] [1, 3, 5, 7] : vector<8xf32> 
+   // CHECK: %[[INSERT_C:.*]] = llvm.insertvalue %[[SHUFFLE_C]], %[[INSERT_A]][1] : !llvm.array<2 x vector<4xf32>> 
+   // CHECK: %[[INSERT_D:.*]] = llvm.insertvalue %[[SHUFFLE_D]], %[[INSERT_B]][1] : !llvm.array<2 x vector<4xf32>> 
+  %0, %1 = vector.deinterleave %a : vector<2x8xf32> -> vector<2x4xf32>
+  return %0, %1 : vector<2x4xf32>, vector<2x4xf32>
+}
+
+func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xf32>) -> (vector<2x[4]xf32>, vector<2x[4]xf32>) {
+    // CHECK: %[[EXTRACT_A:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[8]xf32>>
+    // CHECK: %[[VECTOR_ONE:.*]] = "llvm.intr.vector.deinterleave2"(%[[EXTRACT_ONE]]) : (vector<[8]xf32>) -> !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
+    // CHECK: %[[EXTRACT_B:.*]] = llvm.extractvalue %[[VECTOR_ONE]][0] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)> 
+    // CHECK: %[[EXTRACT_C:.*]] = llvm.extractvalue %[[VECTOR_ONE]][1] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)> 
+    // CHECK: %[[INSERT_A:.*]] = llvm.insertvalue %[[EXTRACT_B]], %{{.*}}[0] : !llvm.array<2 x vector<[4]xf32>> 
+    // CHECK: %[[INSERT_B:.*]] = llvm.insertvalue %[[EXTRACT_C]], %{{.*}}[0] : !llvm.array<2 x vector<[4]xf32>> 
+    // CHECK: %[[EXTRACT_D:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[8]xf32>> 
+    // CHECK: %[[VECTOR_TWO:.*]] = "llvm.intr.vector.deinterleave2"(%[[EXTRACT_D]]) : (vector<[8]xf32>) -> !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
+    // CHECK: %[[EXTRACT_E:.*]] = llvm.extractvalue %[[VECTOR_TWO]][0] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)> 
+    // CHECK: %[[EXTRACT_F:.*]] = llvm.extractvalue %[[VECTOR_TWO]][1] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)> 
+    // CHECK: %[[INSERT_C:.*]] = llvm.insertvalue %[[EXTRACT_E]], %[[INSERT_A]][1] : !llvm.array<2 x vector<[4]xf32>>
+    // CHECK: %[[INSERT_D:.*]] = llvm.insertvalue %[[EXTRACT_F]], %[[INSERT_B]][1] : !llvm.array<2 x vector<[4]xf32>>
+    %0, %1 = vector.deinterleave %a : vector<2x[8]xf32> -> vector<2x[4]xf32>
+    return %0, %1 : vector<2x[4]xf32>, vector<2x[4]xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/94237


More information about the Mlir-commits mailing list