[Mlir-commits] [mlir] 61baf2f - [mlir][Vector] Add check of supported reduction kind for ScanOp.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 19 19:42:38 PDT 2022
Author: jacquesguan
Date: 2022-04-20T02:42:19Z
New Revision: 61baf2ffa7071944c00a0642fdb9ff77d9cff0da
URL: https://github.com/llvm/llvm-project/commit/61baf2ffa7071944c00a0642fdb9ff77d9cff0da
DIFF: https://github.com/llvm/llvm-project/commit/61baf2ffa7071944c00a0642fdb9ff77d9cff0da.diff
LOG: [mlir][Vector] Add check of supported reduction kind for ScanOp.
This patch adds check of supported reduction kind for ScanOp to avoid using and/or/xor for floating point type.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D123977
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 701cc2f701fea..c935349310f78 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4588,6 +4588,13 @@ LogicalResult ScanOp::verify() {
return emitOpError("incompatible input/initial value shapes");
}
+ // Verify supported reduction kind.
+ Type eltType = getDestType().getElementType();
+ if (!isSupportedCombiningKind(getKind(), eltType))
+ return emitOpError("unsupported reduction type ")
+ << eltType << " for kind '" << stringifyCombiningKind(getKind())
+ << "'";
+
return success();
}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index e3e01b993df36..e9f41eede75fe 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1523,6 +1523,15 @@ func @scan_incompatible_shapes(%arg0: vector<2x3xi32>, %arg1: vector<5xi32>) ->
// -----
+func @scan_unsupported_kind(%arg0: vector<2x3xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32> {
+ // expected-error at +1 {{'vector.scan' op unsupported reduction type 'f32' for kind 'xor'}}
+ %0:2 = vector.scan <xor>, %arg0, %arg1 {inclusive = true, reduction_dim = 0} :
+ vector<2x3xf32>, vector<3xf32>
+ return %0#0 : vector<2x3xf32>
+}
+
+// -----
+
func @invalid_splat(%v : f32) {
// expected-error at +1 {{invalid kind of type specified}}
vector.splat %v : memref<8xf32>
More information about the Mlir-commits
mailing list