[Mlir-commits] [mlir] [mlir][vector] Implement lowering for 1D vector.deinterleave operations (PR #93042)

Mubashar Ahmad llvmlistbot at llvm.org
Tue May 28 08:49:46 PDT 2024


https://github.com/mub-at-arm updated https://github.com/llvm/llvm-project/pull/93042

>From 7dd81f4598702027bac5e4fd31f62c07a74f115a Mon Sep 17 00:00:00 2001
From: "Mubashar.Ahmad at arm.com" <mubashar.ahmad at arm.com>
Date: Thu, 16 May 2024 12:28:34 +0000
Subject: [PATCH 1/2] [mlir][VectorOps] Add deinterleave operation to vector
 dialect

The deinterleave operation constructs two vectors from a single input
vector. Each new vector is the collection of even and odd elements
from the input, respectively. This is essentially the inverse of an
interleave operation.

Each output's size is half of the input vector's trailing dimension
for the n-D case and only dimension for 1-D cases. It is not possible
to conduct the operation on 0-D inputs or vectors where the size of
the (trailing) dimension is 1.

The operation supports scalable vectors.

Example:
```mlir
%0 = vector.deinterleave %a
           : vector<[4]xi32> -> vector<[2]xi32>
%1 = vector.deinterleave %b
           : vector<8xi8> -> vector<4xi8>
%2 = vector.deinterleave %c
           : vector<2x8xf32> -> vector<2x4xf32>
%3 = vector.deinterleave %d
           : vector<2x4x[6]xf64> -> vector<2x4x[3]xf64>
```

>From 960cb0b03c54e453264820067026fd44d74f0536 Mon Sep 17 00:00:00 2001
From: "Mubashar.Ahmad at arm.com" <mubashar.ahmad at arm.com>
Date: Wed, 22 May 2024 09:01:18 +0000
Subject: [PATCH 2/2] [mlir][vector] Add 1D vector.deinterleave lowering

This patchs implements the lowering of vector.deinterleave
for 1D vectors.

For fixed vector types, the operation is lowered to two
llvm shufflevector operations. One for even indexed
elements and the other for odd indexed elements. A poison
operation is used to satisfy the parameters of the
shufflevector parameters.

For scalable vectors, the llvm vector.deinterleave2
intrinsic is used for lowering. As such the results
found by extraction and used to form the result
struct for the intrinsic.
---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 68 ++++++++++++++++++-
 .../VectorToLLVM/vector-to-llvm.mlir          | 22 ++++++
 .../Emulated/test-scalable-deinterleave.mlir  | 30 ++++++++
 .../Dialect/Vector/CPU/test-deinterleave.mlir | 18 +++++
 4 files changed, 136 insertions(+), 2 deletions(-)
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/Emulated/test-scalable-deinterleave.mlir
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/test-deinterleave.mlir

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index fe6bcc1c8b667..759d91d1604df 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1761,6 +1761,70 @@ struct VectorInterleaveOpLowering
   }
 };
 
+/// Conversion pattern for a `vector.deinterleave`.
+/// This supports for fixed-sized vectors and scalable vectors.
+struct VectorDeinterleaveOpLowering
+    : public ConvertOpToLLVMPattern<vector::DeinterleaveOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = deinterleaveOp.getResultVectorType();
+    VectorType sourceType = deinterleaveOp.getSourceVectorType();
+    auto loc = deinterleaveOp.getLoc();
+
+    // Note: n-D deinterleave operations should be lowered to the 1-D before
+    // converting to LLVM.
+    if (resultType.getRank() != 1)
+      return rewriter.notifyMatchFailure(deinterleaveOp,
+                                         "DeinterleaveOp not rank 1");
+
+    if (resultType.isScalable()) {
+      auto llvmTypeConverter = this->getTypeConverter();
+      auto deinterleaveResults = deinterleaveOp.getResultTypes();
+      auto packedOpResults =
+          llvmTypeConverter->packOperationResults(deinterleaveResults);
+      auto intrinsic = rewriter.create<LLVM::vector_deinterleave2>(
+          loc, packedOpResults, adaptor.getSource());
+
+      auto evenResult = rewriter.create<LLVM::ExtractValueOp>(
+          loc, intrinsic->getResult(0), 0);
+      auto oddResult = rewriter.create<LLVM::ExtractValueOp>(
+          loc, intrinsic->getResult(0), 1);
+
+      rewriter.replaceOp(deinterleaveOp, ValueRange{evenResult, oddResult});
+      return success();
+    }
+    // Lower fixed-size deinterleave to two shufflevectors. While the
+    // vector.deinterleave2 intrinsic supports fixed and scalable vectors, the
+    // langref still recommends fixed-vectors use shufflevector, see:
+    // https://llvm.org/docs/LangRef.html#id889.
+    int64_t resultVectorSize = resultType.getNumElements();
+    SmallVector<int32_t> evenShuffleMask;
+    SmallVector<int32_t> oddShuffleMask;
+
+    evenShuffleMask.reserve(resultVectorSize);
+    oddShuffleMask.reserve(resultVectorSize);
+
+    for (int i = 0; i < sourceType.getNumElements(); ++i) {
+      if (i % 2 == 0)
+        evenShuffleMask.push_back(i);
+      else
+        oddShuffleMask.push_back(i);
+    }
+
+    auto poison = rewriter.create<LLVM::PoisonOp>(loc, sourceType);
+    auto evenShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
+        loc, adaptor.getSource(), poison, evenShuffleMask);
+    auto oddShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
+        loc, adaptor.getSource(), poison, oddShuffleMask);
+
+    rewriter.replaceOp(deinterleaveOp, ValueRange{evenShuffle, oddShuffle});
+    return success();
+  }
+};
+
 } // namespace
 
 /// Populate the given list with patterns that convert from Vector to LLVM.
@@ -1785,8 +1849,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
                VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
                VectorSplatOpLowering, VectorSplatNdOpLowering,
                VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
-               MaskedReductionOpConversion, VectorInterleaveOpLowering>(
-      converter);
+               MaskedReductionOpConversion, VectorInterleaveOpLowering,
+               VectorDeinterleaveOpLowering>(converter);
   // Transfer ops with rank > 1 are handled by VectorToSCF.
   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 439f1e920e392..874483443095a 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2546,3 +2546,25 @@ func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]
   %0 = vector.interleave %a, %b : vector<2x[8]xi16>
   return %0 : vector<2x[16]xi16>
 }
+
+// -----
+
+// CHECK-LABEL: @vector_deinterleave_1d
+// CHECK-SAME:  (%[[SRC:.*]]: vector<4xi32>) -> (vector<2xi32>, vector<2xi32>)
+func.func @vector_deinterleave_1d(%a: vector<4xi32>) -> (vector<2xi32>, vector<2xi32>) {
+  // CHECK: %[[POISON:.*]] = llvm.mlir.poison : vector<4xi32>
+  // CHECK: llvm.shufflevector %[[SRC]], %[[POISON]] [0, 2] : vector<4xi32>
+  // CHECK: llvm.shufflevector %[[SRC]], %[[POISON]] [1, 3] : vector<4xi32>
+  %0, %1 = vector.deinterleave %a : vector<4xi32> -> vector<2xi32>
+  return %0, %1 : vector<2xi32>, vector<2xi32>
+}
+
+// CHECK-LABEL: @vector_deinterleave_1d_scalable
+// CHECK-SAME:  %[[SRC:.*]]: vector<[4]xi32>) -> (vector<[2]xi32>, vector<[2]xi32>)
+func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi32>, vector<[2]xi32>) {
+    // CHECK: %[[RES:.*]] = "llvm.intr.vector.deinterleave2"
+    // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(vector<[2]xi32>, vector<[2]xi32>)>
+    // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(vector<[2]xi32>, vector<[2]xi32>)>
+    %0, %1 = vector.deinterleave %a : vector<[4]xi32> -> vector<[2]xi32>
+    return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/Emulated/test-scalable-deinterleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/Emulated/test-scalable-deinterleave.mlir
new file mode 100644
index 0000000000000..27fd0687a7053
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/Emulated/test-scalable-deinterleave.mlir
@@ -0,0 +1,30 @@
+// DEFINE: %{entry_point} = entry
+// DEFINE: %{compile} = mlir-opt %s -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd -march=aarch64 -mattr=+sve \
+// DEFINE:  -e %{entry_point} -entry-point-result=void \
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+func.func @entry() {
+  // Set the vector length to 256-bit (equivalent to vscale=2).
+  // This allows the checks (below) to look at an entire vector.
+  %c256 = arith.constant 256 : i32
+  func.call @setArmVLBits(%c256) : (i32) -> ()
+  func.call @test_deinterleave() : () -> ()
+  return
+}
+
+func.func @test_deinterleave() {
+  %step_vector = llvm.intr.experimental.stepvector : vector<[4]xi8>
+  vector.print %step_vector : vector<[4]xi8>
+  // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
+  %v1, %v2 = vector.deinterleave %step_vector : vector<[4]xi8> -> vector<[2]xi8>
+  vector.print %v1 : vector<[2]xi8>
+  vector.print %v2 : vector<[2]xi8>
+  // CHECK: ( 0, 2, 4, 6 )
+  // CHECK: ( 1, 3, 5, 7 )
+  return
+}
+
+func.func private @setArmVLBits(%bits : i32)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-deinterleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-deinterleave.mlir
new file mode 100644
index 0000000000000..4915a3cde124d
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-deinterleave.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt %s -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @entry() {
+  %v0 = arith.constant dense<[1, 2, 3, 4]> : vector<4xi8>
+  vector.print %v0 : vector<4xi8>
+  // CHECK: ( 1, 2, 3, 4 )
+
+  %v1, %v2 = vector.deinterleave %v0 : vector<4xi8> -> vector<2xi8>
+  vector.print %v1 : vector<2xi8>
+  vector.print %v2 : vector<2xi8>
+  // CHECK: ( 1, 3 )
+  // CHECK: ( 2, 4 )
+
+  return
+}



More information about the Mlir-commits mailing list