[Mlir-commits] [mlir] [mlir][VectorOps] Add deinterleave operation to vector dialect (PR #92409)

Mubashar Ahmad llvmlistbot at llvm.org
Thu May 16 07:53:21 PDT 2024


https://github.com/mub-at-arm created https://github.com/llvm/llvm-project/pull/92409

The deinterleave operation constructs two vectors from a single input
vector. Each new vector is the collection of even and odd elements
from the input, respectively. This is essentially the inverse of an
interleave operation.
    
Each output's size is half of the input vector's trailing dimension
for the n-D case and only dimension for 1-D cases. It is not possible
to conduct the operation on 0-D inputs or vectors where the size of
the (trailing) dimension is 1.
    
The operation supports scalable vectors.
    
Example:
```mlir
%0 = vector.deinterleave %a
                   : vector<[4]xi32>     ; yields vector<[2]xi32>, vector<[2]xi32>
%1 = vector.deinterleave %b
                   : vector<8xi8>        ; yields vector<4xi8>, vector<4xi8>
%2 = vector.deinterleave %c
                   : vector<2x8xf32>     ; yields vector<2x4xf32>, vector<2x4xf32>
%3 = vector.deinterleave %d
                   : vector<2x4x[6]xf64> ; yields vector<2x4x[3]xf64>, vector<2x4x[3]xf64>
 ```

>From 9c1ffb2987fb378a6a22f6f1500da287f1cd774d Mon Sep 17 00:00:00 2001
From: "Mubashar.Ahmad at arm.com" <mubashar.ahmad at arm.com>
Date: Mon, 13 May 2024 15:32:21 +0000
Subject: [PATCH 1/3] [LLVMAENG-1417] Add llvm.vector.deinterleave2 intrinsic

Added LLVM vector deinterleave2 intrinsic to the MLIR LLVM
dialect. The deinterleave intrinsic takes a vector and
returns two vectors with the first having even elements
and second with odd elements from the input array. Reverse
of interleave.
---
 .../include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td |  6 ++++++
 mlir/test/Dialect/LLVMIR/roundtrip.mlir             |  7 +++++++
 mlir/test/Target/LLVMIR/Import/intrinsic.ll         |  9 +++++++++
 mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir      | 13 +++++++++++++
 4 files changed, 35 insertions(+)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 759cbe6c15647..97213dccfac07 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -1074,6 +1074,12 @@ def LLVM_vector_interleave2
         ]>,
         Arguments<(ins LLVM_AnyVector:$vec1, LLVM_AnyVector:$vec2)>;
 
+def LLVM_vector_deinterleave2
+    : LLVM_IntrOp<"vector.deinterleave2",
+        /*overloadedResults=*/[], /*overloadedOperands=*/[0],
+        /*traits=*/[Pure], /*numResults=*/2>,
+        Arguments<(ins LLVM_AnyVector:$vec)>;
+
 //
 // LLVM Vector Predication operations.
 //
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 3b94db389f549..d1bf3368a996d 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -349,6 +349,13 @@ func.func @vector_interleave2(%vec1: vector<[4]xf16>, %vec2 : vector<[4]xf16>) {
   return
 }
 
+// CHECK-LABEL: @vector_deinterleave2
+func.func @vector_deinterleave2(%vec: vector<[8]xf16>) {
+  // CHECK: = "llvm.intr.vector.deinterleave2" ({{.*}}) : (vector<[8]xf16) -> !llvm.struct<(vector<[4]xf16>, vector<[4]xf16>)>
+  %0 = "llvm.intr.vector.deinterleave2"(%vec) : (vector<[8]xf16>) -> !llvm.struct<(vector<[4]xf16>, vector<[4]xf16>)>
+  return
+}
+
 // CHECK-LABEL: @alloca
 func.func @alloca(%size : i64) {
   // CHECK: llvm.alloca %{{.*}} x i32 : (i64) -> !llvm.ptr
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index bf6847a32ff4f..28d9aef2432d8 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -786,6 +786,15 @@ define void @vector_extract(<vscale x 4 x float> %0) {
   ret void
 }
 
+; CHECK-LABEL: llvm.func @vector_deinterleave2
+define void @vector_deinterleave2(<4 x double> %0, <vscale x 8 x i32> %1) {
+  ; llvm.intr.vector.deinterleave2 %{{.*}} : (vector<4xf64>) -> !llvm.struct<(vector<2xf64>, vector<2xf64>)>
+  %3 = call { <2 x double>, <2 x double> } @llvm.vector.deinterleave2.v4f64(<4 x double> %0);
+  ; llvm.intr.vector.deinterleave2 %{{.*}} : (vector<[8]xi32>) -> !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)>
+  %4 = call { <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.vector.deinterleave2.nxv8i32(<vscale x 8 x i32> %1);
+  ret void
+}
+
 ; CHECK-LABEL:  llvm.func @vector_predication_intrinsics
 define void @vector_predication_intrinsics(<8 x i32> %0, <8 x i32> %1, <8 x float> %2, <8 x float> %3, <8 x i64> %4, <8 x double> %5, <8 x ptr> %6, i32 %7, float %8, ptr %9, ptr %10, <8 x i1> %11, i32 %12) {
   ; CHECK: "llvm.intr.vp.add"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32>
diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index db5184a63d983..238c3e4263cb0 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -942,6 +942,17 @@ llvm.func @vector_insert_extract(%f256: vector<8xi32>, %f128: vector<4xi32>,
   llvm.return
 }
 
+// CHECK-LABEL: @vector_deinterleave2
+llvm.func @vector_deinterleave2(%vec1: vector<4xf64>, %vec2: vector<[8]xi32>) {
+  // CHECK: call { <2 x double>, <2 x double> } @llvm.vector.deinterleave2.v4f64(<4 x double> %{{.*}})
+  %0 = "llvm.intr.vector.deinterleave2" (%vec1) :
+              (vector<4xf64>) -> !llvm.struct<(vector<2xf64>, vector<2xf64>)>
+  // CHECK: call { <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.vector.deinterleave2.nxv8i32(<vscale x 8 x i32> %{{.*}})
+  %1 = "llvm.intr.vector.deinterleave2" (%vec2) :
+              (vector<[8]xi32>) -> !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)>
+  llvm.return
+}
+
 // CHECK-LABEL: @lifetime
 llvm.func @lifetime(%p: !llvm.ptr) {
   // CHECK: call void @llvm.lifetime.start
@@ -1148,6 +1159,8 @@ llvm.func @experimental_constrained_fptrunc(%s: f64, %v: vector<4xf32>) {
 // CHECK-DAG: declare <8 x i32> @llvm.vector.extract.v8i32.nxv4i32(<vscale x 4 x i32>, i64 immarg)
 // CHECK-DAG: declare <4 x i32> @llvm.vector.extract.v4i32.nxv4i32(<vscale x 4 x i32>, i64 immarg)
 // CHECK-DAG: declare <2 x i32> @llvm.vector.extract.v2i32.v8i32(<8 x i32>, i64 immarg)
+// CHECK-DAG: declare { <2 x double>, <2 x double> } @llvm.vector.deinterleave2.v4f64(<4 x double>)
+// CHECK-DAG: declare { <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.vector.deinterleave2.nxv8i32(<vscale x 8 x i32>)
 // CHECK-DAG: declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture)
 // CHECK-DAG: declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture)
 // CHECK-DAG: declare ptr @llvm.invariant.start.p0(i64 immarg, ptr nocapture)

>From 106b9c4686b1068f56651bbae40123af9251f028 Mon Sep 17 00:00:00 2001
From: "Mubashar.Ahmad at arm.com" <mubashar.ahmad at arm.com>
Date: Mon, 13 May 2024 15:32:21 +0000
Subject: [PATCH 2/3] [mlir][llvm] Add llvm.vector.deinterleave2 intrinsic

Added LLVM vector deinterleave2 intrinsic to the MLIR LLVM
dialect. The deinterleave intrinsic takes a vector and
returns two vectors with the first having even elements
and second with odd elements from the input array. Reverse
of interleave.
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td | 4 ++--
 mlir/test/Dialect/LLVMIR/roundtrip.mlir              | 2 +-
 mlir/test/Target/LLVMIR/Import/intrinsic.ll          | 4 ++--
 3 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 97213dccfac07..bd347d0cf6308 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -1075,9 +1075,9 @@ def LLVM_vector_interleave2
         Arguments<(ins LLVM_AnyVector:$vec1, LLVM_AnyVector:$vec2)>;
 
 def LLVM_vector_deinterleave2
-    : LLVM_IntrOp<"vector.deinterleave2",
+    : LLVM_OneResultIntrOp<"vector.deinterleave2",
         /*overloadedResults=*/[], /*overloadedOperands=*/[0],
-        /*traits=*/[Pure], /*numResults=*/2>,
+        /*traits=*/[Pure]>,
         Arguments<(ins LLVM_AnyVector:$vec)>;
 
 //
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index d1bf3368a996d..410122df1c14d 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -351,7 +351,7 @@ func.func @vector_interleave2(%vec1: vector<[4]xf16>, %vec2 : vector<[4]xf16>) {
 
 // CHECK-LABEL: @vector_deinterleave2
 func.func @vector_deinterleave2(%vec: vector<[8]xf16>) {
-  // CHECK: = "llvm.intr.vector.deinterleave2" ({{.*}}) : (vector<[8]xf16) -> !llvm.struct<(vector<[4]xf16>, vector<[4]xf16>)>
+  // CHECK: = "llvm.intr.vector.deinterleave2"({{.*}}) : (vector<[8]xf16>) -> !llvm.struct<(vector<[4]xf16>, vector<[4]xf16>)>
   %0 = "llvm.intr.vector.deinterleave2"(%vec) : (vector<[8]xf16>) -> !llvm.struct<(vector<[4]xf16>, vector<[4]xf16>)>
   return
 }
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index 28d9aef2432d8..e43024ff868e5 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -788,9 +788,9 @@ define void @vector_extract(<vscale x 4 x float> %0) {
 
 ; CHECK-LABEL: llvm.func @vector_deinterleave2
 define void @vector_deinterleave2(<4 x double> %0, <vscale x 8 x i32> %1) {
-  ; llvm.intr.vector.deinterleave2 %{{.*}} : (vector<4xf64>) -> !llvm.struct<(vector<2xf64>, vector<2xf64>)>
+  ; CHECK: "llvm.intr.vector.deinterleave2"(%{{.*}}) : (vector<4xf64>) -> !llvm.struct<(vector<2xf64>, vector<2xf64>)>
   %3 = call { <2 x double>, <2 x double> } @llvm.vector.deinterleave2.v4f64(<4 x double> %0);
-  ; llvm.intr.vector.deinterleave2 %{{.*}} : (vector<[8]xi32>) -> !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)>
+  ; CHECK: "llvm.intr.vector.deinterleave2"(%{{.*}}) : (!llvm.vec<? x 8 x i32>) -> !llvm.struct<(vec<? x 4 x i32>, vec<? x 4 x i32>)>
   %4 = call { <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.vector.deinterleave2.nxv8i32(<vscale x 8 x i32> %1);
   ret void
 }

>From 1f43c32cf699a6bb956db25f595fd900154bd661 Mon Sep 17 00:00:00 2001
From: "Mubashar.Ahmad at arm.com" <mubashar.ahmad at arm.com>
Date: Thu, 16 May 2024 12:28:34 +0000
Subject: [PATCH 3/3] [mlir][VectorOps] Add deinterleave operation to vector
 dialect

The deinterleave operation constructs two vectors from a single input
vector. Each new vector is the collection of even and odd elements
from the input, respectively. This is essentially the inverse of an
interleave operation.

Each output's size is half of the input vector's trailing dimension
for the n-D case and only dimension for 1-D cases. It is not possible
to conduct the operation on 0-D inputs or vectors where the size of
the (trailing) dimension is 1.

The operation supports scalable vectors.

Example:
```mlir
%0 = vector.deinterleave %a
           : vector<[4]xi32>     ; yields vector<[2]xi32>, vector<[2]xi32>
%1 = vector.deinterleave %b
           : vector<8xi8>        ; yields vector<4xi8>, vector<4xi8>
%2 = vector.deinterleave %c
           : vector<2x8xf32>     ; yields vector<2x4xf32>, vector<2x4xf32>
%3 = vector.deinterleave %d
           : vector<2x4x[6]xf64> ; yields vector<2x4x[3]xf64>, vector<2x4x[3]xf64>
```
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 76 +++++++++++++++++++
 mlir/test/Dialect/Vector/invalid.mlir         | 56 ++++++++++++++
 mlir/test/Dialect/Vector/ops.mlir             | 42 ++++++++++
 3 files changed, 174 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 332b5ad08ced9..1e7e0a1715178 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -543,6 +543,82 @@ def Vector_InterleaveOp :
   }];
 }
 
+class ResultIsHalfSourceVectorType<string result> : TypesMatchWith<
+  "type of 'input' is double the width of results",
+  "input", result,
+  [{
+    [&]() -> ::mlir::VectorType {
+      auto vectorType = ::llvm::cast<mlir::VectorType>($_self);
+      ::mlir::VectorType::Builder builder(vectorType);
+      auto lastDim = vectorType.getRank() - 1;
+      auto newDimSize = vectorType.getDimSize(lastDim) / 2;;
+      if (newDimSize <= 0)
+         return vectorType; // (invalid input type)
+      return builder.setDim(lastDim, newDimSize);
+    }()
+  }]
+>;
+
+def Vector_DeinterleaveOp :
+  Vector_Op<"deinterleave", [Pure,
+    PredOpTrait<"trailing dimension of input vector must be an even number",
+    CPred<[{
+      [&](){
+        auto srcVec = getSourceVectorType();
+        return srcVec.getDimSize(srcVec.getRank() - 1) % 2 == 0;
+      }()
+    }]>>,
+    ResultIsHalfSourceVectorType<"res1">,
+    ResultIsHalfSourceVectorType<"res2">,
+    AllTypesMatch<["res1", "res2"]>
+    ]> {
+      let summary = "constructs two vectors by deinterleaving an input vector";
+      let description = [{
+        The deinterleave operation constructs two vectors from a single input
+        vector. Each new vector is the collection of even and odd elements
+        from the input, respectively. This is essentially the inverse of an
+        interleave operation.
+
+        Each output's size is half of the input vector's trailing dimension
+        for the n-D case and only dimension for 1-D cases. It is not possible
+        to conduct the operation on 0-D inputs or vectors where the size of
+        the (trailing) dimension is 1.
+
+        The operation supports scalable vectors.
+
+        Example:
+        ```mlir
+        %0 = vector.deinterleave %a
+                   : vector<[4]xi32>     ; yields vector<[2]xi32>, vector<[2]xi32>
+        %1 = vector.deinterleave %b
+                   : vector<8xi8>        ; yields vector<4xi8>, vector<4xi8>
+        %2 = vector.deinterleave %c
+                   : vector<2x8xf32>     ; yields vector<2x4xf32>, vector<2x4xf32>
+        %3 = vector.deinterleave %d
+                   : vector<2x4x[6]xf64> ; yields vector<2x4x[3]xf64>, vector<2x4x[3]xf64>
+        ```
+      }];
+
+      let arguments = (ins AnyVector:$input);
+      let results = (outs AnyVector:$res1, AnyVector:$res2);
+
+      let assemblyFormat = [{
+        $input attr-dict `:` type($input)
+      }];
+
+      let extraClassDeclaration = [{
+        VectorType getSourceVectorType() {
+          return ::llvm::cast<VectorType>(getInput().getType());
+        }
+        VectorType getResultOneVectorType() {
+          return ::llvm::cast<VectorType>(getRes1().getType());
+        }
+        VectorType getResultTwoVectorType() {
+          return ::llvm::cast<VectorType>(getRes2().getType());
+        }
+      }];
+    }
+
 def Vector_ExtractElementOp :
   Vector_Op<"extractelement", [Pure,
      TypesMatchWith<"result type matches element type of vector operand",
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c9f7e9c6e2fb0..25cacc6fdf93d 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1798,3 +1798,59 @@ func.func @invalid_outerproduct1(%src : memref<?xf32>) {
   // expected-error @+1 {{'vector.outerproduct' op expected 1-d vector for operand #1}}
   %op = vector.outerproduct %0, %1 : vector<[4]x[4]xf32>, vector<[4]xf32>
 }
+
+// -----
+
+func.func @deinterleave_zero_dim_fail(%vec : vector<f32>) {
+  // expected-error @+1 {{'vector.deinterleave' 'input' must be vector of any type values, but got 'vector<f32>'}}
+  %0, %1 = vector.deinterleave %vec : vector<f32> 
+  return
+}
+
+// -----
+
+func.func @deinterleave_one_dim_fail(%vec : vector<1xf32>) {
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that trailing dimension of input vector must be an even number}}
+  %0, %1 = vector.deinterleave %vec : vector<1xf32>
+  return
+}
+
+// -----
+
+func.func @deinterleave_oversized_output_fail(%vec : vector<4xf32>) {
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
+  %0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<8xf32>, vector<8xf32>)
+  return
+}
+
+// -----
+
+func.func @deinterleave_output_dim_size_mismatch(%vec : vector<4xf32>) {
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
+  %0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<4xf32>, vector<2xf32>)
+  return
+}
+
+// -----
+
+func.func @deinterleave_n_dim_rank_fail(%vec : vector<2x3x4xf32>) {
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
+  %0, %1 = "vector.deinterleave" (%vec) : (vector<2x3x4xf32>) -> (vector<2x3x4xf32>, vector<2x3x2xf32>)
+  return
+}
+
+// -----
+
+func.func @deinterleave_scalable_dim_size_fail(%vec : vector<2x[4]xf32>) {
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
+  %0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<2x[1]xf32>)
+  return
+}
+
+// -----
+
+func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
+  %0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<[2]xf32>)
+  return
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 79a80be4f8b20..a6a992f23a4ba 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1116,3 +1116,45 @@ func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>)
   %0 = vector.interleave %a, %b : vector<2x[2]xf64>
   return %0 : vector<2x[4]xf64>
 }
+
+// CHECK-LABEL: @deinterleave_1d
+func.func @deinterleave_1d(%arg: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
+  // CHECK: vector.deinterleave %{{.*}} : vector<4xf32>
+  %0, %1 = vector.deinterleave %arg : vector<4xf32>
+  return %0, %1 : vector<2xf32>, vector<2xf32>
+}
+
+// CHECK-LABEL: @deinterleave_1d_scalable
+func.func @deinterleave_1d_scalable(%arg: vector<[4]xf32>) -> (vector<[2]xf32>, vector<[2]xf32>) {
+  // CHECK: vector.deinterleave %{{.*}} : vector<[4]xf32>
+  %0, %1 = vector.deinterleave %arg : vector<[4]xf32>
+  return %0, %1 : vector<[2]xf32>, vector<[2]xf32>
+}
+
+// CHECK-LABEL: @deinterleave_2d
+func.func @deinterleave_2d(%arg: vector<3x4xf32>) -> (vector<3x2xf32>, vector<3x2xf32>) {
+  // CHECK: vector.deinterleave %{{.*}} : vector<3x4xf32>
+  %0, %1 = vector.deinterleave %arg : vector<3x4xf32>
+  return %0, %1 : vector<3x2xf32>, vector<3x2xf32>
+}
+
+// CHECK-LABEL: @deinterleave_2d_scalable
+func.func @deinterleave_2d_scalable(%arg: vector<3x[4]xf32>) -> (vector<3x[2]xf32>, vector<3x[2]xf32>) {
+  // CHECK: vector.deinterleave %{{.*}} : vector<3x[4]xf32>
+  %0, %1 = vector.deinterleave %arg : vector<3x[4]xf32>
+  return %0, %1 : vector<3x[2]xf32>, vector<3x[2]xf32>
+}
+
+// CHECK-LABEL: @deinterleave_nd
+func.func @deinterleave_nd(%arg: vector<2x3x4x6xf32>) -> (vector<2x3x4x3xf32>, vector<2x3x4x3xf32>) {
+  // CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x6xf32>
+  %0, %1 = vector.deinterleave %arg : vector<2x3x4x6xf32>
+  return %0, %1 : vector<2x3x4x3xf32>, vector<2x3x4x3xf32>
+}
+
+// CHECK-LABEL: @deinterleave_nd_scalable
+func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>) {
+  // CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x[6]xf32>
+  %0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32>
+  return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
+}
\ No newline at end of file



More information about the Mlir-commits mailing list