[Mlir-commits] [mlir] 12dcb89 - [mlir] [linalg] Only promote selected buffers.
Alex Zinenko
llvmlistbot at llvm.org
Tue Apr 21 02:50:19 PDT 2020
Author: Pierre Oechsel
Date: 2020-04-21T11:50:08+02:00
New Revision: 12dcb89dadf4f37f7781ce687ab06b202e1b8ba3
URL: https://github.com/llvm/llvm-project/commit/12dcb89dadf4f37f7781ce687ab06b202e1b8ba3
DIFF: https://github.com/llvm/llvm-project/commit/12dcb89dadf4f37f7781ce687ab06b202e1b8ba3.diff
LOG: [mlir] [linalg] Only promote selected buffers.
The promotion transformation is promoting all input and output buffers of the transformed op. The user might want to only promote some of these buffers.
Differential Revision: https://reviews.llvm.org/D78498
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
mlir/test/Dialect/Linalg/transform-patterns.mlir
mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
index 7fa33e4f2982..2eaed14e8377 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
@@ -114,4 +114,9 @@ def PreconditionPromoteSubviewsLinalgOp : CPred<
"succeeded(promoteSubviewsLinalgOpPrecondition(op))">;
def PromoteSubviewsLinalgOp : NativeCodeCall<
"promoteSubviewsLinalgOp($_builder, op)">;
+
+class PromoteSelectedSubviewsLinalgOp<list<int> operands, string marker=""> :
+ NativeCodeCall<"promoteSelectedSubviewsLinalgOpAndSetMarker($_builder, op, {" #
+ StrJoinInt<operands>.result # "}, \"" # marker # "\")">;
+
#endif // LINALG_TRANSFORMS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
index c65909ec979e..e7a8925f746b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
@@ -121,6 +121,14 @@ LogicalResult promoteSubviewsLinalgOpPrecondition(Operation *op);
SmallVector<Value, 0> promoteSubviewsLinalgOp(PatternRewriter &rewriter,
Operation *op);
+/// Similar to `promoteSubviewsLinalgOp` but only tries to promote
+/// the views corresponding to the operands specified in
+/// `operandIndicesToPromote`.
+/// If linalgMarker is specified and the transformation is successfull
+/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`.
+SmallVector<Value, 0> promoteSelectedSubviewsLinalgOpAndSetMarker(
+ PatternRewriter &rewriter, Operation *op,
+ ArrayRef<int64_t> operandIndicesToPromote, StringRef linalgMarker = "");
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
index 5b3618d30a71..e96ee2780252 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
@@ -338,6 +338,24 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) &&
"DRR failure case must be a precondition");
+ LinalgOp linOp = cast<LinalgOp>(op);
+ SmallVector<int64_t, 4> toPromote;
+ int64_t nBuffers = linOp.getNumInputsAndOutputBuffers();
+ toPromote.reserve(nBuffers);
+ for (int64_t i = 0; i < nBuffers; ++i)
+ toPromote.push_back(i);
+ return promoteSelectedSubviewsLinalgOpAndSetMarker(rewriter, op, toPromote);
+}
+
+SmallVector<Value, 0> mlir::linalg::promoteSelectedSubviewsLinalgOpAndSetMarker(
+ PatternRewriter &rewriter, Operation *op,
+ ArrayRef<int64_t> operandIndicesToPromote, StringRef linalgMarker) {
+ LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: "
+ << *op << ":\n");
+
+ assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) &&
+ "DRR failure case must be a precondition");
+
if (auto convOp = dyn_cast<linalg::ConvOp>(op)) {
// TODO(ntv): add a level of indirection to linalg.generic.
if (convOp.padding())
@@ -348,11 +366,16 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
assert(linOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
SetVector<Value> subViews;
- for (auto it : linOp.getInputsAndOutputBuffers())
- if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp()))
+ for (int64_t index : operandIndicesToPromote)
+ if (auto sv =
+ dyn_cast_or_null<SubViewOp>(linOp.getBuffer(index).getDefiningOp()))
subViews.insert(sv);
+
if (!subViews.empty()) {
- promoteSubViewOperands(rewriter, linOp, subViews);
+ auto newOp = promoteSubViewOperands(rewriter, linOp, subViews);
+ if (!linalgMarker.empty())
+ newOp.setAttr(LinalgTransforms::kLinalgTransformMarker,
+ rewriter.getStringAttr(linalgMarker));
return {};
}
llvm_unreachable("DRR failure case must be a precondition");
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 7f76819b0849..3e8230c494c8 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -395,3 +395,53 @@ func @promote_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// CHECK : linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
// CHECK : linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
// CHECK : linalg.matmul(%[[v0]], %[[v1]], %[[v2]]) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
+
+func @promote_first_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+ %arg2: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
+ %c2000 = constant 2000 : index
+ %c3000 = constant 3000 : index
+ %c4000 = constant 4000 : index
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = dim %arg0, 0 : memref<?x?xf32, offset: ?, strides: [?, 1]>
+ %1 = dim %arg0, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
+ %2 = dim %arg1, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
+ loop.for %arg3 = %c0 to %0 step %c2000 {
+ loop.for %arg4 = %c0 to %2 step %c3000 {
+ loop.for %arg5 = %c0 to %1 step %c4000 {
+ %3 = std.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] :
+ memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+ %4 = std.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] :
+ memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+ %5 = std.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] :
+ memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+ linalg.matmul(%3, %4, %5) {__internal_linalg_transform__ = "_promote_first_view_"} :
+ memref<?x?xf32, offset: ?, strides: [?, ?]>,
+ memref<?x?xf32, offset: ?, strides: [?, ?]>,
+ memref<?x?xf32, offset: ?, strides: [?, ?]>
+ }
+ }
+ }
+ return
+}
+// CHECK-LABEL: func @promote_first_subview_matmul
+// CHECK: loop.for {{.*}} = %c0 to {{.*}} step %c2000 {
+// CHECK: loop.for {{.*}} = %c0 to {{.*}} step %c3000 {
+// CHECK: loop.for {{.*}} = %c0 to {{.*}} step %c4000 {
+// CHECK: %[[s0:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
+// CHECK: %[[s1:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
+// CHECK: %[[s2:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
+// CHECK: %[[a0:.*]] = alloc({{%.*}}) : memref<?xi8>
+// CHECK: %[[v0:.*]] = std.view %[[a0]][][{{%.*}}, {{%.*}}] : memref<?xi8> to memref<?x?xf32>
+// CHECK: %[[l0:.*]] = subview %[[v0]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32> to memref<?x?xf32, #[[map:.*]]>
+// CHECK-NOT: %[[a1:.*]] = alloc({{%.*}}) : memref<?xi8>
+// CHECK-NOT: %[[v1:.*]] = std.view %[[a1]][][{{%.*}}, {{%.*}}] : memref<?xi8> to memref<?x?xf32>
+// CHECK-NOT: %[[l0:.*]] = subview %[[v1]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32> to memref<?x?xf32, #[[map]]>
+// CHECK-NOT: %[[a2:.*]] = alloc({{%.*}}) : memref<?xi8>
+// CHECK-NOT: %[[v2:.*]] = std.view %[[a2]][][{{%.*}}, {{%.*}}] : memref<?xi8> to memref<?x?xf32>
+// CHECK-NOT: %[[l0:.*]] = subview %[[v2]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32> to memref<?x?xf32, #[[map]]>
+// CHECK: linalg.copy(%[[s0]], %[[l0]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
+// CHECK-NOT: linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
+// CHECK-NOT: linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>^
+// CHECK: linalg.matmul(%[[v0]], %[[s1]], %[[s2]]) : memref<?x?xf32>, memref<?x?xf32, #[[map]]>, memref<?x?xf32, #[[map]]>
diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
index a55cdbffbdb6..8444f4cc3dc4 100644
--- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
+++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
@@ -149,4 +149,12 @@ def : Pat<(MatmulOp:$op $_, $_, $_),
HasLinalgTransformMarker<"_promote_views_">]>>
)]>;
+def : Pat<(MatmulOp:$op $_, $_, $_),
+ (PromoteSelectedSubviewsLinalgOp<[0], "first_view_promotion">),
+ [(Constraint<And<[
+ PreconditionPromoteSubviewsLinalgOp,
+ HasOperandsOfType<"SubViewOp">,
+ HasLinalgTransformMarker<"_promote_first_view_">]>>
+ )]>;
+
#endif // TEST_LINALG_TRANSFORMS_PATTERNS
More information about the Mlir-commits
mailing list