[Mlir-commits] [mlir] 30e130c - [mlir] Move some linalg patterns around.

Sean Silva llvmlistbot at llvm.org
Fri Oct 30 13:49:13 PDT 2020


Author: Sean Silva
Date: 2020-10-30T13:48:03-07:00
New Revision: 30e130c3edb2381ac099556d4976f379dfaa4f66

URL: https://github.com/llvm/llvm-project/commit/30e130c3edb2381ac099556d4976f379dfaa4f66
DIFF: https://github.com/llvm/llvm-project/commit/30e130c3edb2381ac099556d4976f379dfaa4f66.diff

LOG: [mlir] Move some linalg patterns around.

The bufferization patterns are moved to the .cpp file, which is
preferred in the codebase when it makes sense.

The LinalgToStandard patterns are kept a header because they are
expected to be used individually. However, they are moved to
LinalgToStandard.h which is the file corresponding to where they are
defined.

This also removes TensorCastOpConverter, which is handled by
populateStdBufferizePatterns now. Eventually, the constant op lowering
will be handled as well, but it there are currently holdups on moving
it (see https://reviews.llvm.org/D89916).

Differential Revision: https://reviews.llvm.org/D90254

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
index 08b3981d0b67..3a6c8bba614b 100644
--- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
+++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_CONVERSION_LINALGTOSTANDARD_LINALGTOSTANDARD_H_
 #define MLIR_CONVERSION_LINALGTOSTANDARD_LINALGTOSTANDARD_H_
 
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
@@ -16,6 +17,63 @@ class ModuleOp;
 template <typename T>
 class OperationPass;
 
+namespace linalg {
+
+//===----------------------------------------------------------------------===//
+// Patterns to convert a LinalgOp to std.call @external library implementation.
+//===----------------------------------------------------------------------===//
+// These patterns are exposed individually because they are expected to be
+// typically used individually.
+
+// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
+// function. The implementation of the function can be either in the same module
+// or in an externally linked library.
+// This is a generic entry point for all LinalgOp, except for CopyOp and
+// IndexedGenericOp, for which omre specialized patterns are provided.
+class LinalgOpToLibraryCallRewrite : public RewritePattern {
+public:
+  LinalgOpToLibraryCallRewrite()
+      : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Rewrite pattern specialization for CopyOp, kicks in when both input and
+/// output permutations are left unspecified or are the identity.
+class CopyOpToLibraryCallRewrite : public OpRewritePattern<CopyOp> {
+public:
+  using OpRewritePattern<CopyOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(CopyOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Rewrite CopyOp with permutations into a sequence of TransposeOp and
+/// permutation-free CopyOp. This interplays with TransposeOpConversion and
+/// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
+class CopyTransposeRewrite : public OpRewritePattern<CopyOp> {
+public:
+  using OpRewritePattern<CopyOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(CopyOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Conversion pattern specialization for IndexedGenericOp, has special handling
+/// for the extra index operands.
+class IndexedGenericOpToLibraryCallRewrite
+    : public OpRewritePattern<IndexedGenericOp> {
+public:
+  using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(IndexedGenericOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Populate the given list with patterns that convert from Linalg to Standard.
+void populateLinalgToStandardConversionPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *ctx);
+
+} // namespace linalg
+
 /// Create a pass to convert Linalg operations to the Standard dialect.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertLinalgToStandardPass();
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 30690636ac8c..e34150d26594 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -754,98 +754,6 @@ class ConvOpVectorization : public OpRewritePattern<ConvOp> {
                                 PatternRewriter &rewriter) const override;
 };
 
-//===----------------------------------------------------------------------===//
-// Patterns to convert a LinalgOp to std.call @external library implementation.
-//===----------------------------------------------------------------------===//
-// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()`
-// function. The implementation of the function can be either in the same module
-// or in an externally linked library.
-// This is a generic entry point for all LinalgOp, except for CopyOp and
-// IndexedGenericOp, for which omre specialized patterns are provided.
-class LinalgOpToLibraryCallRewrite : public RewritePattern {
-public:
-  LinalgOpToLibraryCallRewrite()
-      : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
-
-  LogicalResult matchAndRewrite(Operation *op,
-                                PatternRewriter &rewriter) const override;
-};
-
-/// Rewrite pattern specialization for CopyOp, kicks in when both input and
-/// output permutations are left unspecified or are the identity.
-class CopyOpToLibraryCallRewrite : public OpRewritePattern<CopyOp> {
-public:
-  using OpRewritePattern<CopyOp>::OpRewritePattern;
-  LogicalResult matchAndRewrite(CopyOp op,
-                                PatternRewriter &rewriter) const override;
-};
-
-/// Rewrite CopyOp with permutations into a sequence of TransposeOp and
-/// permutation-free CopyOp. This interplays with TransposeOpConversion and
-/// LinalgConversion<CopyOp> to create a path to the LLVM dialect.
-class CopyTransposeRewrite : public OpRewritePattern<CopyOp> {
-public:
-  using OpRewritePattern<CopyOp>::OpRewritePattern;
-  LogicalResult matchAndRewrite(CopyOp op,
-                                PatternRewriter &rewriter) const override;
-};
-
-/// Conversion pattern specialization for IndexedGenericOp, has special handling
-/// for the extra index operands.
-class IndexedGenericOpToLibraryCallRewrite
-    : public OpRewritePattern<IndexedGenericOp> {
-public:
-  using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
-  LogicalResult matchAndRewrite(IndexedGenericOp op,
-                                PatternRewriter &rewriter) const override;
-};
-
-/// Populate the given list with patterns that convert from Linalg to Standard.
-void populateLinalgToStandardConversionPatterns(
-    OwningRewritePatternList &patterns, MLIRContext *ctx);
-
-//===----------------------------------------------------------------------===//
-// Buffer allocation patterns.
-//===----------------------------------------------------------------------===//
-
-/// Generic BufferizeConversionPattern that matches any Operation* and
-/// dispatches internally. This avoids template instantiating one pattern for
-/// each LinalgOp op.
-class LinalgOpConverter : public BufferizeConversionPattern {
-public:
-  LinalgOpConverter(MLIRContext *context, BufferizeTypeConverter &converter)
-      : BufferizeConversionPattern(context, converter) {}
-
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const final;
-};
-
-/// TensorConstantOp conversion inserts a linearized 1-D vector constant that is
-/// stored in memory. A linalg.reshape is introduced to convert to the desired
-/// n-D buffer form.
-class TensorConstantOpConverter
-    : public BufferizeOpConversionPattern<ConstantOp> {
-public:
-  using BufferizeOpConversionPattern<ConstantOp>::BufferizeOpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const final;
-};
-
-/// TensorCastOp converts 1-1 to MemRefCastOp.
-class TensorCastOpConverter
-    : public BufferizeOpConversionPattern<TensorCastOp> {
-public:
-  using BufferizeOpConversionPattern<
-      TensorCastOp>::BufferizeOpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(TensorCastOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const final;
-};
-
 //===----------------------------------------------------------------------===//
 // Support for staged pattern application.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index 2f70063957b8..2ad287e22013 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Operation.h"
@@ -185,105 +186,124 @@ static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
   rewriter.replaceOp(linalgOp, outputs);
 }
 
-LogicalResult mlir::linalg::LinalgOpConverter::matchAndRewrite(
-    Operation *op, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) const {
-  LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op);
-  if (!linalgOp)
-    return failure();
-
-  // We abuse the GenericOpAdaptor here.
-  // TODO: Manually create an Adaptor that captures inputs, output_buffers and
-  // init_tensors for all linalg::LinalgOp interface ops.
-  linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
-
-  // All inputs need to be turned into buffers first. Until then, bail out.
-  if (llvm::any_of(adaptor.inputs(),
-                   [](Value in) { return !in.getType().isa<MemRefType>(); }))
-    return failure();
-
-  // All init_tensors need to be turned into buffers first. Until then, bail
-  // out.
-  if (llvm::any_of(adaptor.init_tensors(),
-                   [](Value in) { return !in.getType().isa<MemRefType>(); }))
-    return failure();
-
-  Location loc = linalgOp.getLoc();
-  SmallVector<Value, 2> newOutputBuffers(adaptor.output_buffers().begin(),
-                                         adaptor.output_buffers().end());
-
-  if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, newOutputBuffers,
-                                       rewriter))) {
-    linalgOp.emitOpError() << "Failed to allocate buffers for tensor results.";
-    return failure();
-  }
+//===----------------------------------------------------------------------===//
+// Buffer allocation patterns.
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Generic BufferizeConversionPattern that matches any Operation* and
+/// dispatches internally. This avoids template instantiating one pattern for
+/// each LinalgOp op.
+class LinalgOpConverter : public BufferizeConversionPattern {
+public:
+  LinalgOpConverter(MLIRContext *context, BufferizeTypeConverter &converter)
+      : BufferizeConversionPattern(context, converter) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+
+    LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op);
+    if (!linalgOp)
+      return failure();
+
+    // We abuse the GenericOpAdaptor here.
+    // TODO: Manually create an Adaptor that captures inputs, output_buffers and
+    // init_tensors for all linalg::LinalgOp interface ops.
+    linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
+
+    // All inputs need to be turned into buffers first. Until then, bail out.
+    if (llvm::any_of(adaptor.inputs(),
+                     [](Value in) { return !in.getType().isa<MemRefType>(); }))
+      return failure();
+
+    // All init_tensors need to be turned into buffers first. Until then, bail
+    // out.
+    if (llvm::any_of(adaptor.init_tensors(),
+                     [](Value in) { return !in.getType().isa<MemRefType>(); }))
+      return failure();
+
+    Location loc = linalgOp.getLoc();
+    SmallVector<Value, 2> newOutputBuffers(adaptor.output_buffers().begin(),
+                                           adaptor.output_buffers().end());
+
+    if (failed(allocateBuffersForResults(loc, linalgOp, adaptor,
+                                         newOutputBuffers, rewriter))) {
+      linalgOp.emitOpError()
+          << "Failed to allocate buffers for tensor results.";
+      return failure();
+    }
+
+    // Delegate to the linalg generic pattern.
+    if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
+      finalizeBufferAllocation(rewriter, genericOp, adaptor.inputs(),
+                               newOutputBuffers);
+      return success();
+    }
 
-  // Delegate to the linalg generic pattern.
-  if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
-    finalizeBufferAllocation(rewriter, genericOp, adaptor.inputs(),
+    finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(),
                              newOutputBuffers);
     return success();
   }
+};
+} // namespace
 
-  finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(),
-                           newOutputBuffers);
-  return success();
-}
-
-LogicalResult mlir::linalg::TensorConstantOpConverter::matchAndRewrite(
-    ConstantOp op, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) const {
-  RankedTensorType rankedTensorType = op.getType().dyn_cast<RankedTensorType>();
-  if (!rankedTensorType)
-    return failure();
-  if (llvm::any_of(rankedTensorType.getShape(), [](int64_t s) {
-        return s == 0 || ShapedType::isDynamic(s);
-      }))
-    return failure();
-
-  int64_t nElements = 1;
-  for (int64_t s : rankedTensorType.getShape())
-    nElements *= s;
-  Type elementType = rankedTensorType.getElementType();
-  MemRefType memrefType =
-      converter.convertType(op.getType()).cast<MemRefType>();
-  VectorType flatVectorType = VectorType::get({nElements}, elementType);
-  MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType);
-  MemRefType flatMemrefType = MemRefType::get({nElements}, elementType);
-
-  Location loc = op.getLoc();
-  auto attr = op.getValue().cast<DenseElementsAttr>();
-  Value alloc =
-      rewriter.create<AllocOp>(loc, memrefOfFlatVectorType, ValueRange{});
-  Value cstVec = rewriter.create<ConstantOp>(loc, flatVectorType,
-                                             attr.reshape(flatVectorType));
-  rewriter.create<StoreOp>(loc, cstVec, alloc);
-
-  Value memref =
-      rewriter.create<vector::TypeCastOp>(loc, flatMemrefType, alloc);
-  if (rankedTensorType.getRank() > 1) {
-    // Introduce a linalg.reshape to flatten the memref.
-    AffineMap collapseAllDims = AffineMap::getMultiDimIdentityMap(
-        /*numDims=*/rankedTensorType.getRank(), op.getContext());
-    memref = rewriter.create<linalg::ReshapeOp>(
-        loc, memrefType, memref,
-        rewriter.getAffineMapArrayAttr(collapseAllDims));
-  }
-  rewriter.replaceOp(op, memref);
+namespace {
+/// TensorConstantOp conversion inserts a linearized 1-D vector constant that is
+/// stored in memory. A linalg.reshape is introduced to convert to the desired
+/// n-D buffer form.
+class TensorConstantOpConverter
+    : public BufferizeOpConversionPattern<ConstantOp> {
+public:
+  using BufferizeOpConversionPattern<ConstantOp>::BufferizeOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+
+    RankedTensorType rankedTensorType =
+        op.getType().dyn_cast<RankedTensorType>();
+    if (!rankedTensorType)
+      return failure();
+    if (llvm::any_of(rankedTensorType.getShape(), [](int64_t s) {
+          return s == 0 || ShapedType::isDynamic(s);
+        }))
+      return failure();
 
-  return success();
-}
+    int64_t nElements = 1;
+    for (int64_t s : rankedTensorType.getShape())
+      nElements *= s;
+    Type elementType = rankedTensorType.getElementType();
+    MemRefType memrefType =
+        converter.convertType(op.getType()).cast<MemRefType>();
+    VectorType flatVectorType = VectorType::get({nElements}, elementType);
+    MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType);
+    MemRefType flatMemrefType = MemRefType::get({nElements}, elementType);
+
+    Location loc = op.getLoc();
+    auto attr = op.getValue().cast<DenseElementsAttr>();
+    Value alloc =
+        rewriter.create<AllocOp>(loc, memrefOfFlatVectorType, ValueRange{});
+    Value cstVec = rewriter.create<ConstantOp>(loc, flatVectorType,
+                                               attr.reshape(flatVectorType));
+    rewriter.create<StoreOp>(loc, cstVec, alloc);
+
+    Value memref =
+        rewriter.create<vector::TypeCastOp>(loc, flatMemrefType, alloc);
+    if (rankedTensorType.getRank() > 1) {
+      // Introduce a linalg.reshape to flatten the memref.
+      AffineMap collapseAllDims = AffineMap::getMultiDimIdentityMap(
+          /*numDims=*/rankedTensorType.getRank(), op.getContext());
+      memref = rewriter.create<linalg::ReshapeOp>(
+          loc, memrefType, memref,
+          rewriter.getAffineMapArrayAttr(collapseAllDims));
+    }
+    rewriter.replaceOp(op, memref);
 
-LogicalResult mlir::linalg::TensorCastOpConverter::matchAndRewrite(
-    TensorCastOp op, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) const {
-  if (op.getType().hasRank())
-    return failure();
-  Type t = UnrankedMemRefType::get(op.getType().getElementType(),
-                                   /*memorySpace=*/0);
-  rewriter.replaceOpWithNewOp<MemRefCastOp>(op, t, operands.front());
-  return success();
-}
+    return success();
+  }
+};
+} // namespace
 
 namespace {
 
@@ -347,6 +367,7 @@ struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
 
     OwningRewritePatternList patterns;
     populateLinalgBufferizePatterns(&context, converter, patterns);
+    populateStdBufferizePatterns(&context, converter, patterns);
     populateWithBufferizeOpConversionPatterns<mlir::ReturnOp, mlir::ReturnOp,
                                               linalg::CopyOp>(
         &context, converter, patterns);
@@ -366,7 +387,6 @@ void mlir::linalg::populateLinalgBufferizePatterns(
   patterns.insert<
       // clang-format off
       LinalgOpConverter,
-      TensorCastOpConverter,
       TensorConstantOpConverter
       // clang-format on
       >(context, converter);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 844972e56c6f..88242c1d6f28 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRSCFTransforms
   MLIRPass
   MLIRStandard
+  MLIRStandardOpsTransforms
   MLIRStandardToLLVM
   MLIRTransforms
   MLIRTransformUtils


        


More information about the Mlir-commits mailing list