[Mlir-commits] [mlir] d5cbaa0 - [Linalg] Don't create complex vectors when vectorizing copies

Benjamin Kramer llvmlistbot at llvm.org
Thu Jan 19 15:36:47 PST 2023


Author: Benjamin Kramer
Date: 2023-01-20T00:34:30+01:00
New Revision: d5cbaa047004335a29dc3bcaf6aaa1c26fc27f36

URL: https://github.com/llvm/llvm-project/commit/d5cbaa047004335a29dc3bcaf6aaa1c26fc27f36
DIFF: https://github.com/llvm/llvm-project/commit/d5cbaa047004335a29dc3bcaf6aaa1c26fc27f36.diff

LOG: [Linalg] Don't create complex vectors when vectorizing copies

vector<complex<...>> is currently not valid. This is a reduced version
of https://reviews.llvm.org/D141578

Differential Revision: https://reviews.llvm.org/D142131

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2bc0c2a0d7770..d81496ed0f911 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1109,10 +1109,14 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
   if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
     return failure();
 
-  auto readType =
-      VectorType::get(srcType.getShape(), getElementTypeOrSelf(srcType));
-  auto writeType =
-      VectorType::get(dstType.getShape(), getElementTypeOrSelf(dstType));
+  auto srcElementType = getElementTypeOrSelf(srcType);
+  auto dstElementType = getElementTypeOrSelf(dstType);
+  if (!VectorType::isValidElementType(srcElementType) ||
+      !VectorType::isValidElementType(dstElementType))
+    return failure();
+
+  auto readType = VectorType::get(srcType.getShape(), srcElementType);
+  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
 
   Location loc = copyOp->getLoc();
   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
@@ -1173,6 +1177,8 @@ struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
                                         tensor::PadOp padOp, Value dest) {
     auto sourceType = padOp.getSourceType();
     auto resultType = padOp.getResultType();
+    if (!VectorType::isValidElementType(sourceType.getElementType()))
+      return failure();
 
     // Copy cannot be vectorized if pad value is non-constant and source shape
     // is dynamic. In case of a dynamic source shape, padding must be appended

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index d25ffe74841da..4f34eccec3a93 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -359,6 +359,23 @@ transform.sequence failures(propagate) {
   %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
   %2 = transform.structured.vectorize %1
 }
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_copy_complex
+// CHECK-NOT: vector<
+func.func @test_vectorize_copy_complex(%A : memref<8x16xcomplex<f32>>, %B : memref<8x16xcomplex<f32>>) {
+  memref.copy %A, %B :  memref<8x16xcomplex<f32>> to memref<8x16xcomplex<f32>>
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["memref.copy"]} in %arg1
+  %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+  %2 = transform.structured.vectorize %1
+}
+
 // -----
 
 // CHECK-LABEL: func @test_vectorize_trailing_index
@@ -806,6 +823,25 @@ transform.sequence failures(propagate) {
   %2 = transform.structured.vectorize %1  { vectorize_padding }
 }
 
+// -----
+
+// CHECK-LABEL: func @pad_static_complex(
+//   CHECK-NOT:   vector<
+func.func @pad_static_complex(%arg0: tensor<2x5x2xcomplex<f32>>, %pad_value: complex<f32>) -> tensor<2x6x4xcomplex<f32>> {
+  %0 = tensor.pad %arg0 low[0, 0, 2] high[0, 1, 0] {
+    ^bb0(%arg1: index, %arg2: index, %arg3: index):
+      tensor.yield %pad_value : complex<f32>
+    } : tensor<2x5x2xcomplex<f32>> to tensor<2x6x4xcomplex<f32>>
+  return %0 : tensor<2x6x4xcomplex<f32>>
+}
+
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+  %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+  %2 = transform.structured.vectorize %1  { vectorize_padding }
+}
 
 // -----
 


        


More information about the Mlir-commits mailing list