[Mlir-commits] [mlir] 66e2708 - [mlir:Linalg] Populate LinalgOp patterns on LinalgDialect as opposed to each op

River Riddle llvmlistbot at llvm.org
Mon Jun 14 11:20:55 PDT 2021


Author: River Riddle
Date: 2021-06-14T11:20:15-07:00
New Revision: 66e27082054bf8f106ce63d5cb7d6c67e628feff

URL: https://github.com/llvm/llvm-project/commit/66e27082054bf8f106ce63d5cb7d6c67e628feff
DIFF: https://github.com/llvm/llvm-project/commit/66e27082054bf8f106ce63d5cb7d6c67e628feff.diff

LOG: [mlir:Linalg] Populate LinalgOp patterns on LinalgDialect as opposed to each op

Interface patterns are unique in that they get added to every operation that also implements that interface, given that they aren't tied to individual operations. When the same interface pattern gets added to multiple operations (such as the current behavior with Linalg), an reference to each of these patterns is added to every op (meaning that an operation will now have N references to effectively the same pattern). This revision fixes this problematic behavior in Linalg, and can bring upwards of a 25% reduction in compile time in Linalg based workloads.

Differential Revision: https://reviews.llvm.org/D104160

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index bc9acee6fd6e7..07a378b39d066 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -35,6 +35,7 @@ def Linalg_Dialect : Dialect {
   let dependentDialects = [
     "AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
   ];
+  let hasCanonicalizer = 1;
   let hasOperationAttrVerify = 1;
   let extraClassDeclaration = [{
     /// Attribute name used to to memoize indexing maps for named ops.

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 2b70572ac1e1c..54dda4c4fcb2b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -178,7 +178,6 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
   }];
 
   let hasFolder = 1;
-  let hasCanonicalizer = 1;
   let skipDefaultBuilders = 1;
 }
 
@@ -230,7 +229,6 @@ def FillOp : LinalgStructured_Op<"fill", []> {
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
-  let hasCanonicalizer = 1;
 }
 
 /// A base class for pooling operation such as conv. The arguments must contain
@@ -427,7 +425,6 @@ def ConvOp : PoolingBase_Op<"conv", []> {
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
-  let hasCanonicalizer = 1;
 }
 
 // Only support buffer semantics.
@@ -490,7 +487,6 @@ class SingleInputPoolingBase_Op<string mnemonic>
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
-  let hasCanonicalizer = 1;
 }
 
 def PoolingMaxOp: SingleInputPoolingBase_Op<"pooling_max"> {
@@ -673,7 +669,6 @@ def GenericOp : GenericOpBase<"generic"> {
   let verifier = [{ return ::verify(*this); }];
 
   let hasFolder = 1;
-  let hasCanonicalizer = 1;
 }
 
 /// GenericOp with Indexing (i.e. multi-for style in which the region is passed

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6eef6b0b48efc..8114a22ec50e6 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2787,11 +2787,6 @@ DEFINE_POOLING_OP_GET_EFFECTS(PoolingMaxOp)
 DEFINE_POOLING_OP_GET_EFFECTS(PoolingMinOp)
 DEFINE_POOLING_OP_GET_EFFECTS(PoolingSumOp)
 
-namespace {
-struct EraseDeadLinalgOp;
-struct FoldTensorCastOp;
-} // namespace
-
 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.cpp.inc"
 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
 
@@ -3374,25 +3369,29 @@ struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern<LinalgOp> {
 };
 } // namespace
 
-#define CANONICALIZERS_AND_FOLDERS(XXX)                                        \
-  void XXX::getCanonicalizationPatterns(RewritePatternSet &results,            \
-                                        MLIRContext *context) {                \
-    results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp,        \
-                RemoveIdentityLinalgOps>(context);                             \
-  }                                                                            \
-                                                                               \
+#define LINALGOP_FOLDERS(XXX)                                                  \
   LogicalResult XXX::fold(ArrayRef<Attribute>,                                 \
                           SmallVectorImpl<OpFoldResult> &) {                   \
     return foldMemRefCast(*this);                                              \
   }
 
-CANONICALIZERS_AND_FOLDERS(ConvOp)
-CANONICALIZERS_AND_FOLDERS(PoolingMaxOp)
-CANONICALIZERS_AND_FOLDERS(PoolingMinOp)
-CANONICALIZERS_AND_FOLDERS(PoolingSumOp)
-CANONICALIZERS_AND_FOLDERS(CopyOp)
-CANONICALIZERS_AND_FOLDERS(FillOp)
-CANONICALIZERS_AND_FOLDERS(GenericOp)
+LINALGOP_FOLDERS(ConvOp)
+LINALGOP_FOLDERS(PoolingMaxOp)
+LINALGOP_FOLDERS(PoolingMinOp)
+LINALGOP_FOLDERS(PoolingSumOp)
+LINALGOP_FOLDERS(CopyOp)
+LINALGOP_FOLDERS(FillOp)
+LINALGOP_FOLDERS(GenericOp)
 
 // All named ops canonicalizers and folders are auto-generated in the
 // .cpp.inc.
+
+//===----------------------------------------------------------------------===//
+// LinalgDialect
+//===----------------------------------------------------------------------===//
+
+void LinalgDialect::getCanonicalizationPatterns(
+    RewritePatternSet &results) const {
+  results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp,
+              RemoveIdentityLinalgOps>(getContext());
+}

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index f65a0fa1772d4..5e76361324a08 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -1405,6 +1405,8 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
   IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
   TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
   TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
+  context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
+      patterns);
 }
 
 void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index b46ac20ec6b82..ab80be520c9c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -414,6 +414,7 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
   memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
   tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
   memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
+  ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns);
   CanonicalizationPatternList<
 #define GET_OP_LIST
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 9e86bf3227c1a..faa2835d589e7 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1959,7 +1959,6 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
         return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
       }];
       let hasFolder = 1;
-      let hasCanonicalizer = 1;
 
       let extraClassDeclaration = structuredOpsBaseDecls # [{{
         // Auto-generated.
@@ -2094,13 +2093,7 @@ void TCParser::printReferenceIterators(llvm::raw_ostream &os,
 
 void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
                                              StringRef cppOpName) {
-  const char *canonicalizersAndFoldersFmt = R"FMT(
-    void {0}::getCanonicalizationPatterns(
-        RewritePatternSet &results,
-        MLIRContext *context) {{
-      results.add<EraseDeadLinalgOp>(context);
-      results.add<FoldTensorCastOp>(context);
-    }
+  const char *foldersFmt = R"FMT(
     LogicalResult {0}::fold(ArrayRef<Attribute>,
                             SmallVectorImpl<OpFoldResult> &) {{
       return foldMemRefCast(*this);
@@ -2112,7 +2105,7 @@ void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
       getGenericEffectsImpl(effects,
         getOperation()->getResults(), inputBuffers, outputBuffers);
     })FMT";
-  os << llvm::formatv(canonicalizersAndFoldersFmt, cppOpName);
+  os << llvm::formatv(foldersFmt, cppOpName);
 }
 
 // Prints methods for querying whether the current named op has attributes that

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 312724c94b680..c907caef649de 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -503,7 +503,6 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([
       return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
     }];
     let hasFolder = 1;
-    let hasCanonicalizer = 1;
 
     let extraClassDeclaration = structuredOpsBaseDecls # [{{
       // Auto-generated.
@@ -535,16 +534,10 @@ ArrayAttr {0}::iterator_types() {
 }
 )FMT";
 
-// Implementations of getCanonicalizationPatterns, fold and getEffects.
+// Implementations of fold and getEffects.
 // Parameters:
 // {0}: Class name
-const char structuredOpCanonicalizersAndFoldersFormat[] = R"FMT(
-void {0}::getCanonicalizationPatterns(
-    RewritePatternSet &results,
-    MLIRContext *context) {{
-  results.add<EraseDeadLinalgOp>(context);
-  results.add<FoldTensorCastOp>(context);
-}
+const char structuredOpFoldersFormat[] = R"FMT(
 LogicalResult {0}::fold(ArrayRef<Attribute>,
                         SmallVectorImpl<OpFoldResult> &) {{
   return foldMemRefCast(*this);
@@ -880,7 +873,7 @@ void {0}::regionBuilder(
   }
 
   // Canonicalizers and folders.
-  os << llvm::formatv(structuredOpCanonicalizersAndFoldersFormat, className);
+  os << llvm::formatv(structuredOpFoldersFormat, className);
 
   return success();
 }


        


More information about the Mlir-commits mailing list