[PATCH] D75219: [mlir] Add primitive transform pattern to rewrite linalg.copy into vector.broadcast form
JOSE IGNACIO GOMEZ PEREZ via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 26 15:15:41 PST 2020
tetuante created this revision.
Herald added subscribers: llvm-commits, Joonsoo, liufengdb, lucyrfox, mgester, arpith-jacob, nicolasvasilache, antiagainst, shauheen, burmako, jpienaar, rriddle, mehdi_amini.
Herald added a reviewer: nicolasvasilache.
Herald added a project: LLVM.
tetuante retitled this revision from "[mlir] Add primitive transform pattern to rewrite linalg.fill into vector.broadcast form" to "[mlir] Add primitive transform pattern to rewrite linalg.copy into vector.broadcast form".
tetuante added a reviewer: jsetoain.
tetuante added a reviewer: asaadaldien.
tetuante removed subscribers: burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits.
This diff adds a transformation patter to rewrite linalg.copy as
broadcasting a vector into a vector.
It uses the same preconditioning as matmul (memory is contiguous).
Repository:
rG LLVM Github Monorepo
https://reviews.llvm.org/D75219
Files:
mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
mlir/test/Dialect/Linalg/transform-patterns.mlir
mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
Index: mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
===================================================================
--- mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
+++ mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
@@ -111,6 +111,12 @@
HasLinalgTransformMarker<"VECTORIZE">,
PreconditionVectorizeLinalgOp
]>>)]>;
+def : Pattern<(CopyOp:$op $_, $_, $_, $_),
+ [(VectorizeLinalgOp)],
+ [(Constraint<And<[
+ HasLinalgTransformMarker<"VECTORIZE">,
+ PreconditionVectorizeLinalgOp
+ ]>>)]>;
def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_),
[(VectorizeLinalgOp)],
[(Constraint<And<[
Index: mlir/test/Dialect/Linalg/transform-patterns.mlir
===================================================================
--- mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -212,6 +212,14 @@
// CHECK-LABEL: func @test_vectorize_fill
// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>
+func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
+ linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"}: memref<8x16xf32>,
+ memref<8x16xf32>
+ return
+}
+// CHECK-LABEL: func @test_vectorize_copy
+// CHECK: vector.broadcast {{.*}} : vector<8x16xf32> to vector<8x16xf32>
+
func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
Index: mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
===================================================================
--- mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
+++ mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
@@ -168,7 +168,8 @@
for (Type outputTensorType : linalgOp.getOutputTensorTypes())
if (!outputTensorType.cast<ShapedType>().hasStaticShape())
return failure();
- if (isa<linalg::MatmulOp>(op) || isa<linalg::FillOp>(op))
+ if (isa<linalg::MatmulOp>(op) || isa<linalg::FillOp>(op) ||
+ isa<linalg::CopyOp>(op))
return success();
auto genericOp = dyn_cast<linalg::GenericOp>(op);
@@ -210,6 +211,17 @@
auto dstVec = std_load(dstMemrefVec);
auto resVec = vector_broadcast(dstVec, fillOp.value());
std_store(resVec, dstMemrefVec);
+ } else if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
+ // Vectorize fill as a vector.broadcast.
+ LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
+ "]: Rewrite linalg.copy as vector.broadcast: "
+ << *op << ":\n");
+ auto dstMemrefVec = vector_type_cast(copyOp.getOutputBuffer(0));
+ auto dstVec = std_load(dstMemrefVec);
+ auto srcMemrefVec = vector_type_cast(copyOp.getInput(0));
+ auto srcVec = std_load(srcMemrefVec);
+ auto resVec = vector_broadcast(dstVec, srcVec);
+ std_store(resVec, dstMemrefVec);
} else {
// Vectorize other ops as vector contraction (currently only matmul).
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
-------------- next part --------------
A non-text attachment was scrubbed...
Name: D75219.246845.patch
Type: text/x-patch
Size: 3161 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20200226/311cea9e/attachment.bin>
More information about the llvm-commits
mailing list