[Mlir-commits] [mlir] 8e2b373 - [mlir][Vector] Add some missing tests for `broadcast` and `splat`

Nicolas Vasilache llvmlistbot at llvm.org
Fri Dec 3 00:52:55 PST 2021


Author: Michal Terepeta
Date: 2021-12-03T08:52:51Z
New Revision: 8e2b3733967296a838ea9861e362fb4d322d165e

URL: https://github.com/llvm/llvm-project/commit/8e2b3733967296a838ea9861e362fb4d322d165e
DIFF: https://github.com/llvm/llvm-project/commit/8e2b3733967296a838ea9861e362fb4d322d165e.diff

LOG: [mlir][Vector] Add some missing tests for `broadcast` and `splat`

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D114853

Added: 
    

Modified: 
    mlir/test/Dialect/Standard/ops.mlir
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir
index c3b40be816b54..64322f066b354 100644
--- a/mlir/test/Dialect/Standard/ops.mlir
+++ b/mlir/test/Dialect/Standard/ops.mlir
@@ -62,3 +62,10 @@ func @constant_complex_f64() -> complex<f64> {
   %result = constant [0.1 : f64, -1.0 : f64] : complex<f64>
   return %result : complex<f64>
 }
+
+// CHECK-LABEL: func @vector_splat_0d(
+func @vector_splat_0d(%a: f32) -> vector<f32> {
+  // CHECK: splat %{{.*}} : vector<f32>
+  %0 = splat %a : vector<f32>
+  return %0 : vector<f32>
+}

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c550a0818809f..7902976cce59c 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -16,6 +16,13 @@ func @broadcast_rank_too_high(%arg0: vector<4x4xf32>) {
 
 // -----
 
+func @broadcast_rank_too_high_0d(%arg0: vector<1xf32>) {
+  // expected-error at +1 {{'vector.broadcast' op source rank higher than destination rank}}
+  %1 = vector.broadcast %arg0 : vector<1xf32> to vector<f32>
+}
+
+// -----
+
 func @broadcast_dim1_mismatch(%arg0: vector<7xf32>) {
   // expected-error at +1 {{'vector.broadcast' op dimension mismatch (7 vs. 3)}}
   %1 = vector.broadcast %arg0 : vector<7xf32> to vector<3xf32>
@@ -79,7 +86,7 @@ func @extract_element(%arg0: vector<f32>) {
 }
 
 // -----
- 
+
 func @extract_element(%arg0: vector<4xf32>) {
   %c = arith.constant 3 : i32
   // expected-error at +1 {{expected position for 1-D vector}}

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 576924e1addff..11bc141556e32 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -149,16 +149,20 @@ func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
 }
 
 // CHECK-LABEL: @vector_broadcast
-func @vector_broadcast(%a: f32, %b: vector<16xf32>, %c: vector<1x16xf32>, %d: vector<8x1xf32>) -> vector<8x16xf32> {
+func @vector_broadcast(%a: f32, %b: vector<f32>, %c: vector<16xf32>, %d: vector<1x16xf32>, %e: vector<8x1xf32>) -> vector<8x16xf32> {
+  // CHECK: vector.broadcast %{{.*}} : f32 to vector<f32>
+  %0 = vector.broadcast %a : f32 to vector<f32>
+  // CHECK: vector.broadcast %{{.*}} : vector<f32> to vector<4xf32>
+  %1 = vector.broadcast %b : vector<f32> to vector<4xf32>
   // CHECK: vector.broadcast %{{.*}} : f32 to vector<16xf32>
-  %0 = vector.broadcast %a : f32 to vector<16xf32>
+  %2 = vector.broadcast %a : f32 to vector<16xf32>
   // CHECK-NEXT: vector.broadcast %{{.*}} : vector<16xf32> to vector<8x16xf32>
-  %1 = vector.broadcast %b : vector<16xf32> to vector<8x16xf32>
+  %3 = vector.broadcast %c : vector<16xf32> to vector<8x16xf32>
   // CHECK-NEXT: vector.broadcast %{{.*}} : vector<1x16xf32> to vector<8x16xf32>
-  %2 = vector.broadcast %c : vector<1x16xf32> to vector<8x16xf32>
+  %4 = vector.broadcast %d : vector<1x16xf32> to vector<8x16xf32>
   // CHECK-NEXT: vector.broadcast %{{.*}} : vector<8x1xf32> to vector<8x16xf32>
-  %3 = vector.broadcast %d : vector<8x1xf32> to vector<8x16xf32>
-  return %3 : vector<8x16xf32>
+  %5 = vector.broadcast %e : vector<8x1xf32> to vector<8x16xf32>
+  return %4 : vector<8x16xf32>
 }
 
 // CHECK-LABEL: @shuffle1D


        


More information about the Mlir-commits mailing list