[Mlir-commits] [mlir] a1b9fb2 - [mlir][linalg] Add vectorization transform for CopyOp

Thomas Raoux llvmlistbot at llvm.org
Wed Jul 22 12:41:19 PDT 2020


Author: Thomas Raoux
Date: 2020-07-22T12:40:42-07:00
New Revision: a1b9fb220f6d71be3dde450c1695c92a7579af57

URL: https://github.com/llvm/llvm-project/commit/a1b9fb220f6d71be3dde450c1695c92a7579af57
DIFF: https://github.com/llvm/llvm-project/commit/a1b9fb220f6d71be3dde450c1695c92a7579af57.diff

LOG: [mlir][linalg] Add vectorization transform for CopyOp

CopyOp get vectorized to vector.transfer_read followed by vector.transfer_write

Differential Revision: https://reviews.llvm.org/D83739

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/transform-patterns.mlir
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 8e5da6ae539d..23d89c21e6e0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -96,7 +96,7 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
     if (!outputTensorType.cast<ShapedType>().hasStaticShape())
       return failure();
 
-  if (isa<linalg::FillOp>(op))
+  if (isa<linalg::FillOp, linalg::CopyOp>(op))
     return success();
 
   return isContraction(op);
@@ -119,12 +119,6 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
     return;
   }
 
-  assert(succeeded(isContraction(op)) && "Expected contraction");
-
-  // Vectorize other ops as vector contraction.
-  // TODO: interface.
-  LLVM_DEBUG(dbgs() << dbgPref
-                    << "Rewrite linalg op as vector.contract: " << *op);
   // In the case of 0-D memrefs, return null and special case to scalar load or
   // store later.
   auto extractVectorTypeFromScalarView = [](Value v) {
@@ -133,6 +127,49 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
                ? VectorType()
                : VectorType::get(mt.getShape(), mt.getElementType());
   };
+
+  if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
+    // Vectorize copy as a vector.transfer_read+vector.transfer_write.
+    LLVM_DEBUG(dbgs() << dbgPref
+                      << "Rewrite linalg.copy as vector.transfer_read + "
+                         "vector.transfer_write: "
+                      << *op);
+    Value zero = std_constant_index(0);
+    Value viewInput = copyOp.input();
+    Value viewOutput = copyOp.output();
+    Value vector;
+    if (VectorType inputType = extractVectorTypeFromScalarView(viewInput)) {
+      SmallVector<Value, 4> indicesInput(inputType.getRank(), zero);
+      if (copyOp.inputPermutation())
+        vector = vector_transfer_read(
+            extractVectorTypeFromScalarView(viewInput), viewInput, indicesInput,
+            copyOp.inputPermutation().getValue());
+      else
+        vector =
+            vector_transfer_read(extractVectorTypeFromScalarView(viewInput),
+                                 viewInput, indicesInput);
+    } else {
+      vector = std_load(viewInput).value;
+    }
+    if (VectorType outputType = extractVectorTypeFromScalarView(viewOutput)) {
+      SmallVector<Value, 4> indicesOutput(outputType.getRank(), zero);
+      if (copyOp.outputPermutation())
+        vector_transfer_write(vector, viewOutput, indicesOutput,
+                              copyOp.outputPermutation().getValue());
+      else
+        vector_transfer_write(vector, viewOutput, indicesOutput);
+    } else {
+      std_store(vector, viewOutput);
+    }
+    return;
+  }
+
+  assert(succeeded(isContraction(op)) && "Expected contraction");
+
+  // Vectorize other ops as vector contraction.
+  // TODO: interface.
+  LLVM_DEBUG(dbgs() << dbgPref
+                    << "Rewrite linalg op as vector.contract: " << *op);
   auto linalgOp = cast<linalg::LinalgOp>(op);
   Value viewA = linalgOp.getInput(0);
   Value viewB = linalgOp.getInput(1);

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 819b3b764137..3f7d16497253 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -152,6 +152,23 @@ func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
 // CHECK-LABEL: func @test_vectorize_fill
 //       CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>
 
+func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
+  linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"} :  memref<8x16xf32>, memref<8x16xf32>
+  return
+}
+// CHECK-LABEL: func @test_vectorize_copy
+//       CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32>
+//       CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32>
+
+func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
+  linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"} :  memref<f32>, memref<f32>
+  return
+}
+// CHECK-LABEL: func @test_vectorize_copy_scalar
+//       CHECK: %[[V:.*]] = load {{.*}} : memref<f32>
+//       CHECK: store %[[V]], {{.*}} : memref<f32>
+
+
 #matmul_accesses = [
   affine_map<(m, n, k) -> (m, k)>,
   affine_map<(m, n, k) -> (k, n)>,

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 4fb378c5ab8a..e356eb72fa42 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -144,6 +144,7 @@ static void applyPatterns(FuncOp funcOp) {
   //===--------------------------------------------------------------------===//
   patterns.insert<LinalgVectorizationPattern<MatmulOp>,
                   LinalgVectorizationPattern<FillOp>,
+                  LinalgVectorizationPattern<CopyOp>,
                   LinalgVectorizationPattern<GenericOp>>(
       ctx, LinalgMarker(Identifier::get("VECTORIZE", ctx)));
 


        


More information about the Mlir-commits mailing list