[Mlir-commits] [mlir] ecca785 - [mlir][Linalg] Side effects interface for Linalg ops

Nicolas Vasilache llvmlistbot at llvm.org
Thu Nov 5 01:03:51 PST 2020


Author: Nicolas Vasilache
Date: 2020-11-05T09:00:28Z
New Revision: ecca7852d9d75aba859a3d8d001bfb2dda1345db

URL: https://github.com/llvm/llvm-project/commit/ecca7852d9d75aba859a3d8d001bfb2dda1345db
DIFF: https://github.com/llvm/llvm-project/commit/ecca7852d9d75aba859a3d8d001bfb2dda1345db.diff

LOG: [mlir][Linalg] Side effects interface for Linalg ops

The LinalgDependenceGraph and alias analysis provide the necessary analysis for the Linalg fusion on buffers case.

However this is not enough for linalg on tensors which require proper memory effects to play nicely with DCE and other transformations.
This revision adds side effects to Linalg ops that were previously missing and has 2 consequences:
1. one example in the copy removal pass now fails since the linalg.generic op has side effects and the pass does not perform alias analysis / distinguish between reads and writes.
2. a few examples in fusion-tensor.mlir need to return the resulting tensor otherwise DCE automatically kicks in as part of greedy pattern application.

Differential Revision: https://reviews.llvm.org/D90762

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/fusion-tensor.mlir
    mlir/test/Transforms/copy-removal.mlir
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 6338dae1af08..29ce9efc2e98 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -17,6 +17,7 @@
 include "mlir/Dialect/Linalg/IR/LinalgBase.td"
 include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td"
 include "mlir/Interfaces/CopyOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
 
 // The Linalg `NInputs` trait provides the API for ops that are known
 // to have a specified number of inputs, all passed as operands.
@@ -43,13 +44,14 @@ def NamedStructuredOpTrait : NativeOpTrait<"linalg::NamedStructuredOpTrait">;
 // first operands. These may be optionally followed by non-view operands
 // depending on the specific Linalg op.
 class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
-  : Op<Linalg_Dialect, mnemonic,
-       !listconcat(props, [LinalgStructuredInterface])> {
-}
+  : Op<Linalg_Dialect, mnemonic, !listconcat(props, [
+       LinalgStructuredInterface])> {}
 
 class LinalgStructured_Op<string mnemonic, list<OpTrait> props>
   : LinalgStructuredBase_Op<mnemonic,
-       !listconcat(props, [StructuredOpTraits])> {
+       !listconcat(props, [
+         StructuredOpTraits,
+         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>])> {
   code libraryCallName = [{
     std::string getLibraryCallName() {
       return generateLibraryCallName(getOperation());
@@ -480,8 +482,9 @@ class LinalgOperandOfRank<int rank>: Type<
   >>;
 
 class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
-    NamedStructuredOpTrait,
     AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+    NamedStructuredOpTrait,
     SingleBlockImplicitTerminator<"YieldOp">]> {
   let arguments = (ins Variadic<AnyShaped>:$inputs,
                        Variadic<AnyMemRef>:$output_buffers,

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 121961d7393a..6a2d6f3d7ac8 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -313,6 +313,40 @@ static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
   return success();
 }
 
+static void getGenericEffectsImpl(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects,
+    ValueRange results, ValueRange inputBuffers, ValueRange outputBuffers) {
+  for (Value value : results) {
+    effects.emplace_back(MemoryEffects::Allocate::get(), value,
+                         SideEffects::DefaultResource::get());
+  }
+  for (Value value : inputBuffers) {
+    effects.emplace_back(MemoryEffects::Read::get(), value,
+                         SideEffects::DefaultResource::get());
+  }
+  for (Value value : outputBuffers) {
+    effects.emplace_back(MemoryEffects::Read::get(), value,
+                         SideEffects::DefaultResource::get());
+    effects.emplace_back(MemoryEffects::Write::get(), value,
+                         SideEffects::DefaultResource::get());
+  }
+}
+
+void GenericOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  getGenericEffectsImpl(effects, getOperation()->getResults(),
+                        getInputBuffers(), getOutputBuffers());
+}
+
+void IndexedGenericOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  getGenericEffectsImpl(effects, getOperation()->getResults(),
+                        getInputBuffers(), getOutputBuffers());
+}
+
 namespace {
 template <typename GenericOpType>
 struct BlockArgsVerifier {
@@ -1039,6 +1073,13 @@ static LogicalResult verify(linalg::YieldOp op) {
 
 /////// Operations corresponding to library calls defined with Tablegen ////////
 
+void FillOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  effects.emplace_back(MemoryEffects::Write::get(), output(),
+                       SideEffects::DefaultResource::get());
+}
+
 static LogicalResult verify(FillOp op) {
   auto viewType = op.getOutputShapedType(0);
   auto fillType = op.value().getType();
@@ -1047,6 +1088,15 @@ static LogicalResult verify(FillOp op) {
   return success();
 }
 
+void CopyOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  effects.emplace_back(MemoryEffects::Read::get(), input(),
+                       SideEffects::DefaultResource::get());
+  effects.emplace_back(MemoryEffects::Write::get(), output(),
+                       SideEffects::DefaultResource::get());
+}
+
 static LogicalResult verify(CopyOp op) {
   auto outputViewType = op.getOutputShapedType(0);
   auto inputViewType = op.getInputShapedType(0);
@@ -1093,6 +1143,17 @@ static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
   return success();
 }
 
+void ConvOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  effects.emplace_back(MemoryEffects::Read::get(), input(),
+                       SideEffects::DefaultResource::get());
+  effects.emplace_back(MemoryEffects::Read::get(), filter(),
+                       SideEffects::DefaultResource::get());
+  effects.emplace_back(MemoryEffects::Write::get(), output(),
+                       SideEffects::DefaultResource::get());
+}
+
 static LogicalResult verify(ConvOp op) {
   auto oType = op.output().getType().cast<MemRefType>();
   auto fType = op.filter().getType().cast<MemRefType>();
@@ -1142,6 +1203,16 @@ static LogicalResult verifySingleInputPoolingOp(PoolingOp op) {
   return success();
 }
 
+#define DEFINE_POOLING_OP_GET_EFFECTS(OP_NAME)                                 \
+  void OP_NAME::getEffects(                                                    \
+      SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>      \
+          &effects) {                                                          \
+    effects.emplace_back(MemoryEffects::Read::get(), input(),                  \
+                         SideEffects::DefaultResource::get());                 \
+    effects.emplace_back(MemoryEffects::Write::get(), output(),                \
+                         SideEffects::DefaultResource::get());                 \
+  }
+
 static LogicalResult verify(PoolingMaxOp op) {
   return verifySingleInputPoolingOp(op);
 }
@@ -1152,6 +1223,10 @@ static LogicalResult verify(PoolingSumOp op) {
   return verifySingleInputPoolingOp(op);
 }
 
+DEFINE_POOLING_OP_GET_EFFECTS(PoolingMaxOp);
+DEFINE_POOLING_OP_GET_EFFECTS(PoolingMinOp);
+DEFINE_POOLING_OP_GET_EFFECTS(PoolingSumOp);
+
 namespace {
 struct EraseDeadLinalgOp;
 struct FoldTensorCastOp;
@@ -1472,7 +1547,8 @@ static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
   p.printOptionalAttrDict(op.getAttrs(),
                           /*elidedAttrs=*/{"operand_segment_sizes"});
 
-  // Printing is shared with generic ops, except for the region and attributes.
+  // Printing is shared with generic ops, except for the region and
+  // attributes.
   printCommonStructuredOpParts(p, op);
 
   // Results printing.
@@ -1586,4 +1662,5 @@ CANONICALIZERS_AND_FOLDERS(FillOp)
 CANONICALIZERS_AND_FOLDERS(GenericOp)
 CANONICALIZERS_AND_FOLDERS(IndexedGenericOp)
 
-// All named ops canonicalizers and folders are auto-generated in the .cpp.inc.
+// All named ops canonicalizers and folders are auto-generated in the
+// .cpp.inc.

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index d68aff4d270c..a3d0db64c5e4 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -334,3 +334,20 @@ func @tensor_cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf3
 
   return %1: tensor<3x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @linalg_effects(
+//  CHECK-SAME:     %[[A:[a-z0-9]*]]: tensor<?x?xf32>
+//  CHECK-SAME:     %[[B:[a-z0-9]*]]: memref<?x?xf32>
+//  CHECK-SAME:     %[[C:[a-z0-9]*]]: tensor<?x?xf32>
+func @linalg_effects(%a : tensor<?x?xf32>, %b : memref<?x?xf32>, %c : tensor<?x?xf32>) {
+  // CHECK-NOT:   %{{.*}} = linalg.matmul
+  %t = linalg.matmul ins(%a, %b : tensor<?x?xf32>, memref<?x?xf32>)
+                    init(%c : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+  // CHECK-NOT:   %{{.*}} = linalg.matmul
+  linalg.matmul ins(%a, %c : tensor<?x?xf32>, tensor<?x?xf32>)
+               outs(%b : memref<?x?xf32>)
+  return
+}

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 54d8bef9caf3..1fd71b031ab3 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -251,7 +251,7 @@ func @indexed_generic_op_zero_dim_constant_fusion
 
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 func @generic_op_indexed_generic_op_fusion(%arg0: tensor<?x?xi32>,
-                                           %arg1: tensor<?x?xi32>) {
+                                           %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
     %0 = linalg.generic {
       indexing_maps = [#map0, #map0, #map0],
       iterator_types = ["parallel", "parallel"] }
@@ -271,7 +271,7 @@ func @generic_op_indexed_generic_op_fusion(%arg0: tensor<?x?xi32>,
       %5 = subi %4, %3 : i32
       linalg.yield %5 : i32
     } -> tensor<?x?xi32>
-  return
+  return %1 : tensor<?x?xi32>
 }
 //   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func @generic_op_indexed_generic_op_fusion
@@ -294,7 +294,7 @@ func @generic_op_indexed_generic_op_fusion(%arg0: tensor<?x?xi32>,
 
 #map0 = affine_map<(d0, d1) -> (d0, d1)>
 func @indexed_generic_op_generic_op_fusion(%arg0: tensor<?x?xi32>,
-                                           %arg1: tensor<?x?xi32>) {
+                                           %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
   %0 = linalg.indexed_generic {
     indexing_maps = [#map0, #map0],
     iterator_types = ["parallel", "parallel"] }
@@ -314,7 +314,7 @@ func @indexed_generic_op_generic_op_fusion(%arg0: tensor<?x?xi32>,
     %10 = addi %arg2, %arg3 : i32
     linalg.yield %10 : i32
   } -> tensor<?x?xi32>
-  return
+  return %1 : tensor<?x?xi32>
 }
 //   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func @indexed_generic_op_generic_op_fusion
@@ -338,7 +338,7 @@ func @indexed_generic_op_generic_op_fusion(%arg0: tensor<?x?xi32>,
 // The indices of the first indexed_generic op are swapped after fusion.
 #map0 = affine_map<(d0, d1) -> (d1, d0)>
 #map1 = affine_map<(d0, d1) -> (d0, d1)>
-func @indexed_generic_op_fusion(%arg0: tensor<?x?xi32>) {
+func @indexed_generic_op_fusion(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
     %0 = linalg.indexed_generic {
       indexing_maps = [#map0, #map0],
       iterator_types = ["parallel", "parallel"] }
@@ -361,7 +361,7 @@ func @indexed_generic_op_fusion(%arg0: tensor<?x?xi32>) {
       %5 = subi %4, %3 : i32
       linalg.yield %5 : i32
     } -> tensor<?x?xi32>
-  return
+  return %1 : tensor<?x?xi32>
 }
 //   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-LABEL: func @indexed_generic_op_fusion
@@ -420,4 +420,4 @@ func @scalar_indexed_generic_fusion
 //  CHECK-SAME:     ins(%[[ARG1]] : tensor<i32>)
 //       CHECK:     extract_element %[[ARG0]]
 //       CHECK:     linalg.yield
-//       CHECK   return %[[T0]]
\ No newline at end of file
+//       CHECK   return %[[T0]]

diff  --git a/mlir/test/Transforms/copy-removal.mlir b/mlir/test/Transforms/copy-removal.mlir
index 1a8ea02023b8..ebd737b32aca 100644
--- a/mlir/test/Transforms/copy-removal.mlir
+++ b/mlir/test/Transforms/copy-removal.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt -copy-removal -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -copy-removal -split-input-file %s
+//| FileCheck %s
 
 // All linalg copies except the linalg.copy(%1, %9) must be removed since the
 // defining operation of %1 and its DeallocOp have been defined in another block.
@@ -145,9 +146,20 @@ func @test_with_temp_usage_before_copy() -> memref<5xf32> {
 // -----
 
 // It is legal to remove the copy operation that %temp has a usage after the copy
-// operation. The allocation of %temp and the deallocation of %ret should be also
+// operation. The allocation of %temp and the deallocation of %ret could be also
 // removed.
 
+// However the following pattern is not handled by copy removal.
+//   %from = alloc()
+//   %to = alloc()
+//   copy(%from, %to)
+//   read_from(%from) + write_to(%something_else)
+//   dealloc(%from)
+//   return %to
+// In particular, linalg.generic is a memoryEffectOp between copy and dealloc.
+// Since no alias analysis is performed and no distinction is made between reads
+// and writes, the linalg.generic with effects blocks copy removal.
+
 #map0 = affine_map<(d0) -> (d0)>
 
 // CHECK-LABEL: func @test_with_temp_usage_after_copy
@@ -170,10 +182,11 @@ func @test_with_temp_usage_after_copy() -> memref<5xf32> {
 }
 // CHECK-NEXT: %[[ret:.*]] = alloc()
 // CHECK-NEXT: %[[res:.*]] = alloc()
-// CHECK-NOT: %{{.*}} = alloc()
-// CHECK-NOT: linalg.copy
-// CHECK-NOT: dealloc %[[ret]]
-// CHECK: return %[[ret]]
+// CHECK-NEXT: %[[temp:.*]] = alloc()
+// CHECK-NEXT: linalg.copy(%[[ret]], %[[temp]])
+// CHECK-NEXT: linalg.generic
+//      CHECK: dealloc %[[ret]]
+//      CHECK: return %[[temp]]
 
 // -----
 

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index 9183f3a85b48..528fae883d19 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -2,8 +2,9 @@
 // RUN: mlir-linalg-ods-gen %s -gen-impl=1 | FileCheck %s --check-prefix=IMPL
 
 // ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1", [
-//  ODS-NEXT:   NamedStructuredOpTrait
 //  ODS-NEXT:   AttrSizedOperandSegments
+//  ODS-NEXT:   DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+//  ODS-NEXT:   NamedStructuredOpTrait
 //  ODS-NEXT:   SingleBlockImplicitTerminator<"YieldOp">
 //
 // IMPL-LABEL:  ArrayAttr Test1Op::iterator_types() {
@@ -26,8 +27,9 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
 }
 
 // ODS-LABEL: def Test2Op : LinalgStructuredBase_Op<"test2", [
-//  ODS-NEXT:   NamedStructuredOpTrait
 //  ODS-NEXT:   AttrSizedOperandSegments
+//  ODS-NEXT:   DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+//  ODS-NEXT:   NamedStructuredOpTrait
 //  ODS-NEXT:   SingleBlockImplicitTerminator<"YieldOp">
 //
 // IMPL-LABEL:  ArrayAttr Test2Op::iterator_types() {
@@ -50,8 +52,9 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
 }
 
 // ODS-LABEL: def Test3Op : LinalgStructuredBase_Op<"test3", [
-//  ODS-NEXT:   NamedStructuredOpTrait
 //  ODS-NEXT:   AttrSizedOperandSegments
+//  ODS-NEXT:   DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+//  ODS-NEXT:   NamedStructuredOpTrait
 //  ODS-NEXT:   SingleBlockImplicitTerminator<"YieldOp">
 //
 // IMPL-LABEL:  ArrayAttr Test3Op::iterator_types() {

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 4e7fa3ba7a34..e7e5ef8901b8 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1451,8 +1451,9 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
                         StringRef linalgOpName,
                         ComprehensionParsingState &state) {
   const char *header = R"FMT(  def {0} : LinalgStructuredBase_Op<"{1}", [
-    NamedStructuredOpTrait,
     AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+    NamedStructuredOpTrait,
     SingleBlockImplicitTerminator<"YieldOp">]> {
       let arguments = (ins Variadic<AnyShaped>:$inputs,
                            Variadic<AnyMemRef>:$output_buffers,
@@ -1589,6 +1590,11 @@ void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
     LogicalResult {0}::fold(ArrayRef<Attribute>,
                             SmallVectorImpl<OpFoldResult> &) {{
       return foldMemRefCast(*this);
+    }
+    void {0}::getEffects(SmallVectorImpl<
+        SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
+      getGenericEffectsImpl(effects,
+        getOperation()->getResults(), getInputBuffers(), getOutputBuffers());
     })FMT";
   os << llvm::formatv(canonicalizersAndFoldersFmt, cppOpName);
 }


        


More information about the Mlir-commits mailing list