[Mlir-commits] [mlir] 9915477 - [mlir] Add an additional check to vectorizeStaticLinalgOpPrecondition.
Adrian Kuegel
llvmlistbot at llvm.org
Thu Jun 23 01:24:18 PDT 2022
Author: Adrian Kuegel
Date: 2022-06-23T10:24:04+02:00
New Revision: 991547703a1909392eccc1080778117e522db9f7
URL: https://github.com/llvm/llvm-project/commit/991547703a1909392eccc1080778117e522db9f7
DIFF: https://github.com/llvm/llvm-project/commit/991547703a1909392eccc1080778117e522db9f7.diff
LOG: [mlir] Add an additional check to vectorizeStaticLinalgOpPrecondition.
We need to make sure that the types used in the body are valid element types
for VectorType.
Differential Revision: https://reviews.llvm.org/D128336
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index efbf051645610..6c3d25df12d8d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -552,6 +552,19 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
}
static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
+ // All types in the body should be a supported element type for VectorType.
+ for (Operation &innerOp : op->getRegion(0).front()) {
+ if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
+ return !VectorType::isValidElementType(type);
+ })) {
+ return failure();
+ }
+ if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
+ return !VectorType::isValidElementType(type);
+ })) {
+ return failure();
+ }
+ }
if (isElementwise(op))
return success();
// TODO: isaConvolutionOpInterface that can also infer from generic features.
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 99617d50684ed..dbd09576cb76b 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -207,6 +207,23 @@ func.func @test_vectorize_scalar_input(%A : memref<8x16xf32>, %arg0 : f32) {
// -----
+// CHECK-LABEL: func @test_do_not_vectorize_unsupported_element_types
+func.func @test_do_not_vectorize_unsupported_element_types(%A : memref<8x16xcomplex<f32>>, %arg0 : complex<f32>) {
+ // CHECK-NOT: vector.broadcast
+ // CHECK-NOT: vector.transfer_write
+ linalg.generic {
+ indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : complex<f32>)
+ outs(%A: memref<8x16xcomplex<f32>>) {
+ ^bb(%0: complex<f32>, %1: complex<f32>) :
+ linalg.yield %0 : complex<f32>
+ }
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @test_vectorize_fill
func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
// CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>
More information about the Mlir-commits
mailing list