[Mlir-commits] [mlir] e0dc3db - [mlir][Linalg] NFC - Cleanup explicitly instantiated paterns 1/n - LinalgToStandard.cpp
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Oct 9 12:42:35 PDT 2020
Author: Nicolas Vasilache
Date: 2020-10-09T19:41:41Z
New Revision: e0dc3dba3bd1db450391d7fda040d4fcc830e5e3
URL: https://github.com/llvm/llvm-project/commit/e0dc3dba3bd1db450391d7fda040d4fcc830e5e3
DIFF: https://github.com/llvm/llvm-project/commit/e0dc3dba3bd1db450391d7fda040d4fcc830e5e3.diff
LOG: [mlir][Linalg] NFC - Cleanup explicitly instantiated paterns 1/n - LinalgToStandard.cpp
This revision belongs to a series of patches that reduce reliance of Linalg transformations on templated rewrite and conversion patterns.
Instead, this uses a MatchAnyTag pattern for the vast majority of cases and dispatches internally.
Differential Revision: https://reviews.llvm.org/D89133
Added:
Modified:
mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
index 6585eaf35ef6..08b3981d0b67 100644
--- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
+++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
@@ -12,15 +12,10 @@
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
-class MLIRContext;
class ModuleOp;
template <typename T>
class OperationPass;
-/// Populate the given list with patterns that convert from Linalg to Standard.
-void populateLinalgToStandardConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx);
-
/// Create a pass to convert Linalg operations to the Standard dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertLinalgToStandardPass();
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 9c8197c45ec8..2332f516c44a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -502,8 +502,9 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
getIteratorTypesAttrName(), getSymbolSourceAttrName()
};
}
- StringRef getLibraryCallName() {
- return library_call().hasValue() ? library_call().getValue() : "";
+ std::string getLibraryCallName() {
+ return library_call().hasValue() ?
+ library_call()->str() : "op_has_no_registered_library_name";
}
llvm::Optional<unsigned> getSymbolSource() {
auto ss = symbol_source();
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index 614fd8d2a7de..dbb89c73954b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -594,6 +594,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
llvm::all_of(this->getOperation()->getResults(), isTensorType);
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the name registered for this op when lowering to an external
+ library call.
+ }],
+ /*retTy=*/"std::string",
+ /*methodName=*/"getLibraryCallName",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return $_op.getLibraryCallName();
+ }]
+ >,
//===------------------------------------------------------------------===//
// Other static interface methods.
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 395db396dadc..7512f69608a4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -347,9 +347,7 @@ struct LinalgTilingOptions {
/// values must not fold away when tiling. Otherwise, use a more robust
/// `tileSizeComputationFunction`.
LinalgTilingOptions &setTileSizes(SmallVector<Value, 4> ts) {
- tileSizeComputationFunction = [=](OpBuilder &, Operation *) {
- return ts;
- };
+ tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
return *this;
}
/// Convenience function to set the `tileSizeComputationFunction` to a
@@ -749,6 +747,56 @@ class ConvOpVectorization : public OpRewritePattern<ConvOp> {
PatternRewriter &rewriter) const override;
};
+//===----------------------------------------------------------------------===//
+// Patterns to convert a LinalgOp to std.call @external library implementation.
+//===----------------------------------------------------------------------===//
+// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
+// function. The implementation of the function can be either in the same module
+// or in an externally linked library.
+// This is a generic entry point for all LinalgOp, except for CopyOp and
+// IndexedGenericOp, for which omre specialized patterns are provided.
+class LinalgOpToLibraryCallRewrite : public RewritePattern {
+public:
+ LinalgOpToLibraryCallRewrite()
+ : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Rewrite pattern specialization for CopyOp, kicks in when both input and
+/// output permutations are left unspecified or are the identity.
+class CopyOpToLibraryCallRewrite : public OpRewritePattern<CopyOp> {
+public:
+ using OpRewritePattern<CopyOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(CopyOp op,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Rewrite CopyOp with permutations into a sequence of TransposeOp and
+/// permutation-free CopyOp. This interplays with TransposeOpConversion and
+/// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
+class CopyTransposeRewrite : public OpRewritePattern<CopyOp> {
+public:
+ using OpRewritePattern<CopyOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(CopyOp op,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Conversion pattern specialization for IndexedGenericOp, has special handling
+/// for the extra index operands.
+class IndexedGenericOpToLibraryCallRewrite
+ : public OpRewritePattern<IndexedGenericOp> {
+public:
+ using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(IndexedGenericOp op,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Populate the given list with patterns that convert from Linalg to Standard.
+void populateLinalgToStandardConversionPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *ctx);
+
//===----------------------------------------------------------------------===//
// Support for staged pattern application.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index ffb56138a795..d64e6f9947c7 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -11,6 +11,7 @@
#include "../PassDetail.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@@ -21,10 +22,15 @@ using namespace mlir::linalg;
/// generated CallOp. MemRefTypes have their layout canonicalized since the
/// information is not used in signature generation.
/// Note that static size information is not modified.
-template <typename LinalgOp>
static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
SmallVector<Type, 4> result;
result.reserve(op->getNumOperands());
+ if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op)) {
+ auto *ctx = op->getContext();
+ auto numLoops = indexedGenericOp.getNumLoops();
+ result.reserve(op->getNumOperands() + numLoops);
+ result.assign(numLoops, IndexType::get(ctx));
+ }
for (auto type : op->getOperandTypes()) {
// The underlying descriptor type (e.g. LLVM) does not have layout
// information. Canonicalizing the type at the level of std when going into
@@ -37,21 +43,8 @@ static SmallVector<Type, 4> extractOperandTypes(Operation *op) {
return result;
}
-template <>
-SmallVector<Type, 4> extractOperandTypes<IndexedGenericOp>(Operation *op) {
- auto *ctx = op->getContext();
- auto indexedGenericOp = cast<IndexedGenericOp>(op);
- auto numLoops = indexedGenericOp.getNumLoops();
-
- SmallVector<Type, 4> result(numLoops, IndexType::get(ctx));
- auto canonicalizedOperands = extractOperandTypes<LinalgOp>(op);
- result.append(canonicalizedOperands.begin(), canonicalizedOperands.end());
- return result;
-}
-
// Get a SymbolRefAttr containing the library function name for the LinalgOp.
// If the library function does not exist, insert a declaration.
-template <typename LinalgOp>
static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
PatternRewriter &rewriter) {
auto linalgOp = cast<LinalgOp>(op);
@@ -68,7 +61,7 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
return fnNameAttr;
}
- SmallVector<Type, 4> inputTypes(extractOperandTypes<LinalgOp>(op));
+ SmallVector<Type, 4> inputTypes(extractOperandTypes(op));
assert(op->getNumResults() == 0 &&
"Library call for linalg operation can be generated only for ops that "
"have void return types");
@@ -87,9 +80,7 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
return fnNameAttr;
}
-namespace {
-
-SmallVector<Value, 4>
+static SmallVector<Value, 4>
createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
ValueRange operands) {
SmallVector<Value, 4> res;
@@ -107,154 +98,101 @@ createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc,
return res;
}
-// LinalgOpConversion<LinalgOp> creates a new call to the type-canonicalized
-// `LinalgOp::getLibraryCallName()` function.
-// The implementation of the function can be either in the same module or in an
-// externally linked library.
-template <typename LinalgOp>
-class LinalgOpConversion : public OpRewritePattern<LinalgOp> {
-public:
- using OpRewritePattern<LinalgOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(LinalgOp op,
- PatternRewriter &rewriter) const override {
- auto libraryCallName = getLibraryCallSymbolRef<LinalgOp>(op, rewriter);
- if (!libraryCallName)
- return failure();
-
- rewriter.replaceOpWithNewOp<mlir::CallOp>(
- op, libraryCallName.getValue(), TypeRange(),
- createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
- op.getOperands()));
- return success();
- }
-};
-
-/// Conversion pattern specialization for CopyOp. This kicks in when both input
-/// and output permutations are left unspecified or are the identity.
-template <>
-class LinalgOpConversion<CopyOp> : public OpRewritePattern<CopyOp> {
-public:
- using OpRewritePattern<CopyOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(CopyOp op,
- PatternRewriter &rewriter) const override {
- auto inputPerm = op.inputPermutation();
- if (inputPerm.hasValue() && !inputPerm->isIdentity())
- return failure();
- auto outputPerm = op.outputPermutation();
- if (outputPerm.hasValue() && !outputPerm->isIdentity())
- return failure();
-
- auto libraryCallName = getLibraryCallSymbolRef<CopyOp>(op, rewriter);
- if (!libraryCallName)
- return failure();
-
- rewriter.replaceOpWithNewOp<mlir::CallOp>(
- op, libraryCallName.getValue(), TypeRange(),
- createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
- op.getOperands()));
- return success();
- }
-};
-
-/// Conversion pattern specialization for IndexedGenericOp.
-template <>
-class LinalgOpConversion<IndexedGenericOp>
- : public OpRewritePattern<IndexedGenericOp> {
-public:
- using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(IndexedGenericOp op,
- PatternRewriter &rewriter) const override {
- auto libraryCallName =
- getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter);
- if (!libraryCallName)
- return failure();
-
- // TODO: Use induction variables values instead of zeros, when
- // IndexedGenericOp is tiled.
- auto zero = rewriter.create<mlir::ConstantOp>(
- op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
- auto indexedGenericOp = cast<IndexedGenericOp>(op);
- auto numLoops = indexedGenericOp.getNumLoops();
- SmallVector<Value, 4> operands;
- operands.reserve(numLoops + op.getNumOperands());
- for (unsigned i = 0; i < numLoops; ++i)
- operands.push_back(zero);
- for (auto operand : op.getOperands())
- operands.push_back(operand);
- rewriter.replaceOpWithNewOp<mlir::CallOp>(
- op, libraryCallName.getValue(), TypeRange(),
- createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands));
- return success();
- }
-};
-
-/// A non-conversion rewrite pattern kicks in to convert CopyOp with
-/// permutations into a sequence of TransposeOp and permutation-free CopyOp.
-/// This interplays together with TransposeOpConversion and
-/// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
-class CopyTransposeConversion : public OpRewritePattern<CopyOp> {
-public:
- using OpRewritePattern<CopyOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(CopyOp op,
- PatternRewriter &rewriter) const override {
- Value in = op.input(), out = op.output();
+LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ // Only LinalgOp for which there is no specialized pattern go through this.
+ if (!isa<LinalgOp>(op) || isa<CopyOp>(op) || isa<IndexedGenericOp>(op))
+ return failure();
+
+ auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
+ if (!libraryCallName)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<mlir::CallOp>(
+ op, libraryCallName.getValue(), TypeRange(),
+ createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(),
+ op->getOperands()));
+ return success();
+}
- // If either inputPerm or outputPerm are non-identities, insert transposes.
- auto inputPerm = op.inputPermutation();
- if (inputPerm.hasValue() && !inputPerm->isIdentity())
- in = rewriter.create<TransposeOp>(op.getLoc(), in,
- AffineMapAttr::get(*inputPerm));
- auto outputPerm = op.outputPermutation();
- if (outputPerm.hasValue() && !outputPerm->isIdentity())
- out = rewriter.create<TransposeOp>(op.getLoc(), out,
- AffineMapAttr::get(*outputPerm));
+LogicalResult mlir::linalg::CopyOpToLibraryCallRewrite::matchAndRewrite(
+ CopyOp op, PatternRewriter &rewriter) const {
+ auto inputPerm = op.inputPermutation();
+ if (inputPerm.hasValue() && !inputPerm->isIdentity())
+ return failure();
+ auto outputPerm = op.outputPermutation();
+ if (outputPerm.hasValue() && !outputPerm->isIdentity())
+ return failure();
+
+ auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
+ if (!libraryCallName)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<mlir::CallOp>(
+ op, libraryCallName.getValue(), TypeRange(),
+ createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(),
+ op.getOperands()));
+ return success();
+}
- // If nothing was transposed, fail and let the conversion kick in.
- if (in == op.input() && out == op.output())
- return failure();
+LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
+ CopyOp op, PatternRewriter &rewriter) const {
+ Value in = op.input(), out = op.output();
+
+ // If either inputPerm or outputPerm are non-identities, insert transposes.
+ auto inputPerm = op.inputPermutation();
+ if (inputPerm.hasValue() && !inputPerm->isIdentity())
+ in = rewriter.create<TransposeOp>(op.getLoc(), in,
+ AffineMapAttr::get(*inputPerm));
+ auto outputPerm = op.outputPermutation();
+ if (outputPerm.hasValue() && !outputPerm->isIdentity())
+ out = rewriter.create<TransposeOp>(op.getLoc(), out,
+ AffineMapAttr::get(*outputPerm));
+
+ // If nothing was transposed, fail and let the conversion kick in.
+ if (in == op.input() && out == op.output())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
+ return success();
+}
- rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
- return success();
- }
-};
-} // namespace
+LogicalResult
+mlir::linalg::IndexedGenericOpToLibraryCallRewrite::matchAndRewrite(
+ IndexedGenericOp op, PatternRewriter &rewriter) const {
+ auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
+ if (!libraryCallName)
+ return failure();
+
+ // TODO: Use induction variables values instead of zeros, when
+ // IndexedGenericOp is tiled.
+ auto zero = rewriter.create<mlir::ConstantOp>(
+ op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
+ auto indexedGenericOp = cast<IndexedGenericOp>(op);
+ auto numLoops = indexedGenericOp.getNumLoops();
+ SmallVector<Value, 4> operands;
+ operands.reserve(numLoops + op.getNumOperands());
+ for (unsigned i = 0; i < numLoops; ++i)
+ operands.push_back(zero);
+ for (auto operand : op.getOperands())
+ operands.push_back(operand);
+ rewriter.replaceOpWithNewOp<mlir::CallOp>(
+ op, libraryCallName.getValue(), TypeRange(),
+ createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands));
+ return success();
+}
/// Populate the given list with patterns that convert from Linalg to Standard.
-void mlir::populateLinalgToStandardConversionPatterns(
+void mlir::linalg::populateLinalgToStandardConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
// TODO: ConvOp conversion needs to export a descriptor with relevant
// attribute values such as kernel striding and dilation.
// clang-format off
patterns.insert<
- CopyTransposeConversion,
- LinalgOpConversion<ConvOp>,
- LinalgOpConversion<PoolingMaxOp>,
- LinalgOpConversion<PoolingMinOp>,
- LinalgOpConversion<PoolingSumOp>,
- LinalgOpConversion<CopyOp>,
- LinalgOpConversion<FillOp>,
- LinalgOpConversion<GenericOp>,
- LinalgOpConversion<IndexedGenericOp>>(ctx);
- // TODO: collect all auto-generated named ops with a tblgen directive.
- patterns.insert<
- LinalgOpConversion<DotOp>,
- LinalgOpConversion<BatchMatmulOp>,
- LinalgOpConversion<MatvecOp>,
- LinalgOpConversion<VecmatOp>,
- LinalgOpConversion<MatmulOp>,
- LinalgOpConversion<ConvWOp>,
- LinalgOpConversion<ConvNWCOp>,
- LinalgOpConversion<ConvNCWOp>,
- LinalgOpConversion<ConvHWOp>,
- LinalgOpConversion<ConvNHWCOp>,
- LinalgOpConversion<ConvNCHWOp>,
- LinalgOpConversion<ConvDHWOp>,
- LinalgOpConversion<ConvNDHWCOp>,
- LinalgOpConversion<ConvNCDHWOp>>(ctx);
+ CopyOpToLibraryCallRewrite,
+ CopyTransposeRewrite,
+ IndexedGenericOpToLibraryCallRewrite>(ctx);
+ patterns.insert<LinalgOpToLibraryCallRewrite>();
// clang-format on
}
More information about the Mlir-commits
mailing list