[Mlir-commits] [mlir] d453d73 - [mlir][vector] add proper verification to vector.print operation
Aart Bik
llvmlistbot at llvm.org
Mon Feb 6 14:10:16 PST 2023
Author: Aart Bik
Date: 2023-02-06T14:10:07-08:00
New Revision: d453d73d0d04c452e33d9d96743710dbb9ce5b09
URL: https://github.com/llvm/llvm-project/commit/d453d73d0d04c452e33d9d96743710dbb9ce5b09
DIFF: https://github.com/llvm/llvm-project/commit/d453d73d0d04c452e33d9d96743710dbb9ce5b09.diff
LOG: [mlir][vector] add proper verification to vector.print operation
Rationale:
Only proper vectors and scalars of floating-point or integral types
are actually lowered to calls into the light-weight output library.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D143423
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3dc3123ccb8e3..4d63c3682b9eb 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2440,7 +2440,11 @@ def Vector_TransposeOp :
}
def Vector_PrintOp :
- Vector_Op<"print", []>, Arguments<(ins AnyType:$source)> {
+ Vector_Op<"print", []>,
+ Arguments<(ins Type<Or<[
+ AnyVectorOfAnyRank.predicate,
+ AnyInteger.predicate, Index.predicate, AnyFloat.predicate
+ ]>>:$source)> {
let summary = "print operation (for testing and debugging)";
let description = [{
Prints the source vector (or scalar) to stdout in human readable
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5132fa8368996..adb524e3b7e0d 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -987,6 +987,14 @@ func.func @print_no_result(%arg0 : f32) -> i32 {
// -----
+func.func private @print_needs_vector(%arg0: tensor<8xf32>) {
+ // expected-error at +1 {{op operand #0 must be , but got 'tensor<8xf32>'}}
+ vector.print %arg0 : tensor<8xf32>
+ return
+}
+
+// -----
+
func.func @reshape_bad_input_shape(%arg0 : vector<3x2x4xf32>) {
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 06d0903a5284c..53a836466b5ad 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -401,13 +401,20 @@ func.func @constant_vector_mask() {
return
}
-// CHECK-LABEL: @vector_print
-func.func @vector_print(%arg0: vector<8x4xf32>) {
+// CHECK-LABEL: @vector_print_on_vector
+func.func @vector_print_on_vector(%arg0: vector<8x4xf32>) {
// CHECK: vector.print %{{.*}} : vector<8x4xf32>
vector.print %arg0 : vector<8x4xf32>
return
}
+// CHECK-LABEL: @vector_print_on_scalar
+func.func @vector_print_on_scalar(%arg0: i64) {
+ // CHECK: vector.print %{{.*}} : i64
+ vector.print %arg0 : i64
+ return
+}
+
// CHECK-LABEL: @reshape
func.func @reshape(%arg0 : vector<3x2x4xf32>) -> (vector<2x3x4xf32>) {
// CHECK: %[[C2:.*]] = arith.constant 2 : index
More information about the Mlir-commits
mailing list