[Mlir-commits] [mlir] d5cbaa0 - [Linalg] Don't create complex vectors when vectorizing copies
Benjamin Kramer
llvmlistbot at llvm.org
Thu Jan 19 15:36:47 PST 2023
Author: Benjamin Kramer
Date: 2023-01-20T00:34:30+01:00
New Revision: d5cbaa047004335a29dc3bcaf6aaa1c26fc27f36
URL: https://github.com/llvm/llvm-project/commit/d5cbaa047004335a29dc3bcaf6aaa1c26fc27f36
DIFF: https://github.com/llvm/llvm-project/commit/d5cbaa047004335a29dc3bcaf6aaa1c26fc27f36.diff
LOG: [Linalg] Don't create complex vectors when vectorizing copies
vector<complex<...>> is currently not valid. This is a reduced version
of https://reviews.llvm.org/D141578
Differential Revision: https://reviews.llvm.org/D142131
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2bc0c2a0d7770..d81496ed0f911 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1109,10 +1109,14 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
return failure();
- auto readType =
- VectorType::get(srcType.getShape(), getElementTypeOrSelf(srcType));
- auto writeType =
- VectorType::get(dstType.getShape(), getElementTypeOrSelf(dstType));
+ auto srcElementType = getElementTypeOrSelf(srcType);
+ auto dstElementType = getElementTypeOrSelf(dstType);
+ if (!VectorType::isValidElementType(srcElementType) ||
+ !VectorType::isValidElementType(dstElementType))
+ return failure();
+
+ auto readType = VectorType::get(srcType.getShape(), srcElementType);
+ auto writeType = VectorType::get(dstType.getShape(), dstElementType);
Location loc = copyOp->getLoc();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
@@ -1173,6 +1177,8 @@ struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
tensor::PadOp padOp, Value dest) {
auto sourceType = padOp.getSourceType();
auto resultType = padOp.getResultType();
+ if (!VectorType::isValidElementType(sourceType.getElementType()))
+ return failure();
// Copy cannot be vectorized if pad value is non-constant and source shape
// is dynamic. In case of a dynamic source shape, padding must be appended
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index d25ffe74841da..4f34eccec3a93 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -359,6 +359,23 @@ transform.sequence failures(propagate) {
%1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
%2 = transform.structured.vectorize %1
}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_copy_complex
+// CHECK-NOT: vector<
+func.func @test_vectorize_copy_complex(%A : memref<8x16xcomplex<f32>>, %B : memref<8x16xcomplex<f32>>) {
+ memref.copy %A, %B : memref<8x16xcomplex<f32>> to memref<8x16xcomplex<f32>>
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["memref.copy"]} in %arg1
+ %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+ %2 = transform.structured.vectorize %1
+}
+
// -----
// CHECK-LABEL: func @test_vectorize_trailing_index
@@ -806,6 +823,25 @@ transform.sequence failures(propagate) {
%2 = transform.structured.vectorize %1 { vectorize_padding }
}
+// -----
+
+// CHECK-LABEL: func @pad_static_complex(
+// CHECK-NOT: vector<
+func.func @pad_static_complex(%arg0: tensor<2x5x2xcomplex<f32>>, %pad_value: complex<f32>) -> tensor<2x6x4xcomplex<f32>> {
+ %0 = tensor.pad %arg0 low[0, 0, 2] high[0, 1, 0] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index):
+ tensor.yield %pad_value : complex<f32>
+ } : tensor<2x5x2xcomplex<f32>> to tensor<2x6x4xcomplex<f32>>
+ return %0 : tensor<2x6x4xcomplex<f32>>
+}
+
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+ %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+ %2 = transform.structured.vectorize %1 { vectorize_padding }
+}
// -----
More information about the Mlir-commits
mailing list