[Mlir-commits] [mlir] 8e2b373 - [mlir][Vector] Add some missing tests for `broadcast` and `splat`
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Dec 3 00:52:55 PST 2021
Author: Michal Terepeta
Date: 2021-12-03T08:52:51Z
New Revision: 8e2b3733967296a838ea9861e362fb4d322d165e
URL: https://github.com/llvm/llvm-project/commit/8e2b3733967296a838ea9861e362fb4d322d165e
DIFF: https://github.com/llvm/llvm-project/commit/8e2b3733967296a838ea9861e362fb4d322d165e.diff
LOG: [mlir][Vector] Add some missing tests for `broadcast` and `splat`
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D114853
Added:
Modified:
mlir/test/Dialect/Standard/ops.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir
index c3b40be816b54..64322f066b354 100644
--- a/mlir/test/Dialect/Standard/ops.mlir
+++ b/mlir/test/Dialect/Standard/ops.mlir
@@ -62,3 +62,10 @@ func @constant_complex_f64() -> complex<f64> {
%result = constant [0.1 : f64, -1.0 : f64] : complex<f64>
return %result : complex<f64>
}
+
+// CHECK-LABEL: func @vector_splat_0d(
+func @vector_splat_0d(%a: f32) -> vector<f32> {
+ // CHECK: splat %{{.*}} : vector<f32>
+ %0 = splat %a : vector<f32>
+ return %0 : vector<f32>
+}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c550a0818809f..7902976cce59c 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -16,6 +16,13 @@ func @broadcast_rank_too_high(%arg0: vector<4x4xf32>) {
// -----
+func @broadcast_rank_too_high_0d(%arg0: vector<1xf32>) {
+ // expected-error at +1 {{'vector.broadcast' op source rank higher than destination rank}}
+ %1 = vector.broadcast %arg0 : vector<1xf32> to vector<f32>
+}
+
+// -----
+
func @broadcast_dim1_mismatch(%arg0: vector<7xf32>) {
// expected-error at +1 {{'vector.broadcast' op dimension mismatch (7 vs. 3)}}
%1 = vector.broadcast %arg0 : vector<7xf32> to vector<3xf32>
@@ -79,7 +86,7 @@ func @extract_element(%arg0: vector<f32>) {
}
// -----
-
+
func @extract_element(%arg0: vector<4xf32>) {
%c = arith.constant 3 : i32
// expected-error at +1 {{expected position for 1-D vector}}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 576924e1addff..11bc141556e32 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -149,16 +149,20 @@ func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
}
// CHECK-LABEL: @vector_broadcast
-func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: vector<8x1xf32>) -> vector<8x16xf32> {
+func @vector_broadcast(%a: f32, %b: vector<f32>, %c: vector<16xf32>, %d: vector<1x16xf32>, %e: vector<8x1xf32>) -> vector<8x16xf32> {
+ // CHECK: vector.broadcast %{{.*}} : f32 to vector<f32>
+ %0 = vector.broadcast %a : f32 to vector<f32>
+ // CHECK: vector.broadcast %{{.*}} : vector<f32> to vector<4xf32>
+ %1 = vector.broadcast %b : vector<f32> to vector<4xf32>
// CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32>
- %0 = vector.broadcast %a : f32 to vector<16xf32>
+ %2 = vector.broadcast %a : f32 to vector<16xf32>
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32>
- %1 = vector.broadcast %b : vector<16xf32> to vector<8x16xf32>
+ %3 = vector.broadcast %c : vector<16xf32> to vector<8x16xf32>
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<1x16xf32> to vector<8x16xf32>
- %2 = vector.broadcast %c : vector<1x16xf32> to vector<8x16xf32>
+ %4 = vector.broadcast %d : vector<1x16xf32> to vector<8x16xf32>
// CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32>
- %3 = vector.broadcast %d : vector<8x1xf32> to vector<8x16xf32>
- return %3 : vector<8x16xf32>
+ %5 = vector.broadcast %e : vector<8x1xf32> to vector<8x16xf32>
+ return %4 : vector<8x16xf32>
}
// CHECK-LABEL: @shuffle1D
More information about the Mlir-commits
mailing list