[Mlir-commits] [mlir] 9915477 - [mlir] Add an additional check to vectorizeStaticLinalgOpPrecondition.

Adrian Kuegel llvmlistbot at llvm.org
Thu Jun 23 01:24:18 PDT 2022


Author: Adrian Kuegel
Date: 2022-06-23T10:24:04+02:00
New Revision: 991547703a1909392eccc1080778117e522db9f7

URL: https://github.com/llvm/llvm-project/commit/991547703a1909392eccc1080778117e522db9f7
DIFF: https://github.com/llvm/llvm-project/commit/991547703a1909392eccc1080778117e522db9f7.diff

LOG: [mlir] Add an additional check to vectorizeStaticLinalgOpPrecondition.

We need to make sure that the types used in the body are valid element types
for VectorType.

Differential Revision: https://reviews.llvm.org/D128336

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index efbf051645610..6c3d25df12d8d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -552,6 +552,19 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
 }
 
 static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) {
+  // All types in the body should be a supported element type for VectorType.
+  for (Operation &innerOp : op->getRegion(0).front()) {
+    if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) {
+          return !VectorType::isValidElementType(type);
+        })) {
+      return failure();
+    }
+    if (llvm::any_of(innerOp.getResultTypes(), [](Type type) {
+          return !VectorType::isValidElementType(type);
+        })) {
+      return failure();
+    }
+  }
   if (isElementwise(op))
     return success();
   // TODO: isaConvolutionOpInterface that can also infer from generic features.

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 99617d50684ed..dbd09576cb76b 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -207,6 +207,23 @@ func.func @test_vectorize_scalar_input(%A : memref<8x16xf32>, %arg0 : f32) {
 
 // -----
 
+// CHECK-LABEL: func @test_do_not_vectorize_unsupported_element_types
+func.func @test_do_not_vectorize_unsupported_element_types(%A : memref<8x16xcomplex<f32>>, %arg0 : complex<f32>) {
+  // CHECK-NOT: vector.broadcast
+  // CHECK-NOT: vector.transfer_write
+  linalg.generic {
+    indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>],
+    iterator_types = ["parallel", "parallel"]}
+   ins(%arg0 : complex<f32>)
+  outs(%A: memref<8x16xcomplex<f32>>) {
+    ^bb(%0: complex<f32>, %1: complex<f32>) :
+      linalg.yield %0 : complex<f32>
+  }
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @test_vectorize_fill
 func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
   //       CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>


        


More information about the Mlir-commits mailing list