[PATCH] D75219: [mlir] Add primitive transform pattern to rewrite linalg.copy into vector.broadcast form

JOSE IGNACIO GOMEZ PEREZ via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 26 15:15:41 PST 2020


tetuante created this revision.
Herald added subscribers: llvm-commits, Joonsoo, liufengdb, lucyrfox, mgester, arpith-jacob, nicolasvasilache, antiagainst, shauheen, burmako, jpienaar, rriddle, mehdi_amini.
Herald added a reviewer: nicolasvasilache.
Herald added a project: LLVM.
tetuante retitled this revision from "[mlir] Add primitive transform pattern to rewrite linalg.fill into vector.broadcast form" to "[mlir] Add primitive transform pattern to rewrite linalg.copy into vector.broadcast form".
tetuante added a reviewer: jsetoain.
tetuante added a reviewer: asaadaldien.
tetuante removed subscribers: burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits.

This diff adds a transformation patter to rewrite linalg.copy as
broadcasting a vector into a vector.
It uses the same preconditioning as matmul (memory is contiguous).


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D75219

Files:
  mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
  mlir/test/Dialect/Linalg/transform-patterns.mlir
  mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td


Index: mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
===================================================================
--- mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
+++ mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
@@ -111,6 +111,12 @@
                 HasLinalgTransformMarker<"VECTORIZE">,
                 PreconditionVectorizeLinalgOp
                ]>>)]>;
+def : Pattern<(CopyOp:$op $_, $_, $_, $_),
+              [(VectorizeLinalgOp)],
+              [(Constraint<And<[
+                HasLinalgTransformMarker<"VECTORIZE">,
+                PreconditionVectorizeLinalgOp
+               ]>>)]>;
 def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
               [(VectorizeLinalgOp)],
               [(Constraint<And<[
Index: mlir/test/Dialect/Linalg/transform-patterns.mlir
===================================================================
--- mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -212,6 +212,14 @@
 // CHECK-LABEL: func @test_vectorize_fill
 //       CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>
 
+func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
+  linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"}: memref<8x16xf32>,
+                              memref<8x16xf32>
+  return
+}
+// CHECK-LABEL: func @test_vectorize_copy
+//       CHECK: vector.broadcast {{.*}} : vector<8x16xf32> to vector<8x16xf32>
+
 func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
           %d = mulf %a, %b: f32
           %e = addf %c, %d: f32
Index: mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
===================================================================
--- mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
+++ mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
@@ -168,7 +168,8 @@
   for (Type outputTensorType : linalgOp.getOutputTensorTypes())
     if (!outputTensorType.cast<ShapedType>().hasStaticShape())
       return failure();
-  if (isa<linalg::MatmulOp>(op) || isa<linalg::FillOp>(op))
+  if (isa<linalg::MatmulOp>(op) || isa<linalg::FillOp>(op) ||
+      isa<linalg::CopyOp>(op))
     return success();
 
   auto genericOp = dyn_cast<linalg::GenericOp>(op);
@@ -210,6 +211,17 @@
     auto dstVec = std_load(dstMemrefVec);
     auto resVec = vector_broadcast(dstVec, fillOp.value());
     std_store(resVec, dstMemrefVec);
+  } else if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
+    // Vectorize fill as a vector.broadcast.
+    LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
+                         "]: Rewrite linalg.copy as vector.broadcast: "
+                      << *op << ":\n");
+    auto dstMemrefVec = vector_type_cast(copyOp.getOutputBuffer(0));
+    auto dstVec = std_load(dstMemrefVec);
+    auto srcMemrefVec = vector_type_cast(copyOp.getInput(0));
+    auto srcVec = std_load(srcMemrefVec);
+    auto resVec = vector_broadcast(dstVec, srcVec);
+    std_store(resVec, dstMemrefVec);
   } else {
     // Vectorize other ops as vector contraction (currently only matmul).
     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D75219.246845.patch
Type: text/x-patch
Size: 3161 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20200226/311cea9e/attachment.bin>


More information about the llvm-commits mailing list