[Mlir-commits] [mlir] bf4d99e - [mlir][vector] Add deinterleave operation to vector dialect (#92409)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 23 01:59:51 PDT 2024


Author: Mubashar Ahmad
Date: 2024-05-23T09:59:47+01:00
New Revision: bf4d99e16789dd711eb61b36ce92b8519f450dd5

URL: https://github.com/llvm/llvm-project/commit/bf4d99e16789dd711eb61b36ce92b8519f450dd5
DIFF: https://github.com/llvm/llvm-project/commit/bf4d99e16789dd711eb61b36ce92b8519f450dd5.diff

LOG: [mlir][vector] Add deinterleave operation to vector dialect (#92409)

The deinterleave operation constructs two vectors from a single input
vector. The first result vector contains the elements from even indexes
of the input, and the second contains elements from odd indexes. This is
the inverse of a `vector.interleave` operation.

Each output's trailing dimension is half of the size of the input
vector's trailing dimension. This operation requires the input vector
to have a rank > 0 and an even number of elements in its trailing
dimension.

The operation supports scalable vectors.

Example:
```mlir
%0, %1 = vector.deinterleave %a
           : vector<8xi8> -> vector<4xi8>
%2, %3 = vector.deinterleave %b
           : vector<2x8xi8> -> vector<2x4xi8>
%4, %5 = vector.deinterleave %c
           : vector<2x8x4xi8> -> vector<2x8x2xi8>
%6, %7 = vector.deinterleave %d
           : vector<[8]xf32> -> vector<[4]xf32>
%8, %9 = vector.deinterleave %e
           : vector<2x[6]xf64> -> vector<2x[3]xf64>
%10, %11 = vector.deinterleave %f
           : vector<2x4x[6]xf64> -> vector<2x4x[3]xf64>
```

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 332b5ad08ced9..2bb7540ef0b0f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -543,6 +543,86 @@ def Vector_InterleaveOp :
   }];
 }
 
+class ResultIsHalfSourceVectorType<string result> : TypesMatchWith<
+  "the trailing dimension of the results is half the width of source trailing dimension",
+  "source", result,
+  [{
+    [&]() -> ::mlir::VectorType {
+      auto vectorType = ::llvm::cast<mlir::VectorType>($_self);
+      ::mlir::VectorType::Builder builder(vectorType);
+      auto lastDim = vectorType.getRank() - 1;
+      auto newDimSize = vectorType.getDimSize(lastDim) / 2;;
+      if (newDimSize <= 0)
+         return vectorType; // (invalid input type)
+      return builder.setDim(lastDim, newDimSize);
+    }()
+  }]
+>;
+
+def SourceVectorEvenElementCount : PredOpTrait<
+  "the trailing dimension of the source vector has an even number of elements",
+  CPred<[{
+    [&](){
+      auto srcVec = getSourceVectorType();
+      return srcVec.getDimSize(srcVec.getRank() - 1) % 2 == 0;
+    }()
+  }]>
+>;
+
+def Vector_DeinterleaveOp :
+  Vector_Op<"deinterleave", [Pure,
+    SourceVectorEvenElementCount,
+    ResultIsHalfSourceVectorType<"res1">,
+    AllTypesMatch<["res1", "res2"]>
+    ]> {
+      let summary = "constructs two vectors by deinterleaving an input vector";
+      let description = [{
+        The deinterleave operation constructs two vectors from a single input
+        vector. The first result vector contains the elements from even indexes
+        of the input, and the second contains elements from odd indexes. This is
+        the inverse of a `vector.interleave` operation.
+
+        Each output's trailing dimension is half of the size of the input
+        vector's trailing dimension. This operation requires the input vector
+        to have a rank > 0 and an even number of elements in its trailing
+        dimension.
+
+        The operation supports scalable vectors.
+
+        Example:
+        ```mlir
+        %0, %1 = vector.deinterleave %a
+                   : vector<8xi8> -> vector<4xi8>
+        %2, %3 = vector.deinterleave %b
+                   : vector<2x8xi8> -> vector<2x4xi8>
+        %4, %5 = vector.deinterleave %c
+                   : vector<2x8x4xi8> -> vector<2x8x2xi8>
+        %6, %7 = vector.deinterleave %d
+                   : vector<[8]xf32> -> vector<[4]xf32>
+        %8, %9 = vector.deinterleave %e
+                   : vector<2x[6]xf64> -> vector<2x[3]xf64>
+        %10, %11 = vector.deinterleave %f
+                   : vector<2x4x[6]xf64> -> vector<2x4x[3]xf64>
+        ```
+      }];
+
+      let arguments = (ins AnyVector:$source);
+      let results = (outs AnyVector:$res1, AnyVector:$res2);
+
+      let assemblyFormat = [{
+        $source attr-dict `:` type($source) `->` type($res1)
+      }];
+
+      let extraClassDeclaration = [{
+        VectorType getSourceVectorType() {
+          return ::llvm::cast<VectorType>(getSource().getType());
+        }
+        VectorType getResultVectorType() {
+          return ::llvm::cast<VectorType>(getRes1().getType());
+        }
+      }];
+    }
+
 def Vector_ExtractElementOp :
   Vector_Op<"extractelement", [Pure,
      TypesMatchWith<"result type matches element type of vector operand",

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c9f7e9c6e2fb0..1516f51fe1458 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1798,3 +1798,59 @@ func.func @invalid_outerproduct1(%src : memref<?xf32>) {
   // expected-error @+1 {{'vector.outerproduct' op expected 1-d vector for operand #1}}
   %op = vector.outerproduct %0, %1 : vector<[4]x[4]xf32>, vector<[4]xf32>
 }
+
+// -----
+
+func.func @deinterleave_zero_dim_fail(%vec : vector<f32>) {
+  // expected-error @+1 {{'vector.deinterleave' op operand #0 must be vector of any type values, but got 'vector<f32>}}
+  %0, %1 = vector.deinterleave %vec : vector<f32> -> vector<f32>
+  return
+}
+
+// -----
+
+func.func @deinterleave_one_dim_fail(%vec : vector<1xf32>) {
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the source vector has an even number of elements}}
+  %0, %1 = vector.deinterleave %vec : vector<1xf32> -> vector<1xf32>
+  return
+}
+
+// -----
+
+func.func @deinterleave_oversized_output_fail(%vec : vector<4xf32>) {
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}}
+  %0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<8xf32>, vector<8xf32>)
+  return
+}
+
+// -----
+
+func.func @deinterleave_output_dim_size_mismatch(%vec : vector<4xf32>) {
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}}
+  %0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<4xf32>, vector<2xf32>)
+  return
+}
+
+// -----
+
+func.func @deinterleave_n_dim_rank_fail(%vec : vector<2x3x4xf32>) {
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}}
+  %0, %1 = "vector.deinterleave" (%vec) : (vector<2x3x4xf32>) -> (vector<2x3x4xf32>, vector<2x3x2xf32>)
+  return
+}
+
+// -----
+
+func.func @deinterleave_scalable_dim_size_fail(%vec : vector<2x[4]xf32>) {
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that all of {res1, res2} have same type}}
+  %0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<2x[1]xf32>)
+  return
+}
+
+// -----
+
+func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
+  // expected-error @+1 {{'vector.deinterleave' op failed to verify that all of {res1, res2} have same type}}
+  %0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<[2]xf32>)
+  return
+}

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 79a80be4f8b20..9d8101d3eee97 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1116,3 +1116,45 @@ func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>)
   %0 = vector.interleave %a, %b : vector<2x[2]xf64>
   return %0 : vector<2x[4]xf64>
 }
+
+// CHECK-LABEL: @deinterleave_1d
+func.func @deinterleave_1d(%arg: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
+  // CHECK: vector.deinterleave %{{.*}} : vector<4xf32> -> vector<2xf32>
+  %0, %1 = vector.deinterleave %arg : vector<4xf32> -> vector<2xf32>
+  return %0, %1 : vector<2xf32>, vector<2xf32>
+}
+
+// CHECK-LABEL: @deinterleave_1d_scalable
+func.func @deinterleave_1d_scalable(%arg: vector<[4]xf32>) -> (vector<[2]xf32>, vector<[2]xf32>) {
+  // CHECK: vector.deinterleave %{{.*}} : vector<[4]xf32> -> vector<[2]xf32>
+  %0, %1 = vector.deinterleave %arg : vector<[4]xf32> -> vector<[2]xf32>
+  return %0, %1 : vector<[2]xf32>, vector<[2]xf32>
+}
+
+// CHECK-LABEL: @deinterleave_2d
+func.func @deinterleave_2d(%arg: vector<3x4xf32>) -> (vector<3x2xf32>, vector<3x2xf32>) {
+  // CHECK: vector.deinterleave %{{.*}} : vector<3x4xf32> -> vector<3x2xf32>
+  %0, %1 = vector.deinterleave %arg : vector<3x4xf32> -> vector<3x2xf32>
+  return %0, %1 : vector<3x2xf32>, vector<3x2xf32>
+}
+
+// CHECK-LABEL: @deinterleave_2d_scalable
+func.func @deinterleave_2d_scalable(%arg: vector<3x[4]xf32>) -> (vector<3x[2]xf32>, vector<3x[2]xf32>) {
+  // CHECK: vector.deinterleave %{{.*}} : vector<3x[4]xf32> -> vector<3x[2]xf32>
+  %0, %1 = vector.deinterleave %arg : vector<3x[4]xf32> -> vector<3x[2]xf32>
+  return %0, %1 : vector<3x[2]xf32>, vector<3x[2]xf32>
+}
+
+// CHECK-LABEL: @deinterleave_nd
+func.func @deinterleave_nd(%arg: vector<2x3x4x6xf32>) -> (vector<2x3x4x3xf32>, vector<2x3x4x3xf32>) {
+  // CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x6xf32> -> vector<2x3x4x3xf32>
+  %0, %1 = vector.deinterleave %arg : vector<2x3x4x6xf32> -> vector<2x3x4x3xf32>
+  return %0, %1 : vector<2x3x4x3xf32>, vector<2x3x4x3xf32>
+}
+
+// CHECK-LABEL: @deinterleave_nd_scalable
+func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>) {
+  // CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32>
+  %0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32>
+  return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
+}


        


More information about the Mlir-commits mailing list