[Mlir-commits] [mlir] 6e81eae - [mlir][Vector] Support 0-D vectors in TransposeOp
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Aug 26 03:40:31 PDT 2022
Author: Nicolas Vasilache
Date: 2022-08-26T03:40:21-07:00
New Revision: 6e81eae2f767df99c46e2296a74a00e28716ccae
URL: https://github.com/llvm/llvm-project/commit/6e81eae2f767df99c46e2296a74a00e28716ccae
DIFF: https://github.com/llvm/llvm-project/commit/6e81eae2f767df99c46e2296a74a00e28716ccae.diff
LOG: [mlir][Vector] Support 0-D vectors in TransposeOp
Co-authored-by: Michal Terepeta <michalt at google.com>
Reviewed-by: ftynse
Differential Revision: https://reviews.llvm.org/D115743
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 033c29bca60f6..aa6624f07a2dc 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2229,12 +2229,13 @@ def Vector_TransposeOp :
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
PredOpTrait<"operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
- Arguments<(ins AnyVector:$vector, I64ArrayAttr:$transp)>,
- Results<(outs AnyVector:$result)> {
+ Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$transp)>,
+ Results<(outs AnyVectorOfAnyRank:$result)> {
let summary = "vector transpose operation";
let description = [{
Takes a n-D vector and returns the transposed n-D vector defined by
- the permutation of ranks in the n-sized integer array attribute.
+ the permutation of ranks in the n-sized integer array attribute (in case
+ of 0-D vectors the array attribute must be empty).
In the operation
```mlir
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 40e4022461d04..828fc22f18346 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1760,6 +1760,8 @@ func.func @create_mask_1d(%a : index) -> vector<4xi1> {
// CHECK: %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<4xi32>
// CHECK: return %[[result]] : vector<4xi1>
+// -----
+
func.func @create_mask_1d_scalable(%a : index) -> vector<[4]xi1> {
%v = vector.create_mask %a : vector<[4]xi1>
return %v: vector<[4]xi1>
@@ -1776,6 +1778,17 @@ func.func @create_mask_1d_scalable(%a : index) -> vector<[4]xi1> {
// -----
+func.func @transpose_0d(%arg0: vector<f32>) -> vector<f32> {
+ %0 = vector.transpose %arg0, [] : vector<f32> to vector<f32>
+ return %0 : vector<f32>
+}
+
+// CHECK-LABEL: func @transpose_0d
+// CHECK-SAME: %[[A:.*]]: vector<f32>
+// CHECK: return %[[A]] : vector<f32>
+
+// -----
+
func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
%0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
: vector<16xf32> -> vector<16xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 9a7e6f4979a39..fa2516466ad50 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1145,11 +1145,25 @@ func.func @multi_reduce_invalid_type(%arg0: vector<4x16xf32>, %acc: vector<16xf3
// -----
+func.func @transpose_rank_mismatch_0d(%arg0: vector<f32>) {
+ // expected-error at +1 {{'vector.transpose' op vector result rank mismatch: 1}}
+ %0 = vector.transpose %arg0, [] : vector<f32> to vector<100xf32>
+}
+
+// -----
+
func.func @transpose_rank_mismatch(%arg0: vector<4x16x11xf32>) {
// expected-error at +1 {{'vector.transpose' op vector result rank mismatch: 1}}
%0 = vector.transpose %arg0, [2, 1, 0] : vector<4x16x11xf32> to vector<100xf32>
}
+// -----
+
+func.func @transpose_length_mismatch_0d(%arg0: vector<f32>) {
+ // expected-error at +1 {{'vector.transpose' op transposition length mismatch: 1}}
+ %0 = vector.transpose %arg0, [1] : vector<f32> to vector<f32>
+}
+
// -----
func.func @transpose_length_mismatch(%arg0: vector<4x4xf32>) {
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 4c3e322131715..e4e260a37bb13 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -570,6 +570,22 @@ func.func @transpose_int(%arg0: vector<11x7x3x2xi32>) -> vector<2x11x7x3xi32> {
return %0 : vector<2x11x7x3xi32>
}
+// CHECK-LABEL: @transpose_fp_0d
+func.func @transpose_fp_0d(%arg0: vector<f32>) -> vector<f32> {
+ // CHECK: %[[X:.*]] = vector.transpose %{{.*}}, [] : vector<f32> to vector<f32>
+ %0 = vector.transpose %arg0, [] : vector<f32> to vector<f32>
+ // CHECK: return %[[X]] : vector<f32>
+ return %0 : vector<f32>
+}
+
+// CHECK-LABEL: @transpose_int_0d
+func.func @transpose_int_0d(%arg0: vector<i32>) -> vector<i32> {
+ // CHECK: %[[X:.*]] = vector.transpose %{{.*}}, [] : vector<i32> to vector<i32>
+ %0 = vector.transpose %arg0, [] : vector<i32> to vector<i32>
+ // CHECK: return %[[X]] : vector<i32>
+ return %0 : vector<i32>
+}
+
// CHECK-LABEL: @flat_transpose_fp
func.func @flat_transpose_fp(%arg0: vector<16xf32>) -> vector<16xf32> {
// CHECK: %[[X:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
index a29ab10ddbd2c..8a100dcc91b35 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir
@@ -120,6 +120,13 @@ func.func @fma_0d(%four: vector<f32>) {
return
}
+func.func @transpose_0d(%arg: vector<i32>) {
+ %1 = vector.transpose %arg, [] : vector<i32> to vector<i32>
+ // CHECK: ( 42 )
+ vector.print %1: vector<i32>
+ return
+}
+
func.func @entry() {
%0 = arith.constant 42.0 : f32
%1 = arith.constant dense<0.0> : vector<f32>
@@ -151,6 +158,8 @@ func.func @entry() {
%5 = arith.constant dense<4.0> : vector<f32>
call @fma_0d(%5) : (vector<f32>) -> ()
+ %6 = arith.constant dense<42> : vector<i32>
+ call @transpose_0d(%6) : (vector<i32>) -> ()
return
}
More information about the Mlir-commits
mailing list