[Mlir-commits] [mlir] d453d73 - [mlir][vector] add proper verification to vector.print operation

Aart Bik llvmlistbot at llvm.org
Mon Feb 6 14:10:16 PST 2023


Author: Aart Bik
Date: 2023-02-06T14:10:07-08:00
New Revision: d453d73d0d04c452e33d9d96743710dbb9ce5b09

URL: https://github.com/llvm/llvm-project/commit/d453d73d0d04c452e33d9d96743710dbb9ce5b09
DIFF: https://github.com/llvm/llvm-project/commit/d453d73d0d04c452e33d9d96743710dbb9ce5b09.diff

LOG: [mlir][vector] add proper verification to vector.print operation

Rationale:
Only proper vectors and scalars of floating-point or integral types
are actually lowered to calls into the light-weight output library.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D143423

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/test/Dialect/Vector/invalid.mlir
    mlir/test/Dialect/Vector/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3dc3123ccb8e3..4d63c3682b9eb 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2440,7 +2440,11 @@ def Vector_TransposeOp :
 }
 
 def Vector_PrintOp :
-  Vector_Op<"print", []>, Arguments<(ins AnyType:$source)> {
+  Vector_Op<"print", []>,
+  Arguments<(ins Type<Or<[
+    AnyVectorOfAnyRank.predicate,
+    AnyInteger.predicate, Index.predicate, AnyFloat.predicate
+  ]>>:$source)> {
   let summary = "print operation (for testing and debugging)";
   let description = [{
     Prints the source vector (or scalar) to stdout in human readable

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5132fa8368996..adb524e3b7e0d 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -987,6 +987,14 @@ func.func @print_no_result(%arg0 : f32) -> i32 {
 
 // -----
 
+func.func private @print_needs_vector(%arg0: tensor<8xf32>) {
+  // expected-error at +1 {{op operand #0 must be , but got 'tensor<8xf32>'}}
+  vector.print %arg0 : tensor<8xf32>
+  return
+}
+
+// -----
+
 func.func @reshape_bad_input_shape(%arg0 : vector<3x2x4xf32>) {
   %c2 = arith.constant 2 : index
   %c3 = arith.constant 3 : index

diff  --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 06d0903a5284c..53a836466b5ad 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -401,13 +401,20 @@ func.func @constant_vector_mask() {
   return
 }
 
-// CHECK-LABEL: @vector_print
-func.func @vector_print(%arg0: vector<8x4xf32>) {
+// CHECK-LABEL: @vector_print_on_vector
+func.func @vector_print_on_vector(%arg0: vector<8x4xf32>) {
   // CHECK: vector.print %{{.*}} : vector<8x4xf32>
   vector.print %arg0 : vector<8x4xf32>
   return
 }
 
+// CHECK-LABEL: @vector_print_on_scalar
+func.func @vector_print_on_scalar(%arg0: i64) {
+  // CHECK: vector.print %{{.*}} : i64
+  vector.print %arg0 : i64
+  return
+}
+
 // CHECK-LABEL: @reshape
 func.func @reshape(%arg0 : vector<3x2x4xf32>) -> (vector<2x3x4xf32>) {
   // CHECK:      %[[C2:.*]] = arith.constant 2 : index


        


More information about the Mlir-commits mailing list