[Mlir-commits] [mlir] e0dc3db - [mlir][Linalg] NFC - Cleanup explicitly instantiated paterns 1/n - LinalgToStandard.cpp

Nicolas Vasilache llvmlistbot at llvm.org
Fri Oct 9 12:42:35 PDT 2020


Author: Nicolas Vasilache
Date: 2020-10-09T19:41:41Z
New Revision: e0dc3dba3bd1db450391d7fda040d4fcc830e5e3

URL: https://github.com/llvm/llvm-project/commit/e0dc3dba3bd1db450391d7fda040d4fcc830e5e3
DIFF: https://github.com/llvm/llvm-project/commit/e0dc3dba3bd1db450391d7fda040d4fcc830e5e3.diff

LOG: [mlir][Linalg] NFC - Cleanup explicitly instantiated paterns 1/n - LinalgToStandard.cpp

This revision belongs to a series of patches that reduce reliance of Linalg transformations on templated rewrite and conversion patterns.
Instead, this uses a MatchAnyTag pattern for the vast majority of cases and dispatches internally.

Differential Revision: https://reviews.llvm.org/D89133

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
index 6585eaf35ef6..08b3981d0b67 100644
--- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
+++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
@@ -12,15 +12,10 @@
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
-class MLIRContext;
 class ModuleOp;
 template <typename T>
 class OperationPass;
 
-/// Populate the given list with patterns that convert from Linalg to Standard.
-void populateLinalgToStandardConversionPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *ctx);
-
 /// Create a pass to convert Linalg operations to the Standard dialect.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertLinalgToStandardPass();
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 9c8197c45ec8..2332f516c44a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -502,8 +502,9 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
         getIteratorTypesAttrName(), getSymbolSourceAttrName()
       };
     }
-    StringRef getLibraryCallName() {
-      return library_call().hasValue() ? library_call().getValue() : "";
+    std::string getLibraryCallName() {
+      return library_call().hasValue() ?
+        library_call()->str() : "op_has_no_registered_library_name";
     }
     llvm::Optional<unsigned> getSymbolSource() {
       auto ss = symbol_source();

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index 614fd8d2a7de..dbb89c73954b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -594,6 +594,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
                llvm::all_of(this->getOperation()->getResults(), isTensorType);
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the name registered for this op when lowering to an external
+        library call.
+      }],
+      /*retTy=*/"std::string",
+      /*methodName=*/"getLibraryCallName",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return $_op.getLibraryCallName();
+      }]
+    >,
 
     //===------------------------------------------------------------------===//
     // Other static interface methods.

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 395db396dadc..7512f69608a4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -347,9 +347,7 @@ struct LinalgTilingOptions {
   /// values must not fold away when tiling. Otherwise, use a more robust
   /// `tileSizeComputationFunction`.
   LinalgTilingOptions &setTileSizes(SmallVector<Value, 4> ts) {
-    tileSizeComputationFunction = [=](OpBuilder &, Operation *) {
-      return ts;
-    };
+    tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
     return *this;
   }
   /// Convenience function to set the `tileSizeComputationFunction` to a
@@ -749,6 +747,56 @@ class ConvOpVectorization : public OpRewritePattern<ConvOp> {
                                 PatternRewriter &rewriter) const override;
 };
 
+//===----------------------------------------------------------------------===//
+// Patterns to convert a LinalgOp to std.call @external library implementation.
+//===----------------------------------------------------------------------===//
+// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
+// function. The implementation of the function can be either in the same module
+// or in an externally linked library.
+// This is a generic entry point for all LinalgOp, except for CopyOp and
+// IndexedGenericOp, for which omre specialized patterns are provided.
+class LinalgOpToLibraryCallRewrite : public RewritePattern {
+public:
+  LinalgOpToLibraryCallRewrite()
+      : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Rewrite pattern specialization for CopyOp, kicks in when both input and
+/// output permutations are left unspecified or are the identity.
+class CopyOpToLibraryCallRewrite : public OpRewritePattern<CopyOp> {
+public:
+  using OpRewritePattern<CopyOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(CopyOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Rewrite CopyOp with permutations into a sequence of TransposeOp and
+/// permutation-free CopyOp. This interplays with TransposeOpConversion and
+/// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
+class CopyTransposeRewrite : public OpRewritePattern<CopyOp> {
+public:
+  using OpRewritePattern<CopyOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(CopyOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Conversion pattern specialization for IndexedGenericOp, has special handling
+/// for the extra index operands.
+class IndexedGenericOpToLibraryCallRewrite
+    : public OpRewritePattern<IndexedGenericOp> {
+public:
+  using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(IndexedGenericOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Populate the given list with patterns that convert from Linalg to Standard.
+void populateLinalgToStandardConversionPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *ctx);
+
 //===----------------------------------------------------------------------===//
 // Support for staged pattern application.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index ffb56138a795..d64e6f9947c7 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -11,6 +11,7 @@
 #include "../PassDetail.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 
@@ -21,10 +22,15 @@ using namespace mlir::linalg;
 /// generated CallOp. MemRefTypes have their layout canonicalized since the
 /// information is not used in signature generation.
 /// Note that static size information is not modified.
-template <typename LinalgOp>
 static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
   SmallVector<Type, 4> result;
   result.reserve(op->getNumOperands());
+  if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op)) {
+    auto *ctx = op->getContext();
+    auto numLoops = indexedGenericOp.getNumLoops();
+    result.reserve(op->getNumOperands() + numLoops);
+    result.assign(numLoops, IndexType::get(ctx));
+  }
   for (auto type : op->getOperandTypes()) {
     // The underlying descriptor type (e.g. LLVM) does not have layout
     // information. Canonicalizing the type at the level of std when going into
@@ -37,21 +43,8 @@ static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
   return result;
 }
 
-template <>
-SmallVector<Type, 4> extractOperandTypes<IndexedGenericOp>(Operation *op) {
-  auto *ctx = op->getContext();
-  auto indexedGenericOp = cast<IndexedGenericOp>(op);
-  auto numLoops = indexedGenericOp.getNumLoops();
-
-  SmallVector<Type, 4> result(numLoops, IndexType::get(ctx));
-  auto canonicalizedOperands = extractOperandTypes<LinalgOp>(op);
-  result.append(canonicalizedOperands.begin(), canonicalizedOperands.end());
-  return result;
-}
-
 // Get a SymbolRefAttr containing the library function name for the LinalgOp.
 // If the library function does not exist, insert a declaration.
-template <typename LinalgOp>
 static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
                                                  PatternRewriter &rewriter) {
   auto linalgOp = cast<LinalgOp>(op);
@@ -68,7 +61,7 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
     return fnNameAttr;
   }
 
-  SmallVector<Type, 4> inputTypes(extractOperandTypes<LinalgOp>(op));
+  SmallVector<Type, 4> inputTypes(extractOperandTypes(op));
   assert(op->getNumResults() == 0 &&
          "Library call for linalg operation can be generated only for ops that "
          "have void return types");
@@ -87,9 +80,7 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
   return fnNameAttr;
 }
 
-namespace {
-
-SmallVector<Value, 4>
+static SmallVector<Value, 4>
 createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
                                       ValueRange operands) {
   SmallVector<Value, 4> res;
@@ -107,154 +98,101 @@ createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
   return res;
 }
 
-// LinalgOpConversion<LinalgOp> creates a new call to the type-canonicalized
-// `LinalgOp::getLibraryCallName()` function.
-// The implementation of the function can be either in the same module or in an
-// externally linked library.
-template <typename LinalgOp>
-class LinalgOpConversion : public OpRewritePattern<LinalgOp> {
-public:
-  using OpRewritePattern<LinalgOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(LinalgOp op,
-                                PatternRewriter &rewriter) const override {
-    auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter);
-    if (!libraryCallName)
-      return failure();
-
-    rewriter.replaceOpWithNewOp<mlir::CallOp>(
-        op, libraryCallName.getValue(), TypeRange(),
-        createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
-                                              op.getOperands()));
-    return success();
-  }
-};
-
-/// Conversion pattern specialization for CopyOp. This kicks in when both input
-/// and output permutations are left unspecified or are the identity.
-template <>
-class LinalgOpConversion<CopyOp> : public OpRewritePattern<CopyOp> {
-public:
-  using OpRewritePattern<CopyOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(CopyOp op,
-                                PatternRewriter &rewriter) const override {
-    auto inputPerm = op.inputPermutation();
-    if (inputPerm.hasValue() && !inputPerm->isIdentity())
-      return failure();
-    auto outputPerm = op.outputPermutation();
-    if (outputPerm.hasValue() && !outputPerm->isIdentity())
-      return failure();
-
-    auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter);
-    if (!libraryCallName)
-      return failure();
-
-    rewriter.replaceOpWithNewOp<mlir::CallOp>(
-        op, libraryCallName.getValue(), TypeRange(),
-        createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
-                                              op.getOperands()));
-    return success();
-  }
-};
-
-/// Conversion pattern specialization for IndexedGenericOp.
-template <>
-class LinalgOpConversion<IndexedGenericOp>
-    : public OpRewritePattern<IndexedGenericOp> {
-public:
-  using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(IndexedGenericOp op,
-                                PatternRewriter &rewriter) const override {
-    auto libraryCallName =
-        getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter);
-    if (!libraryCallName)
-      return failure();
-
-    // TODO: Use induction variables values instead of zeros, when
-    // IndexedGenericOp is tiled.
-    auto zero = rewriter.create<mlir::ConstantOp>(
-        op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
-    auto indexedGenericOp = cast<IndexedGenericOp>(op);
-    auto numLoops = indexedGenericOp.getNumLoops();
-    SmallVector<Value, 4> operands;
-    operands.reserve(numLoops + op.getNumOperands());
-    for (unsigned i = 0; i < numLoops; ++i)
-      operands.push_back(zero);
-    for (auto operand : op.getOperands())
-      operands.push_back(operand);
-    rewriter.replaceOpWithNewOp<mlir::CallOp>(
-        op, libraryCallName.getValue(), TypeRange(),
-        createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands));
-    return success();
-  }
-};
-
-/// A non-conversion rewrite pattern kicks in to convert CopyOp with
-/// permutations into a sequence of TransposeOp and permutation-free CopyOp.
-/// This interplays together with TransposeOpConversion and
-/// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
-class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
-public:
-  using OpRewritePattern<CopyOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(CopyOp op,
-                                PatternRewriter &rewriter) const override {
-    Value in = op.input(), out = op.output();
+LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
+    Operation *op, PatternRewriter &rewriter) const {
+  // Only LinalgOp for which there is no specialized pattern go through this.
+  if (!isa<LinalgOp>(op) || isa<CopyOp>(op) || isa<IndexedGenericOp>(op))
+    return failure();
+
+  auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
+  if (!libraryCallName)
+    return failure();
+
+  rewriter.replaceOpWithNewOp<mlir::CallOp>(
+      op, libraryCallName.getValue(), TypeRange(),
+      createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(),
+                                            op->getOperands()));
+  return success();
+}
 
-    // If either inputPerm or outputPerm are non-identities, insert transposes.
-    auto inputPerm = op.inputPermutation();
-    if (inputPerm.hasValue() && !inputPerm->isIdentity())
-      in = rewriter.create<TransposeOp>(op.getLoc(), in,
-                                        AffineMapAttr::get(*inputPerm));
-    auto outputPerm = op.outputPermutation();
-    if (outputPerm.hasValue() && !outputPerm->isIdentity())
-      out = rewriter.create<TransposeOp>(op.getLoc(), out,
-                                         AffineMapAttr::get(*outputPerm));
+LogicalResult mlir::linalg::CopyOpToLibraryCallRewrite::matchAndRewrite(
+    CopyOp op, PatternRewriter &rewriter) const {
+  auto inputPerm = op.inputPermutation();
+  if (inputPerm.hasValue() && !inputPerm->isIdentity())
+    return failure();
+  auto outputPerm = op.outputPermutation();
+  if (outputPerm.hasValue() && !outputPerm->isIdentity())
+    return failure();
+
+  auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
+  if (!libraryCallName)
+    return failure();
+
+  rewriter.replaceOpWithNewOp<mlir::CallOp>(
+      op, libraryCallName.getValue(), TypeRange(),
+      createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
+                                            op.getOperands()));
+  return success();
+}
 
-    // If nothing was transposed, fail and let the conversion kick in.
-    if (in == op.input() && out == op.output())
-      return failure();
+LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
+    CopyOp op, PatternRewriter &rewriter) const {
+  Value in = op.input(), out = op.output();
+
+  // If either inputPerm or outputPerm are non-identities, insert transposes.
+  auto inputPerm = op.inputPermutation();
+  if (inputPerm.hasValue() && !inputPerm->isIdentity())
+    in = rewriter.create<TransposeOp>(op.getLoc(), in,
+                                      AffineMapAttr::get(*inputPerm));
+  auto outputPerm = op.outputPermutation();
+  if (outputPerm.hasValue() && !outputPerm->isIdentity())
+    out = rewriter.create<TransposeOp>(op.getLoc(), out,
+                                       AffineMapAttr::get(*outputPerm));
+
+  // If nothing was transposed, fail and let the conversion kick in.
+  if (in == op.input() && out == op.output())
+    return failure();
+
+  rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
+  return success();
+}
 
-    rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
-    return success();
-  }
-};
-} // namespace
+LogicalResult
+mlir::linalg::IndexedGenericOpToLibraryCallRewrite::matchAndRewrite(
+    IndexedGenericOp op, PatternRewriter &rewriter) const {
+  auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
+  if (!libraryCallName)
+    return failure();
+
+  // TODO: Use induction variables values instead of zeros, when
+  // IndexedGenericOp is tiled.
+  auto zero = rewriter.create<mlir::ConstantOp>(
+      op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
+  auto indexedGenericOp = cast<IndexedGenericOp>(op);
+  auto numLoops = indexedGenericOp.getNumLoops();
+  SmallVector<Value, 4> operands;
+  operands.reserve(numLoops + op.getNumOperands());
+  for (unsigned i = 0; i < numLoops; ++i)
+    operands.push_back(zero);
+  for (auto operand : op.getOperands())
+    operands.push_back(operand);
+  rewriter.replaceOpWithNewOp<mlir::CallOp>(
+      op, libraryCallName.getValue(), TypeRange(),
+      createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands));
+  return success();
+}
 
 /// Populate the given list with patterns that convert from Linalg to Standard.
-void mlir::populateLinalgToStandardConversionPatterns(
+void mlir::linalg::populateLinalgToStandardConversionPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx) {
   // TODO: ConvOp conversion needs to export a descriptor with relevant
   // attribute values such as kernel striding and dilation.
   // clang-format off
   patterns.insert<
-      CopyTransposeConversion,
-      LinalgOpConversion<ConvOp>,
-      LinalgOpConversion<PoolingMaxOp>,
-      LinalgOpConversion<PoolingMinOp>,
-      LinalgOpConversion<PoolingSumOp>,
-      LinalgOpConversion<CopyOp>,
-      LinalgOpConversion<FillOp>,
-      LinalgOpConversion<GenericOp>,
-      LinalgOpConversion<IndexedGenericOp>>(ctx);
-  // TODO: collect all auto-generated named ops with a tblgen directive.
-  patterns.insert<
-      LinalgOpConversion<DotOp>,
-      LinalgOpConversion<BatchMatmulOp>,
-      LinalgOpConversion<MatvecOp>,
-      LinalgOpConversion<VecmatOp>,
-      LinalgOpConversion<MatmulOp>,
-      LinalgOpConversion<ConvWOp>,
-      LinalgOpConversion<ConvNWCOp>,
-      LinalgOpConversion<ConvNCWOp>,
-      LinalgOpConversion<ConvHWOp>,
-      LinalgOpConversion<ConvNHWCOp>,
-      LinalgOpConversion<ConvNCHWOp>,
-      LinalgOpConversion<ConvDHWOp>,
-      LinalgOpConversion<ConvNDHWCOp>,
-      LinalgOpConversion<ConvNCDHWOp>>(ctx);
+      CopyOpToLibraryCallRewrite,
+      CopyTransposeRewrite,
+      IndexedGenericOpToLibraryCallRewrite>(ctx);
+  patterns.insert<LinalgOpToLibraryCallRewrite>();
   // clang-format on
 }
 


        


More information about the Mlir-commits mailing list