[Mlir-commits] [mlir] a1b9fb2 - [mlir][linalg] Add vectorization transform for CopyOp
Thomas Raoux
llvmlistbot at llvm.org
Wed Jul 22 12:41:19 PDT 2020
Author: Thomas Raoux
Date: 2020-07-22T12:40:42-07:00
New Revision: a1b9fb220f6d71be3dde450c1695c92a7579af57
URL: https://github.com/llvm/llvm-project/commit/a1b9fb220f6d71be3dde450c1695c92a7579af57
DIFF: https://github.com/llvm/llvm-project/commit/a1b9fb220f6d71be3dde450c1695c92a7579af57.diff
LOG: [mlir][linalg] Add vectorization transform for CopyOp
CopyOp get vectorized to vector.transfer_read followed by vector.transfer_write
Differential Revision: https://reviews.llvm.org/D83739
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/transform-patterns.mlir
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 8e5da6ae539d..23d89c21e6e0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -96,7 +96,7 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
return failure();
- if (isa<linalg::FillOp>(op))
+ if (isa<linalg::FillOp, linalg::CopyOp>(op))
return success();
return isContraction(op);
@@ -119,12 +119,6 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
return;
}
- assert(succeeded(isContraction(op)) && "Expected contraction");
-
- // Vectorize other ops as vector contraction.
- // TODO: interface.
- LLVM_DEBUG(dbgs() << dbgPref
- << "Rewrite linalg op as vector.contract: " << *op);
// In the case of 0-D memrefs, return null and special case to scalar load or
// store later.
auto extractVectorTypeFromScalarView = [](Value v) {
@@ -133,6 +127,49 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
? VectorType()
: VectorType::get(mt.getShape(), mt.getElementType());
};
+
+ if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
+ // Vectorize copy as a vector.transfer_read+vector.transfer_write.
+ LLVM_DEBUG(dbgs() << dbgPref
+ << "Rewrite linalg.copy as vector.transfer_read + "
+ "vector.transfer_write: "
+ << *op);
+ Value zero = std_constant_index(0);
+ Value viewInput = copyOp.input();
+ Value viewOutput = copyOp.output();
+ Value vector;
+ if (VectorType inputType = extractVectorTypeFromScalarView(viewInput)) {
+ SmallVector<Value, 4> indicesInput(inputType.getRank(), zero);
+ if (copyOp.inputPermutation())
+ vector = vector_transfer_read(
+ extractVectorTypeFromScalarView(viewInput), viewInput, indicesInput,
+ copyOp.inputPermutation().getValue());
+ else
+ vector =
+ vector_transfer_read(extractVectorTypeFromScalarView(viewInput),
+ viewInput, indicesInput);
+ } else {
+ vector = std_load(viewInput).value;
+ }
+ if (VectorType outputType = extractVectorTypeFromScalarView(viewOutput)) {
+ SmallVector<Value, 4> indicesOutput(outputType.getRank(), zero);
+ if (copyOp.outputPermutation())
+ vector_transfer_write(vector, viewOutput, indicesOutput,
+ copyOp.outputPermutation().getValue());
+ else
+ vector_transfer_write(vector, viewOutput, indicesOutput);
+ } else {
+ std_store(vector, viewOutput);
+ }
+ return;
+ }
+
+ assert(succeeded(isContraction(op)) && "Expected contraction");
+
+ // Vectorize other ops as vector contraction.
+ // TODO: interface.
+ LLVM_DEBUG(dbgs() << dbgPref
+ << "Rewrite linalg op as vector.contract: " << *op);
auto linalgOp = cast<linalg::LinalgOp>(op);
Value viewA = linalgOp.getInput(0);
Value viewB = linalgOp.getInput(1);
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 819b3b764137..3f7d16497253 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -152,6 +152,23 @@ func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
// CHECK-LABEL: func @test_vectorize_fill
// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>
+func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
+ linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"} : memref<8x16xf32>, memref<8x16xf32>
+ return
+}
+// CHECK-LABEL: func @test_vectorize_copy
+// CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
+// CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
+
+func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
+ linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"} : memref<f32>, memref<f32>
+ return
+}
+// CHECK-LABEL: func @test_vectorize_copy_scalar
+// CHECK: %[[V:.*]] = load {{.*}} : memref<f32>
+// CHECK: store %[[V]], {{.*}} : memref<f32>
+
+
#matmul_accesses = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 4fb378c5ab8a..e356eb72fa42 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -144,6 +144,7 @@ static void applyPatterns(FuncOp funcOp) {
//===--------------------------------------------------------------------===//
patterns.insert<LinalgVectorizationPattern<MatmulOp>,
LinalgVectorizationPattern<FillOp>,
+ LinalgVectorizationPattern<CopyOp>,
LinalgVectorizationPattern<GenericOp>>(
ctx, LinalgMarker(Identifier::get("VECTORIZE", ctx)));
More information about the Mlir-commits
mailing list