[Mlir-commits] [mlir] [mlir][vector] Implement lowering for 1D vector.deinterleave operations (PR #93042)

Mubashar Ahmad llvmlistbot at llvm.org
Wed May 22 07:59:05 PDT 2024


https://github.com/mub-at-arm created https://github.com/llvm/llvm-project/pull/93042

This patch implements the lowering of vector.deinterleave 
for 1D vectors.

For fixed vector types, the operation is lowered to two
llvm shufflevector operations. One for even indexed
elements and the other for odd indexed elements. A poison
operation is used to satisfy the parameters of the
shufflevector parameters.
    
For scalable vectors, the llvm vector.deinterleave2
intrinsic is used for lowering. As such the results
found by extraction and used to form the result
struct for the intrinsic.

>From 0987e00444844ce1faf5071d4e565cc20fc9f86d Mon Sep 17 00:00:00 2001
From: "Mubashar.Ahmad at arm.com" <mubashar.ahmad at arm.com>
Date: Thu, 16 May 2024 12:28:34 +0000
Subject: [PATCH 1/4] [mlir][VectorOps] Add deinterleave operation to vector
 dialect

The deinterleave operation constructs two vectors from a single input
vector. Each new vector is the collection of even and odd elements
from the input, respectively. This is essentially the inverse of an
interleave operation.

Each output's size is half of the input vector's trailing dimension
for the n-D case and only dimension for 1-D cases. It is not possible
to conduct the operation on 0-D inputs or vectors where the size of
the (trailing) dimension is 1.

The operation supports scalable vectors.

Example:
```mlir
%0 = vector.deinterleave %a
           : vector<[4]xi32>     ; yields vector<[2]xi32>, vector<[2]xi32>
%1 = vector.deinterleave %b
           : vector<8xi8>        ; yields vector<4xi8>, vector<4xi8>
%2 = vector.deinterleave %c
           : vector<2x8xf32>     ; yields vector<2x4xf32>, vector<2x4xf32>
%3 = vector.deinterleave %d
           : vector<2x4x[6]xf64> ; yields vector<2x4x[3]xf64>, vector<2x4x[3]xf64>
```
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 76 +++++++++++++++++++
 mlir/test/Dialect/Vector/invalid.mlir         | 56 ++++++++++++++
 mlir/test/Dialect/Vector/ops.mlir             | 42 ++++++++++
 3 files changed, 174 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 332b5ad08ced9..1e7e0a1715178 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -543,6 +543,82 @@ def Vector_InterleaveOp :
   }];
 }
 
+class ResultIsHalfSourceVectorType<string result> : TypesMatchWith<
+  "type of 'input' is double the width of results",
+  "input", result,
+  [{
+    [&]() -> ::mlir::VectorType {
+      auto vectorType = ::llvm::cast<mlir::VectorType>($_self);
+      ::mlir::VectorType::Builder builder(vectorType);
+      auto lastDim = vectorType.getRank() - 1;
+      auto newDimSize = vectorType.getDimSize(lastDim) / 2;;
+      if (newDimSize <= 0)
+         return vectorType; // (invalid input type)
+      return builder.setDim(lastDim, newDimSize);
+    }()
+  }]
+>;
+
+def Vector_DeinterleaveOp :
+  Vector_Op<"deinterleave", [Pure,
+    PredOpTrait<"trailing dimension of input vector must be an even number",
+    CPred<[{
+      [&](){
+        auto srcVec = getSourceVectorType();
+        return srcVec.getDimSize(srcVec.getRank() - 1) % 2 == 0;
+      }()
+    }]>>,
+    ResultIsHalfSourceVectorType<"res1">,
+    ResultIsHalfSourceVectorType<"res2">,
+    AllTypesMatch<["res1", "res2"]>
+    ]> {
+      let summary = "constructs two vectors by deinterleaving an input vector";
+      let description = [{
+        The deinterleave operation constructs two vectors from a single input
+        vector. Each new vector is the collection of even and odd elements
+        from the input, respectively. This is essentially the inverse of an
+        interleave operation.
+
+        Each output's size is half of the input vector's trailing dimension
+        for the n-D case and only dimension for 1-D cases. It is not possible
+        to conduct the operation on 0-D inputs or vectors where the size of
+        the (trailing) dimension is 1.
+
+        The operation supports scalable vectors.
+
+        Example:
+        ```mlir
+        %0 = vector.deinterleave %a
+                   : vector<[4]xi32>     ; yields vector<[2]xi32>, vector<[2]xi32>
+        %1 = vector.deinterleave %b
+                   : vector<8xi8>        ; yields vector<4xi8>, vector<4xi8>
+        %2 = vector.deinterleave %c
+                   : vector<2x8xf32>     ; yields vector<2x4xf32>, vector<2x4xf32>
+        %3 = vector.deinterleave %d
+                   : vector<2x4x[6]xf64> ; yields vector<2x4x[3]xf64>, vector<2x4x[3]xf64>
+        ```
+      }];
+
+      let arguments = (ins AnyVector:$input);
+      let results = (outs AnyVector:$res1, AnyVector:$res2);
+
+      let assemblyFormat = [{
+        $input attr-dict `:` type($input)
+      }];
+
+      let extraClassDeclaration = [{
+        VectorType getSourceVectorType() {
+          return ::llvm::cast<VectorType>(getInput().getType());
+        }
+        VectorType getResultOneVectorType() {
+          return ::llvm::cast<VectorType>(getRes1().getType());
+        }
+        VectorType getResultTwoVectorType() {
+          return ::llvm::cast<VectorType>(getRes2().getType());
+        }
+      }];
+    }
+
 def Vector_ExtractElementOp :
   Vector_Op<"extractelement", [Pure,
      TypesMatchWith<"result type matches element type of vector operand",
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c9f7e9c6e2fb0..25cacc6fdf93d 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1798,3 +1798,59 @@ func.func @invalid_outerproduct1(%src : memref<?xf32>) {
   // expected-error @+1 {{'vector.outerproduct' op expected 1-d vector for operand #1}}
   %op = vector.outerproduct %0, %1 : vector<[4]x[4]xf32>, vector<[4]xf32>
 }
+
+// -----
+
+func.func @deinterleave_zero_dim_fail(%vec : vector<f32>) {
+  // expected-error @+1 {{'vector.deinterleave' 'input' must be vector of any type values, but got 'vector<f32>'}}
+  %0, %1 = vector.deinterleave %vec : vector<f32> 
+  return
+}
+
+// -----
+
+func.func @deinterleave_one_dim_fail(%vec : vector<1xf32>) {
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that trailing dimension of input vector must be an even number}}
+  %0, %1 = vector.deinterleave %vec : vector<1xf32>
+  return
+}
+
+// -----
+
+func.func @deinterleave_oversized_output_fail(%vec : vector<4xf32>) {
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
+  %0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<8xf32>, vector<8xf32>)
+  return
+}
+
+// -----
+
+func.func @deinterleave_output_dim_size_mismatch(%vec : vector<4xf32>) {
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
+  %0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<4xf32>, vector<2xf32>)
+  return
+}
+
+// -----
+
+func.func @deinterleave_n_dim_rank_fail(%vec : vector<2x3x4xf32>) {
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
+  %0, %1 = "vector.deinterleave" (%vec) : (vector<2x3x4xf32>) -> (vector<2x3x4xf32>, vector<2x3x2xf32>)
+  return
+}
+
+// -----
+
+func.func @deinterleave_scalable_dim_size_fail(%vec : vector<2x[4]xf32>) {
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
+  %0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<2x[1]xf32>)
+  return
+}
+
+// -----
+
+func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
+  %0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<[2]xf32>)
+  return
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 79a80be4f8b20..a6a992f23a4ba 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1116,3 +1116,45 @@ func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>)
   %0 = vector.interleave %a, %b : vector<2x[2]xf64>
   return %0 : vector<2x[4]xf64>
 }
+
+// CHECK-LABEL: @deinterleave_1d
+func.func @deinterleave_1d(%arg: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
+  // CHECK: vector.deinterleave %{{.*}} : vector<4xf32>
+  %0, %1 = vector.deinterleave %arg : vector<4xf32>
+  return %0, %1 : vector<2xf32>, vector<2xf32>
+}
+
+// CHECK-LABEL: @deinterleave_1d_scalable
+func.func @deinterleave_1d_scalable(%arg: vector<[4]xf32>) -> (vector<[2]xf32>, vector<[2]xf32>) {
+  // CHECK: vector.deinterleave %{{.*}} : vector<[4]xf32>
+  %0, %1 = vector.deinterleave %arg : vector<[4]xf32>
+  return %0, %1 : vector<[2]xf32>, vector<[2]xf32>
+}
+
+// CHECK-LABEL: @deinterleave_2d
+func.func @deinterleave_2d(%arg: vector<3x4xf32>) -> (vector<3x2xf32>, vector<3x2xf32>) {
+  // CHECK: vector.deinterleave %{{.*}} : vector<3x4xf32>
+  %0, %1 = vector.deinterleave %arg : vector<3x4xf32>
+  return %0, %1 : vector<3x2xf32>, vector<3x2xf32>
+}
+
+// CHECK-LABEL: @deinterleave_2d_scalable
+func.func @deinterleave_2d_scalable(%arg: vector<3x[4]xf32>) -> (vector<3x[2]xf32>, vector<3x[2]xf32>) {
+  // CHECK: vector.deinterleave %{{.*}} : vector<3x[4]xf32>
+  %0, %1 = vector.deinterleave %arg : vector<3x[4]xf32>
+  return %0, %1 : vector<3x[2]xf32>, vector<3x[2]xf32>
+}
+
+// CHECK-LABEL: @deinterleave_nd
+func.func @deinterleave_nd(%arg: vector<2x3x4x6xf32>) -> (vector<2x3x4x3xf32>, vector<2x3x4x3xf32>) {
+  // CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x6xf32>
+  %0, %1 = vector.deinterleave %arg : vector<2x3x4x6xf32>
+  return %0, %1 : vector<2x3x4x3xf32>, vector<2x3x4x3xf32>
+}
+
+// CHECK-LABEL: @deinterleave_nd_scalable
+func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>) {
+  // CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x[6]xf32>
+  %0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32>
+  return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
+}
\ No newline at end of file

>From 11dc393466361241223e9312e7dcd7f292fe3d89 Mon Sep 17 00:00:00 2001
From: "Mubashar.Ahmad at arm.com" <mubashar.ahmad at arm.com>
Date: Thu, 16 May 2024 12:28:34 +0000
Subject: [PATCH 2/4] [mlir][VectorOps] Add deinterleave operation to vector
 dialect

The deinterleave operation constructs two vectors from a single input
vector. Each new vector is the collection of even and odd elements
from the input, respectively. This is essentially the inverse of an
interleave operation.

Each output's size is half of the input vector's trailing dimension
for the n-D case and only dimension for 1-D cases. It is not possible
to conduct the operation on 0-D inputs or vectors where the size of
the (trailing) dimension is 1.

The operation supports scalable vectors.

Example:
```mlir
%0 = vector.deinterleave %a
           : vector<[4]xi32> -> vector<[2]xi32>
%1 = vector.deinterleave %b
           : vector<8xi8> -> vector<4xi8>
%2 = vector.deinterleave %c
           : vector<2x8xf32> -> vector<2x4xf32>
%3 = vector.deinterleave %d
           : vector<2x4x[6]xf64> -> vector<2x4x[3]xf64>
```
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 62 ++++++++++---------
 mlir/test/Dialect/Vector/invalid.mlir         | 20 +++---
 mlir/test/Dialect/Vector/ops.mlir             | 26 ++++----
 3 files changed, 56 insertions(+), 52 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 1e7e0a1715178..bfbb40405c3c1 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -544,8 +544,8 @@ def Vector_InterleaveOp :
 }
 
 class ResultIsHalfSourceVectorType<string result> : TypesMatchWith<
-  "type of 'input' is double the width of results",
-  "input", result,
+  "the trailing dimension of the results is half the width of source trailing dimension",
+  "source", result,
   [{
     [&]() -> ::mlir::VectorType {
       auto vectorType = ::llvm::cast<mlir::VectorType>($_self);
@@ -559,63 +559,67 @@ class ResultIsHalfSourceVectorType<string result> : TypesMatchWith<
   }]
 >;
 
-def Vector_DeinterleaveOp :
-  Vector_Op<"deinterleave", [Pure,
-    PredOpTrait<"trailing dimension of input vector must be an even number",
+def SourceVectorEvenElementCount : PredOpTrait<
+  "the trailing dimension of the source vector has an even number of elements",
     CPred<[{
       [&](){
         auto srcVec = getSourceVectorType();
         return srcVec.getDimSize(srcVec.getRank() - 1) % 2 == 0;
       }()
-    }]>>,
+    }]>
+>;
+
+def Vector_DeinterleaveOp :
+  Vector_Op<"deinterleave", [Pure,
+    SourceVectorEvenElementCount,
     ResultIsHalfSourceVectorType<"res1">,
-    ResultIsHalfSourceVectorType<"res2">,
     AllTypesMatch<["res1", "res2"]>
     ]> {
       let summary = "constructs two vectors by deinterleaving an input vector";
       let description = [{
         The deinterleave operation constructs two vectors from a single input
-        vector. Each new vector is the collection of even and odd elements
-        from the input, respectively. This is essentially the inverse of an
-        interleave operation.
+        vector. The first result vector contains the elements from even indexes
+        of the input, and the second contains elements from odd indexes. This is
+        the inverse of a `vector.interleave` operation.
 
-        Each output's size is half of the input vector's trailing dimension
-        for the n-D case and only dimension for 1-D cases. It is not possible
-        to conduct the operation on 0-D inputs or vectors where the size of
-        the (trailing) dimension is 1.
+        Each output's trailing dimension is half of the size of the input
+        vector's trailing dimension. This operation requires the input vector
+        to have a rank > 0 and an even number of elements in its trailing
+        dimension.
 
         The operation supports scalable vectors.
 
         Example:
         ```mlir
-        %0 = vector.deinterleave %a
-                   : vector<[4]xi32>     ; yields vector<[2]xi32>, vector<[2]xi32>
-        %1 = vector.deinterleave %b
-                   : vector<8xi8>        ; yields vector<4xi8>, vector<4xi8>
-        %2 = vector.deinterleave %c
-                   : vector<2x8xf32>     ; yields vector<2x4xf32>, vector<2x4xf32>
-        %3 = vector.deinterleave %d
-                   : vector<2x4x[6]xf64> ; yields vector<2x4x[3]xf64>, vector<2x4x[3]xf64>
+        %0, %1 = vector.deinterleave %a
+                   :vector<8xi8> -> vector<4xi8>
+        %2, %3 = vector.deinterleave %b
+                   : vector<2x8xi8> -> vector<2x4xi8>
+        %4, %5 = vector.deinterleave %b
+                   : vector<2x8x4xi8> -> vector<2x8x2xi8>
+        %6, %7 = vector.deinterleave %c
+                   : vector<[8]xf32> -> vector<[4]xf32>
+        %8, %9 = vector.deinterleave %d
+                   : vector<2x[6]xf64> -> vector<2x[3]xf64>
+        %10, %11 = vector.deinterleave %d
+                   : vector<2x4x[6]xf64> -> vector<2x4x[3]xf64>
         ```
       }];
 
-      let arguments = (ins AnyVector:$input);
+      let arguments = (ins AnyVector:$source);
       let results = (outs AnyVector:$res1, AnyVector:$res2);
 
       let assemblyFormat = [{
-        $input attr-dict `:` type($input)
+        $source attr-dict `:` type($source) `->` type($res1)
       }];
 
       let extraClassDeclaration = [{
         VectorType getSourceVectorType() {
-          return ::llvm::cast<VectorType>(getInput().getType());
+          return ::llvm::cast<VectorType>(getSource().getType());
         }
-        VectorType getResultOneVectorType() {
+        VectorType getResultVectorType() {
           return ::llvm::cast<VectorType>(getRes1().getType());
         }
-        VectorType getResultTwoVectorType() {
-          return ::llvm::cast<VectorType>(getRes2().getType());
-        }
       }];
     }
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 25cacc6fdf93d..1516f51fe1458 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1802,23 +1802,23 @@ func.func @invalid_outerproduct1(%src : memref<?xf32>) {
 // -----
 
 func.func @deinterleave_zero_dim_fail(%vec : vector<f32>) {
-  // expected-error @+1 {{'vector.deinterleave' 'input' must be vector of any type values, but got 'vector<f32>'}}
-  %0, %1 = vector.deinterleave %vec : vector<f32> 
+  // expected-error @+1 {{'vector.deinterleave' op operand #0 must be vector of any type values, but got 'vector<f32>}}
+  %0, %1 = vector.deinterleave %vec : vector<f32> -> vector<f32>
   return
 }
 
 // -----
 
 func.func @deinterleave_one_dim_fail(%vec : vector<1xf32>) {
-  // expected-error @+1 {{'vector.deinterleave' op failed to verify that trailing dimension of input vector must be an even number}}
-  %0, %1 = vector.deinterleave %vec : vector<1xf32>
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the source vector has an even number of elements}}
+  %0, %1 = vector.deinterleave %vec : vector<1xf32> -> vector<1xf32>
   return
 }
 
 // -----
 
 func.func @deinterleave_oversized_output_fail(%vec : vector<4xf32>) {
-  // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}}
   %0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<8xf32>, vector<8xf32>)
   return
 }
@@ -1826,7 +1826,7 @@ func.func @deinterleave_oversized_output_fail(%vec : vector<4xf32>) {
 // -----
 
 func.func @deinterleave_output_dim_size_mismatch(%vec : vector<4xf32>) {
-  // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}}
   %0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<4xf32>, vector<2xf32>)
   return
 }
@@ -1834,7 +1834,7 @@ func.func @deinterleave_output_dim_size_mismatch(%vec : vector<4xf32>) {
 // -----
 
 func.func @deinterleave_n_dim_rank_fail(%vec : vector<2x3x4xf32>) {
-  // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}}
   %0, %1 = "vector.deinterleave" (%vec) : (vector<2x3x4xf32>) -> (vector<2x3x4xf32>, vector<2x3x2xf32>)
   return
 }
@@ -1842,7 +1842,7 @@ func.func @deinterleave_n_dim_rank_fail(%vec : vector<2x3x4xf32>) {
 // -----
 
 func.func @deinterleave_scalable_dim_size_fail(%vec : vector<2x[4]xf32>) {
-  // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that all of {res1, res2} have same type}}
   %0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<2x[1]xf32>)
   return
 }
@@ -1850,7 +1850,7 @@ func.func @deinterleave_scalable_dim_size_fail(%vec : vector<2x[4]xf32>) {
 // -----
 
 func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
-  // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that all of {res1, res2} have same type}}
   %0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<[2]xf32>)
   return
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index a6a992f23a4ba..9d8101d3eee97 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1119,42 +1119,42 @@ func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>)
 
 // CHECK-LABEL: @deinterleave_1d
 func.func @deinterleave_1d(%arg: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
-  // CHECK: vector.deinterleave %{{.*}} : vector<4xf32>
-  %0, %1 = vector.deinterleave %arg : vector<4xf32>
+  // CHECK: vector.deinterleave %{{.*}} : vector<4xf32> -> vector<2xf32>
+  %0, %1 = vector.deinterleave %arg : vector<4xf32> -> vector<2xf32>
   return %0, %1 : vector<2xf32>, vector<2xf32>
 }
 
 // CHECK-LABEL: @deinterleave_1d_scalable
 func.func @deinterleave_1d_scalable(%arg: vector<[4]xf32>) -> (vector<[2]xf32>, vector<[2]xf32>) {
-  // CHECK: vector.deinterleave %{{.*}} : vector<[4]xf32>
-  %0, %1 = vector.deinterleave %arg : vector<[4]xf32>
+  // CHECK: vector.deinterleave %{{.*}} : vector<[4]xf32> -> vector<[2]xf32>
+  %0, %1 = vector.deinterleave %arg : vector<[4]xf32> -> vector<[2]xf32>
   return %0, %1 : vector<[2]xf32>, vector<[2]xf32>
 }
 
 // CHECK-LABEL: @deinterleave_2d
 func.func @deinterleave_2d(%arg: vector<3x4xf32>) -> (vector<3x2xf32>, vector<3x2xf32>) {
-  // CHECK: vector.deinterleave %{{.*}} : vector<3x4xf32>
-  %0, %1 = vector.deinterleave %arg : vector<3x4xf32>
+  // CHECK: vector.deinterleave %{{.*}} : vector<3x4xf32> -> vector<3x2xf32>
+  %0, %1 = vector.deinterleave %arg : vector<3x4xf32> -> vector<3x2xf32>
   return %0, %1 : vector<3x2xf32>, vector<3x2xf32>
 }
 
 // CHECK-LABEL: @deinterleave_2d_scalable
 func.func @deinterleave_2d_scalable(%arg: vector<3x[4]xf32>) -> (vector<3x[2]xf32>, vector<3x[2]xf32>) {
-  // CHECK: vector.deinterleave %{{.*}} : vector<3x[4]xf32>
-  %0, %1 = vector.deinterleave %arg : vector<3x[4]xf32>
+  // CHECK: vector.deinterleave %{{.*}} : vector<3x[4]xf32> -> vector<3x[2]xf32>
+  %0, %1 = vector.deinterleave %arg : vector<3x[4]xf32> -> vector<3x[2]xf32>
   return %0, %1 : vector<3x[2]xf32>, vector<3x[2]xf32>
 }
 
 // CHECK-LABEL: @deinterleave_nd
 func.func @deinterleave_nd(%arg: vector<2x3x4x6xf32>) -> (vector<2x3x4x3xf32>, vector<2x3x4x3xf32>) {
-  // CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x6xf32>
-  %0, %1 = vector.deinterleave %arg : vector<2x3x4x6xf32>
+  // CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x6xf32> -> vector<2x3x4x3xf32>
+  %0, %1 = vector.deinterleave %arg : vector<2x3x4x6xf32> -> vector<2x3x4x3xf32>
   return %0, %1 : vector<2x3x4x3xf32>, vector<2x3x4x3xf32>
 }
 
 // CHECK-LABEL: @deinterleave_nd_scalable
 func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>) {
-  // CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x[6]xf32>
-  %0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32>
+  // CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32>
+  %0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32>
   return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
-}
\ No newline at end of file
+}

>From 3d79b273288712276770ad8a9efcae6b2e8e0260 Mon Sep 17 00:00:00 2001
From: "Mubashar.Ahmad at arm.com" <mubashar.ahmad at arm.com>
Date: Thu, 16 May 2024 12:28:34 +0000
Subject: [PATCH 3/4] [mlir][VectorOps] Add deinterleave operation to vector
 dialect

The deinterleave operation constructs two vectors from a single input
vector. Each new vector is the collection of even and odd elements
from the input, respectively. This is essentially the inverse of an
interleave operation.

Each output's size is half of the input vector's trailing dimension
for the n-D case and only dimension for 1-D cases. It is not possible
to conduct the operation on 0-D inputs or vectors where the size of
the (trailing) dimension is 1.

The operation supports scalable vectors.

Example:
```mlir
%0 = vector.deinterleave %a
           : vector<[4]xi32> -> vector<[2]xi32>
%1 = vector.deinterleave %b
           : vector<8xi8> -> vector<4xi8>
%2 = vector.deinterleave %c
           : vector<2x8xf32> -> vector<2x4xf32>
%3 = vector.deinterleave %d
           : vector<2x4x[6]xf64> -> vector<2x4x[3]xf64>
```
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 22 +++++++++----------
 1 file changed, 11 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index bfbb40405c3c1..2bb7540ef0b0f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -561,12 +561,12 @@ class ResultIsHalfSourceVectorType<string result> : TypesMatchWith<
 
 def SourceVectorEvenElementCount : PredOpTrait<
   "the trailing dimension of the source vector has an even number of elements",
-    CPred<[{
-      [&](){
-        auto srcVec = getSourceVectorType();
-        return srcVec.getDimSize(srcVec.getRank() - 1) % 2 == 0;
-      }()
-    }]>
+  CPred<[{
+    [&](){
+      auto srcVec = getSourceVectorType();
+      return srcVec.getDimSize(srcVec.getRank() - 1) % 2 == 0;
+    }()
+  }]>
 >;
 
 def Vector_DeinterleaveOp :
@@ -592,16 +592,16 @@ def Vector_DeinterleaveOp :
         Example:
         ```mlir
         %0, %1 = vector.deinterleave %a
-                   :vector<8xi8> -> vector<4xi8>
+                   : vector<8xi8> -> vector<4xi8>
         %2, %3 = vector.deinterleave %b
                    : vector<2x8xi8> -> vector<2x4xi8>
-        %4, %5 = vector.deinterleave %b
+        %4, %5 = vector.deinterleave %c
                    : vector<2x8x4xi8> -> vector<2x8x2xi8>
-        %6, %7 = vector.deinterleave %c
+        %6, %7 = vector.deinterleave %d
                    : vector<[8]xf32> -> vector<[4]xf32>
-        %8, %9 = vector.deinterleave %d
+        %8, %9 = vector.deinterleave %e
                    : vector<2x[6]xf64> -> vector<2x[3]xf64>
-        %10, %11 = vector.deinterleave %d
+        %10, %11 = vector.deinterleave %f
                    : vector<2x4x[6]xf64> -> vector<2x4x[3]xf64>
         ```
       }];

>From 4449999a62ad09639f209f0764790e1ea9421565 Mon Sep 17 00:00:00 2001
From: "Mubashar.Ahmad at arm.com" <mubashar.ahmad at arm.com>
Date: Wed, 22 May 2024 09:01:18 +0000
Subject: [PATCH 4/4] [mlir][vector] Implement lowering for 1D
 vector.deinterleave operations

This patchs implements the lowering of vector.deinterleave
for 1D vectors.

For fixed vector types, the operation is lowered to two
llvm shufflevector operations. One for even indexed
elements and the other for odd indexed elements. A poison
operation is used to satisfy the parameters of the
shufflevector parameters.

For scalable vectors, the llvm vector.deinterleave2
intrinsic is used for lowering. As such the results
found by extraction and used to form the result
struct for the intrinsic.
---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 63 ++++++++++++++++++-
 .../VectorToLLVM/vector-to-llvm.mlir          | 22 +++++++
 .../Vector/CPU/ArmSVE/test-deinterleave.mlir  | 18 ++++++
 .../Dialect/Vector/CPU/test-deinterleave.mlir | 18 ++++++
 4 files changed, 120 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-deinterleave.mlir
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/test-deinterleave.mlir

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index fe6bcc1c8b667..94a2954f9a247 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1761,6 +1761,66 @@ struct VectorInterleaveOpLowering
   }
 };
 
+/// Conversion pattern for a `vector.deinterleave`.
+/// Support available for fixed-sized vectors and scalable vectors.
+
+struct VectorDeinterleaveOpLowering
+    : public ConvertOpToLLVMPattern<vector::DeinterleaveOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = deinterleaveOp.getResultVectorType();
+    VectorType sourceType = deinterleaveOp.getSourceVectorType();
+    auto loc = deinterleaveOp.getLoc();
+
+    if (resultType.getRank() != 1)
+        return rewriter.notifyMatchFailure(deinterleaveOp,
+                                           "deinterleaveOp not rank 1");
+
+    if (resultType.isScalable()) {
+        auto llvmTypeConverter = this->getTypeConverter();
+        auto deinterleaveResults = deinterleaveOp.getResultTypes();
+        auto packedOpResults = llvmTypeConverter->packOperationResults(deinterleaveResults);
+        auto intrinsic = rewriter.create<LLVM::vector_deinterleave2>(loc, packedOpResults, adaptor.getSource());
+
+        auto resultOne = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsic->getResult(0), 0);
+        auto resultTwo = rewriter.create<LLVM::ExtractValueOp>(loc, intrinsic->getResult(0), 1);
+
+        rewriter.replaceOp(
+          deinterleaveOp, ValueRange{resultOne, resultTwo}
+        );
+        return success();
+    }
+
+    int64_t resultVectorSize = resultType.getNumElements();
+    auto poison = rewriter.create<LLVM::PoisonOp>(loc, sourceType);
+    SmallVector<int32_t> shuffleMaskOne;
+    SmallVector<int32_t> shuffleMaskTwo;
+
+    shuffleMaskOne.reserve(resultVectorSize);
+    shuffleMaskTwo.reserve(resultVectorSize);
+
+    for (int i = 0; i < sourceType.getNumElements(); ++i) {
+      if (i % 2 == 0)
+          shuffleMaskOne.push_back(i);
+      else
+          shuffleMaskTwo.push_back(i);
+    }
+
+    auto evenShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
+      loc, adaptor.getSource(), poison, shuffleMaskOne);
+    auto oddShuffle = rewriter.create<LLVM::ShuffleVectorOp>(
+      loc, adaptor.getSource(), poison, shuffleMaskTwo);
+
+    rewriter.replaceOp(
+      deinterleaveOp, ValueRange{evenShuffle, oddShuffle}
+    );
+    return::success();
+  }
+};
+
 } // namespace
 
 /// Populate the given list with patterns that convert from Vector to LLVM.
@@ -1785,7 +1845,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
                VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
                VectorSplatOpLowering, VectorSplatNdOpLowering,
                VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
-               MaskedReductionOpConversion, VectorInterleaveOpLowering>(
+               MaskedReductionOpConversion, VectorInterleaveOpLowering,
+               VectorDeinterleaveOpLowering>(
       converter);
   // Transfer ops with rank > 1 are handled by VectorToSCF.
   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 439f1e920e392..d1755f0cd3a21 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2546,3 +2546,25 @@ func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]
   %0 = vector.interleave %a, %b : vector<2x[8]xi16>
   return %0 : vector<2x[16]xi16>
 }
+
+// -----
+
+// CHECK-LABEL: @vector_deinterleave_1d
+// CHECK-SAME:  (%{{.*}}: vector<4xi32>) -> (vector<2xi32>, vector<2xi32>)
+func.func @vector_deinterleave_1d(%a: vector<4xi32>) -> (vector<2xi32>, vector<2xi32>) {
+  // CHECK: llvm.mlir.poison : vector<4xi32>
+  // CHECK: llvm.shufflevector %{{.*}}, %{{.*}} [0, 2] : vector<4xi32> 
+  // CHECK: llvm.shufflevector %{{.*}}, %{{.*}} [1, 3] : vector<4xi32> 
+  %0, %1 = vector.deinterleave %a : vector<4xi32> -> vector<2xi32>
+  return %0, %1 : vector<2xi32>, vector<2xi32>
+}
+
+// CHECK-LABEL: @vector_deinterleave_1d_scalable
+// CHECK-SAME:  %{{.*}}: vector<[4]xi32>) -> (vector<[2]xi32>, vector<[2]xi32>)
+func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi32>, vector<[2]xi32>) {
+    // CHECK: llvm.intr.vector.deinterleave2
+    // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(vector<[2]xi32>, vector<[2]xi32>)> 
+    // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.struct<(vector<[2]xi32>, vector<[2]xi32>)> 
+    %0, %1 = vector.deinterleave %a : vector<[4]xi32> -> vector<[2]xi32>
+    return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-deinterleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-deinterleave.mlir
new file mode 100644
index 0000000000000..d8cd38ef33037
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-deinterleave.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt %s -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @entry() {
+  %step_vector = llvm.intr.experimental.stepvector : vector<[4]xi8>
+  vector.print %step_vector : vector<[4]xi8>
+  // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 )
+
+  %v1, %v2 = vector.deinterleave %step_vector : vector<[4]xi8> -> vector<[2]xi8>
+  vector.print %v1 : vector<[2]xi8>
+  vector.print %v2 : vector<[2]xi8>
+  // CHECK: ( 0, 2, 4, 6 )
+  // CHECK: ( 1, 3, 5, 7 )
+
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-deinterleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-deinterleave.mlir
new file mode 100644
index 0000000000000..4915a3cde124d
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-deinterleave.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt %s -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @entry() {
+  %v0 = arith.constant dense<[1, 2, 3, 4]> : vector<4xi8>
+  vector.print %v0 : vector<4xi8>
+  // CHECK: ( 1, 2, 3, 4 )
+
+  %v1, %v2 = vector.deinterleave %v0 : vector<4xi8> -> vector<2xi8>
+  vector.print %v1 : vector<2xi8>
+  vector.print %v2 : vector<2xi8>
+  // CHECK: ( 1, 3 )
+  // CHECK: ( 2, 4 )
+
+  return
+}



More information about the Mlir-commits mailing list