[Mlir-commits] [mlir] 12dcb89 - [mlir] [linalg] Only promote selected buffers.

Alex Zinenko llvmlistbot at llvm.org
Tue Apr 21 02:50:19 PDT 2020


Author: Pierre Oechsel
Date: 2020-04-21T11:50:08+02:00
New Revision: 12dcb89dadf4f37f7781ce687ab06b202e1b8ba3

URL: https://github.com/llvm/llvm-project/commit/12dcb89dadf4f37f7781ce687ab06b202e1b8ba3
DIFF: https://github.com/llvm/llvm-project/commit/12dcb89dadf4f37f7781ce687ab06b202e1b8ba3.diff

LOG: [mlir] [linalg] Only promote selected buffers.

The promotion transformation is promoting all input and output buffers of the transformed op. The user might want to only promote some of these buffers.

Differential Revision: https://reviews.llvm.org/D78498

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
    mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
    mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
    mlir/test/Dialect/Linalg/transform-patterns.mlir
    mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
index 7fa33e4f2982..2eaed14e8377 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td
@@ -114,4 +114,9 @@ def PreconditionPromoteSubviewsLinalgOp : CPred<
   "succeeded(promoteSubviewsLinalgOpPrecondition(op))">;
 def PromoteSubviewsLinalgOp : NativeCodeCall<
   "promoteSubviewsLinalgOp($_builder, op)">;
+
+class PromoteSelectedSubviewsLinalgOp<list<int> operands, string marker=""> :
+  NativeCodeCall<"promoteSelectedSubviewsLinalgOpAndSetMarker($_builder, op, {" #
+    StrJoinInt<operands>.result # "}, \"" # marker # "\")">;
+
 #endif // LINALG_TRANSFORMS

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
index c65909ec979e..e7a8925f746b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h
@@ -121,6 +121,14 @@ LogicalResult promoteSubviewsLinalgOpPrecondition(Operation *op);
 SmallVector<Value, 0> promoteSubviewsLinalgOp(PatternRewriter &rewriter,
                                               Operation *op);
 
+/// Similar to `promoteSubviewsLinalgOp` but only tries to promote
+/// the views corresponding to the operands specified in
+/// `operandIndicesToPromote`.
+/// If linalgMarker is specified and the transformation is successfull
+/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`.
+SmallVector<Value, 0> promoteSelectedSubviewsLinalgOpAndSetMarker(
+    PatternRewriter &rewriter, Operation *op,
+    ArrayRef<int64_t> operandIndicesToPromote, StringRef linalgMarker = "");
 } // namespace linalg
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
index 5b3618d30a71..e96ee2780252 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp
@@ -338,6 +338,24 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
   assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) &&
          "DRR failure case must be a precondition");
 
+  LinalgOp linOp = cast<LinalgOp>(op);
+  SmallVector<int64_t, 4> toPromote;
+  int64_t nBuffers = linOp.getNumInputsAndOutputBuffers();
+  toPromote.reserve(nBuffers);
+  for (int64_t i = 0; i < nBuffers; ++i)
+    toPromote.push_back(i);
+  return promoteSelectedSubviewsLinalgOpAndSetMarker(rewriter, op, toPromote);
+}
+
+SmallVector<Value, 0> mlir::linalg::promoteSelectedSubviewsLinalgOpAndSetMarker(
+    PatternRewriter &rewriter, Operation *op,
+    ArrayRef<int64_t> operandIndicesToPromote, StringRef linalgMarker) {
+  LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: "
+                    << *op << ":\n");
+
+  assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) &&
+         "DRR failure case must be a precondition");
+
   if (auto convOp = dyn_cast<linalg::ConvOp>(op)) {
     // TODO(ntv): add a level of indirection to linalg.generic.
     if (convOp.padding())
@@ -348,11 +366,16 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
   assert(linOp.hasBufferSemantics() &&
          "expected linalg op with buffer semantics");
   SetVector<Value> subViews;
-  for (auto it : linOp.getInputsAndOutputBuffers())
-    if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp()))
+  for (int64_t index : operandIndicesToPromote)
+    if (auto sv =
+            dyn_cast_or_null<SubViewOp>(linOp.getBuffer(index).getDefiningOp()))
       subViews.insert(sv);
+
   if (!subViews.empty()) {
-    promoteSubViewOperands(rewriter, linOp, subViews);
+    auto newOp = promoteSubViewOperands(rewriter, linOp, subViews);
+    if (!linalgMarker.empty())
+      newOp.setAttr(LinalgTransforms::kLinalgTransformMarker,
+                    rewriter.getStringAttr(linalgMarker));
     return {};
   }
   llvm_unreachable("DRR failure case must be a precondition");

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 7f76819b0849..3e8230c494c8 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -395,3 +395,53 @@ func @promote_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
 // CHECK      :         linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
 // CHECK      :         linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
 // CHECK      :         linalg.matmul(%[[v0]], %[[v1]], %[[v2]]) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
+
+func @promote_first_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                             %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+                             %arg2: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
+  %c2000 = constant 2000 : index
+  %c3000 = constant 3000 : index
+  %c4000 = constant 4000 : index
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = dim %arg0, 0 : memref<?x?xf32, offset: ?, strides: [?, 1]>
+  %1 = dim %arg0, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
+  %2 = dim %arg1, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
+  loop.for %arg3 = %c0 to %0 step %c2000 {
+    loop.for %arg4 = %c0 to %2 step %c3000 {
+      loop.for %arg5 = %c0 to %1 step %c4000 {
+        %3 = std.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] :
+             memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+        %4 = std.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] :
+             memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+        %5 = std.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] :
+             memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+        linalg.matmul(%3, %4, %5) {__internal_linalg_transform__ = "_promote_first_view_"} :
+                      memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                      memref<?x?xf32, offset: ?, strides: [?, ?]>,
+                      memref<?x?xf32, offset: ?, strides: [?, ?]>
+      }
+    }
+  }
+  return
+}
+// CHECK-LABEL: func @promote_first_subview_matmul
+// CHECK:   loop.for {{.*}} = %c0 to {{.*}} step %c2000 {
+// CHECK:     loop.for {{.*}} = %c0 to {{.*}} step %c3000 {
+// CHECK:       loop.for {{.*}} = %c0 to {{.*}} step %c4000 {
+// CHECK:         %[[s0:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
+// CHECK:         %[[s1:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
+// CHECK:         %[[s2:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
+// CHECK:         %[[a0:.*]] = alloc({{%.*}}) : memref<?xi8>
+// CHECK:         %[[v0:.*]] = std.view %[[a0]][][{{%.*}}, {{%.*}}] : memref<?xi8> to memref<?x?xf32>
+// CHECK:         %[[l0:.*]] = subview %[[v0]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32> to memref<?x?xf32, #[[map:.*]]>
+// CHECK-NOT:     %[[a1:.*]] = alloc({{%.*}}) : memref<?xi8>
+// CHECK-NOT:     %[[v1:.*]] = std.view %[[a1]][][{{%.*}}, {{%.*}}] : memref<?xi8> to memref<?x?xf32>
+// CHECK-NOT:     %[[l0:.*]] = subview %[[v1]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32> to memref<?x?xf32, #[[map]]>
+// CHECK-NOT:     %[[a2:.*]] = alloc({{%.*}}) : memref<?xi8>
+// CHECK-NOT:     %[[v2:.*]] = std.view %[[a2]][][{{%.*}}, {{%.*}}] : memref<?xi8> to memref<?x?xf32>
+// CHECK-NOT:     %[[l0:.*]] = subview %[[v2]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32> to memref<?x?xf32, #[[map]]>
+// CHECK:         linalg.copy(%[[s0]], %[[l0]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
+// CHECK-NOT:     linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
+// CHECK-NOT:     linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>^
+// CHECK:         linalg.matmul(%[[v0]], %[[s1]], %[[s2]]) : memref<?x?xf32>, memref<?x?xf32, #[[map]]>, memref<?x?xf32, #[[map]]>

diff  --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
index a55cdbffbdb6..8444f4cc3dc4 100644
--- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
+++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td
@@ -149,4 +149,12 @@ def : Pat<(MatmulOp:$op $_, $_, $_),
               HasLinalgTransformMarker<"_promote_views_">]>>
            )]>;
 
+def : Pat<(MatmulOp:$op $_, $_, $_),
+          (PromoteSelectedSubviewsLinalgOp<[0], "first_view_promotion">),
+          [(Constraint<And<[
+              PreconditionPromoteSubviewsLinalgOp,
+              HasOperandsOfType<"SubViewOp">,
+              HasLinalgTransformMarker<"_promote_first_view_">]>>
+           )]>;
+
 #endif // TEST_LINALG_TRANSFORMS_PATTERNS


        


More information about the Mlir-commits mailing list