[Mlir-commits] [mlir] 422aaf3 - [mlir][Linalg] Add named Linalg ops on tensor to buffer support.

Nicolas Vasilache llvmlistbot at llvm.org
Mon Oct 12 04:21:12 PDT 2020


Author: Nicolas Vasilache
Date: 2020-10-12T11:20:23Z
New Revision: 422aaf31daa520899303adaf82ba76743624ee0a

URL: https://github.com/llvm/llvm-project/commit/422aaf31daa520899303adaf82ba76743624ee0a
DIFF: https://github.com/llvm/llvm-project/commit/422aaf31daa520899303adaf82ba76743624ee0a.diff

LOG: [mlir][Linalg] Add named Linalg ops on tensor to buffer support.

This revision introduces support for buffer allocation for any named linalg op.
To avoid template instantiating many ops, a new ConversionPattern is created to capture the LinalgOp interface.

Some APIs are updated to remain consistent with MLIR style:
`OwningRewritePatternList * -> OwningRewritePatternList &`
`BufferAssignmentTypeConverter * -> BufferAssignmentTypeConverter &`

Differential revision: https://reviews.llvm.org/D89226

Added: 
    mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir

Modified: 
    mlir/docs/Tutorials/QuickstartRewrites.md
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
    mlir/include/mlir/Transforms/Bufferize.h
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
    mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
    mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
    mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp
    mlir/lib/Transforms/BufferPlacement.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp
    mlir/test/lib/Transforms/TestBufferPlacement.cpp
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Tutorials/QuickstartRewrites.md b/mlir/docs/Tutorials/QuickstartRewrites.md
index fbc2406a0f31..447f8a62f91e 100644
--- a/mlir/docs/Tutorials/QuickstartRewrites.md
+++ b/mlir/docs/Tutorials/QuickstartRewrites.md
@@ -155,7 +155,7 @@ add_public_tablegen_target(<name-of-the-cmake-target>)
 Then you can `#include` the generated file in any C++ implementation file you
 like. (You will also need to make sure the library depends on the CMake target
 defined in the above.) The generated file will have a `populateWithGenerated(
-MLIRContext *context, OwningRewritePatternList *patterns)` function that you can
+MLIRContext *context, OwningRewritePatternList &patterns)` function that you can
 use to collect all the generated patterns inside `patterns` and then use
 `patterns` in any pass you would like.
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 2332f516c44a..55df0bccbb64 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -39,7 +39,7 @@ def StructuredOpTraits : NativeOpTrait<"linalg::StructuredOpTraits">;
 def NamedStructuredOpTrait : NativeOpTrait<"linalg::NamedStructuredOpTrait">;
 
 // Base Tablegen class for Linalg ops.
-// Linalg ops that correspond to library calls operate on linalg::View as their
+// Linalg ops that correspond to library calls operate on ShapedType as their
 // first operands. These may be optionally followed by non-view operands
 // depending on the specific Linalg op.
 class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index dbb89c73954b..845873ff83df 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -628,7 +628,9 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     InterfaceMethod<
       /*desc=*/[{
         Clone the current operation with the given location and operands. This
-        is used to abstract away the optional underlying region creation.
+        is used to abstract away the optional underlying region creation. This 
+        does not change the balance between input, output_buffer and 
+        init_tensors operands.
       }],
       /*retTy=*/"Operation *",
       /*methodName=*/"clone",
@@ -666,6 +668,23 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       }
       return res;
     }
+    //========================================================================//
+    // Helper functions to mutate the `operand_segment_sizes` attribute.
+    // These are useful when cloning and changing operand types.
+    //========================================================================//
+    void setNumInputs(unsigned num) { setOperandSegmentAt(0, num); }
+    void setNumOutputBuffers(unsigned num) { setOperandSegmentAt(1, num); }
+    void setNumInitTensors(unsigned num) { setOperandSegmentAt(2, num); }
+
+    private:
+    void setOperandSegmentAt(unsigned idx, unsigned val) {
+      auto attr = getOperation()->getAttr("operand_segment_sizes")
+        .cast<DenseIntElementsAttr>();
+      unsigned i = 0;
+      auto newAttr = attr.mapValues(IntegerType::get(32, getContext()),
+        [&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });
+      getOperation()->setAttr("operand_segment_sizes", newAttr);
+    }
   }];
 }
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7512f69608a4..a2dee8c3ae65 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/Identifier.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/Bufferize.h"
 #include "llvm/ADT/SmallBitVector.h"
 
 namespace mlir {
@@ -51,8 +52,8 @@ void populateConvVectorizationPatterns(
 /// Populates the given list with patterns to convert Linalg operations on
 /// tensors to buffers.
 void populateConvertLinalgOnTensorsToBuffersPatterns(
-    MLIRContext *context, BufferAssignmentTypeConverter *converter,
-    OwningRewritePatternList *patterns);
+    MLIRContext *context, BufferAssignmentTypeConverter &converter,
+    OwningRewritePatternList &patterns);
 
 /// Performs standalone tiling of a single LinalgOp by `tileSizes`.
 /// and permute the loop nest according to `interchangeVector`
@@ -797,6 +798,46 @@ class IndexedGenericOpToLibraryCallRewrite
 void populateLinalgToStandardConversionPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx);
 
+//===----------------------------------------------------------------------===//
+// Buffer allocation patterns.
+//===----------------------------------------------------------------------===//
+
+/// Generic BufferAssignmentConversionPattern that matches any Operation* and
+/// dispatches internally. This avoids template instantiating one pattern for
+/// each LinalgOp op.
+class LinalgOpConverter : public BufferAssignmentConversionPattern {
+public:
+  LinalgOpConverter(MLIRContext *context,
+                    BufferAssignmentTypeConverter &converter)
+      : BufferAssignmentConversionPattern(context, converter) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final;
+};
+
+class TensorConstantOpConverter
+    : public BufferAssignmentOpConversionPattern<ConstantOp> {
+public:
+  using BufferAssignmentOpConversionPattern<
+      ConstantOp>::BufferAssignmentOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final;
+};
+
+class TensorCastOpConverter
+    : public BufferAssignmentOpConversionPattern<TensorCastOp> {
+public:
+  using BufferAssignmentOpConversionPattern<
+      TensorCastOp>::BufferAssignmentOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(TensorCastOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final;
+};
+
 //===----------------------------------------------------------------------===//
 // Support for staged pattern application.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
index 72816b72f41e..81f64dc3ccb3 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
@@ -41,8 +41,8 @@ void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns,
 std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
 
 void populateShapeTypeConversionPatterns(
-    MLIRContext *ctx, BufferAssignmentTypeConverter *converter,
-    OwningRewritePatternList *patterns);
+    MLIRContext *ctx, BufferAssignmentTypeConverter &converter,
+    OwningRewritePatternList &patterns);
 // Collects a set of patterns to replace tensors as inputs and outputs to shape
 // operations with buffers. This only modifies the shape operations.
 std::unique_ptr<FunctionPass> createShapeTensorToMemrefPass();

diff  --git a/mlir/include/mlir/Transforms/Bufferize.h b/mlir/include/mlir/Transforms/Bufferize.h
index e8fffcfe4925..26452e9db513 100644
--- a/mlir/include/mlir/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Transforms/Bufferize.h
@@ -140,14 +140,28 @@ class BufferAssignmentOpConversionPattern
     : public OpConversionPattern<SourceOp> {
 public:
   explicit BufferAssignmentOpConversionPattern(
-      MLIRContext *context, BufferAssignmentTypeConverter *converter,
+      MLIRContext *context, BufferAssignmentTypeConverter &converter,
       PatternBenefit benefit = 1)
-      : OpConversionPattern<SourceOp>(context, benefit), converter(converter) {
-    assert(converter && "The type converter has not been defined");
-  }
+      : OpConversionPattern<SourceOp>(context, benefit), converter(converter) {}
+
+protected:
+  BufferAssignmentTypeConverter &converter;
+};
+
+/// Helper conversion pattern that encapsulates a BufferAssignmentTypeConverter
+/// instance and that operates on Operation* to be compatible with OpInterfaces.
+/// This allows avoiding to instantiate N patterns for ops that can be subsumed
+/// by a single op interface (e.g. Linalg named ops).
+class BufferAssignmentConversionPattern : public ConversionPattern {
+public:
+  explicit BufferAssignmentConversionPattern(
+      MLIRContext *context, BufferAssignmentTypeConverter &converter,
+      PatternBenefit benefit = 1)
+      : ConversionPattern(benefit, converter, MatchAnyOpTypeTag()),
+        converter(converter) {}
 
 protected:
-  BufferAssignmentTypeConverter *converter;
+  BufferAssignmentTypeConverter &converter;
 };
 
 /// Converts the signature of the function using BufferAssignmentTypeConverter.
@@ -191,15 +205,15 @@ class BufferAssignmentReturnOpConverter
     OpBuilder builder(returnOp);
     for (auto operand : llvm::enumerate(operands)) {
       SmallVector<Value, 2> values;
-      this->converter->tryDecomposeValue(
-          builder, loc, operand.value().getType(), operand.value(), values);
+      this->converter.tryDecomposeValue(builder, loc, operand.value().getType(),
+                                        operand.value(), values);
       Type type = returnOp.getOperand(operand.index()).getType();
       SmallVector<Type, 2> originTypes;
-      this->converter->tryDecomposeType(type, originTypes);
+      this->converter.tryDecomposeType(type, originTypes);
       for (auto value : llvm::enumerate(values)) {
         Type origin = originTypes[value.index()];
         Type converted = value.value().getType();
-        auto kind = this->converter->getResultConversionKind(origin, converted);
+        auto kind = this->converter.getResultConversionKind(origin, converted);
         if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult)
           newOperands.push_back(value.value());
         else
@@ -247,10 +261,10 @@ class BufferAssignmentCallOpConverter
 template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
           typename CopyOpTy>
 static void populateWithBufferAssignmentOpConversionPatterns(
-    MLIRContext *context, BufferAssignmentTypeConverter *converter,
-    OwningRewritePatternList *patterns) {
+    MLIRContext *context, BufferAssignmentTypeConverter &converter,
+    OwningRewritePatternList &patterns) {
   // clang-format off
-  patterns->insert<
+  patterns.insert<
     BufferAssignmentCallOpConverter,
     BufferAssignmentFuncOpConverter,
     BufferAssignmentReturnOpConverter

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
new file mode 100644
index 000000000000..2c01a688cfa0
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s -convert-linalg-on-tensors-to-buffers -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func @main() {
+  %A = constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
+  %B = constant dense<[[1.0, 2.0, 3.0, 4.0],
+                       [5.0, 6.0, 7.0, 8.0],
+                       [9.0, 10.0, 11.0, 12.0]]> : tensor<3x4xf32>
+  %C = constant dense<1000.0> : tensor<2x4xf32>
+
+  %D = linalg.matmul ins(%A, %B: tensor<2x3xf32>, tensor<3x4xf32>)
+                     init(%C: tensor<2x4xf32>) -> tensor<2x4xf32>
+
+  %unranked = tensor_cast %D : tensor<2x4xf32> to tensor<*xf32>
+  call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
+
+  //      CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
+  // CHECK-SAME: rank = 2 offset = 0 sizes = [2, 4] strides = [4, 1] data =
+  // CHECK-NEXT: [1038,   1044,   1050,   1056]
+  // CHECK-NEXT: [1065,   1074,   1083,   1092]
+
+  return
+}
+
+func @print_memref_f32(%ptr : tensor<*xf32>)

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index d11cc51d1d59..69786823dd32 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -152,7 +152,7 @@ struct LowerGpuOpsToNVVMOpsPass
 
 void mlir::populateGpuToNVVMConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
-  populateWithGenerated(converter.getDialect()->getContext(), &patterns);
+  populateWithGenerated(converter.getDialect()->getContext(), patterns);
   patterns
       .insert<GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
                                           NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,

diff  --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 40cf097c9c5a..e9b44a9fef52 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -86,7 +86,7 @@ struct LowerGpuOpsToROCDLOpsPass
 
 void mlir::populateGpuToROCDLConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
-  populateWithGenerated(converter.getDialect()->getContext(), &patterns);
+  populateWithGenerated(converter.getDialect()->getContext(), patterns);
   patterns.insert<
       GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, ROCDL::ThreadIdXOp,
                                   ROCDL::ThreadIdYOp, ROCDL::ThreadIdZOp>,

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
index f4273a44bb9c..ee1e4131854f 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
@@ -330,7 +330,7 @@ namespace {
 void mlir::populateGPUToSPIRVPatterns(MLIRContext *context,
                                       SPIRVTypeConverter &typeConverter,
                                       OwningRewritePatternList &patterns) {
-  populateWithGenerated(context, &patterns);
+  populateWithGenerated(context, patterns);
   patterns.insert<
       GPUFuncOpConversion, GPUModuleConversion, GPUReturnOpConversion,
       LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
index 442117c619de..650c44f9e922 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
@@ -22,40 +22,37 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/Bufferize.h"
 
-namespace {
-
 using namespace ::mlir;
 using namespace ::mlir::linalg;
 
-SmallVector<Range, 4>
-computeLoopRanges(Location loc, linalg::GenericOp linalgOp, OpBuilder *b) {
+static SmallVector<Range, 4> computeLoopRanges(Location loc, LinalgOp linalgOp,
+                                               OpBuilder &b) {
   auto indexingMaps = llvm::to_vector<4>(
       linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>());
   auto inputIndexingMaps =
       llvm::makeArrayRef(indexingMaps).take_front(linalgOp.getNumInputs());
 
-  mlir::edsc::ScopedContext scope(*b, loc);
+  mlir::edsc::ScopedContext scope(b, loc);
   return emitLoopRanges(scope.getBuilderRef(), loc,
                         concatAffineMaps(inputIndexingMaps),
-                        getShape(*b, linalgOp));
+                        getShape(b, linalgOp));
 }
 
-Value maybeConvertToIndex(Location loc, Value val, OpBuilder *b) {
+static Value maybeConvertToIndex(Location loc, Value val, OpBuilder &b) {
   if (val.getType().isIndex())
     return val;
-  return b->create<IndexCastOp>(loc, val, b->getIndexType());
+  return b.create<IndexCastOp>(loc, val, b.getIndexType());
 }
 
-LogicalResult allocateBuffersForResults(Location loc,
-                                        linalg::GenericOp linalgOp,
-                                        linalg::GenericOpAdaptor &adaptor,
-                                        SmallVectorImpl<Value> *resultBuffers,
-                                        OpBuilder *b) {
+static LogicalResult
+allocateBuffersForResults(Location loc, LinalgOp linalgOp,
+                          linalg::GenericOpAdaptor &adaptor,
+                          SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) {
   // Lazily compute loopRanges.
   SmallVector<Range, 4> loopRanges;
 
   // Allocate a buffer for every tensor result.
-  for (auto en : llvm::enumerate(linalgOp.getResultTypes())) {
+  for (auto en : llvm::enumerate(linalgOp.getOperation()->getResultTypes())) {
     size_t resultIndex = en.index();
     Type resultType = en.value();
 
@@ -79,24 +76,24 @@ LogicalResult allocateBuffersForResults(Location loc,
       Value initTensor = linalgOp.getInitTensor(resultIndex);
       Value initBuffer = adaptor.init_tensors()[resultIndex];
       if (initTensor.hasOneUse()) {
-        resultBuffers->push_back(initBuffer);
+        resultBuffers.push_back(initBuffer);
         continue;
       }
       SmallVector<Value, 4> dynOperands;
       for (auto dim : llvm::enumerate(tensorShape)) {
         if (dim.value() == TensorType::kDynamicSize) {
-          dynOperands.push_back(b->create<DimOp>(loc, initTensor, dim.index()));
+          dynOperands.push_back(b.create<DimOp>(loc, initTensor, dim.index()));
         }
       }
-      auto alloc = b->create<AllocOp>(loc, memrefType, dynOperands);
-      b->create<linalg::CopyOp>(loc, initBuffer, alloc);
-      resultBuffers->push_back(alloc);
+      auto alloc = b.create<AllocOp>(loc, memrefType, dynOperands);
+      b.create<linalg::CopyOp>(loc, initBuffer, alloc);
+      resultBuffers.push_back(alloc);
       continue;
     }
 
     // Allocate buffers for statically-shaped results.
     if (memrefType.hasStaticShape()) {
-      resultBuffers->push_back(b->create<AllocOp>(loc, memrefType));
+      resultBuffers.push_back(b.create<AllocOp>(loc, memrefType));
       continue;
     }
 
@@ -123,148 +120,157 @@ LogicalResult allocateBuffersForResults(Location loc,
         return failure();
       }
     }
-    resultBuffers->push_back(b->create<AllocOp>(loc, memrefType, dynOperands));
+    resultBuffers.push_back(b.create<AllocOp>(loc, memrefType, dynOperands));
   }
   return success();
 }
 
+// Specialization for `linalg::GenericOp`.
 /// A pattern to convert Generic Linalg operations which work on tensors to
 /// use buffers. A buffer is allocated using BufferAssignmentPlacer for
 /// each operation result. BufferPlacement pass should be later used to move
 /// Alloc operations to the correct positions and insert the missing Dealloc
 /// operations in the correct places.
-class GenericOpConverter
-    : public BufferAssignmentOpConversionPattern<linalg::GenericOp> {
-public:
-  using BufferAssignmentOpConversionPattern<
-      linalg::GenericOp>::BufferAssignmentOpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(linalg::GenericOp linalgOp, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const final {
-    linalg::GenericOpAdaptor adaptor(
-        operands, linalgOp.getOperation()->getAttrDictionary());
-
-    // All inputs need to be turned into buffers first. Until then, bail out.
-    if (llvm::any_of(adaptor.inputs(),
-                     [](Value in) { return !in.getType().isa<MemRefType>(); }))
-      return failure();
-
-    // All init_tensors need to be turned into buffers first. Until then, bail
-    // out.
-    if (llvm::any_of(adaptor.init_tensors(),
-                     [](Value in) { return !in.getType().isa<MemRefType>(); }))
-      return failure();
+static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
+                                     linalg::GenericOp genericOp,
+                                     ValueRange inputs, ValueRange outputs) {
+  // Generate a new linalg operation that works on buffers.
+  auto newGenericOp = rewriter.create<linalg::GenericOp>(
+      genericOp.getLoc(),
+      /*resultTensorTypes=*/llvm::None,
+      /*inputs=*/inputs,
+      /*outputBuffers=*/outputs,
+      /*initTensors=*/llvm::None, genericOp.indexing_maps(),
+      genericOp.iterator_types(), genericOp.docAttr(),
+      genericOp.library_callAttr(), genericOp.symbol_sourceAttr());
+
+  // Create a new block in the region of the new Generic Op.
+  Block *oldBlock = genericOp.getBody();
+  Region &newRegion = newGenericOp.region();
+  Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
+                                         oldBlock->getArgumentTypes());
+
+  // Add the result arguments to the new block.
+  for (Value v : ValueRange(outputs).drop_front(genericOp.getNumInitTensors()))
+    newBlock->addArgument(v.getType().cast<MemRefType>().getElementType());
+
+  // Clone the body of the old block to the new block.
+  BlockAndValueMapping mapping;
+  mapping.map(oldBlock->getArguments(), newBlock->getArguments());
+
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPointToEnd(newBlock);
+  for (auto &op : oldBlock->getOperations()) {
+    Operation *clonedOp = rewriter.clone(op, mapping);
+    mapping.map(op.getResults(), clonedOp->getResults());
+  }
 
-    Location loc = linalgOp.getLoc();
-    SmallVector<Value, 2> newOutputBuffers(adaptor.output_buffers().begin(),
-                                           adaptor.output_buffers().end());
+  // Replace the results of the old op with the new output buffers.
+  rewriter.replaceOp(genericOp, outputs);
+}
 
-    if (failed(allocateBuffersForResults(loc, linalgOp, adaptor,
-                                         &newOutputBuffers, &rewriter))) {
-      linalgOp.emitOpError()
-          << "Failed to allocate buffers for tensor results.";
-      return failure();
-    }
+// TODO: Specialization for `linalg::IndexedGenericOp`.
+
+// Specialization for all other `linalg::LinalgOp`.
+static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
+                                     linalg::LinalgOp linalgOp,
+                                     ValueRange inputs, ValueRange outputs) {
+  assert(!isa<linalg::GenericOp>(linalgOp.getOperation()));
+  assert(!isa<linalg::IndexedGenericOp>(linalgOp.getOperation()));
+  SmallVector<Value, 8> newOperands = inputs;
+  newOperands.append(outputs.begin(), outputs.end());
+  auto otherOperands = linalgOp.getAssumedNonShapedOperands();
+  newOperands.append(otherOperands.begin(), otherOperands.end());
+  LinalgOp res = cast<LinalgOp>(linalgOp.clone(rewriter, linalgOp.getLoc(),
+                                               /*resultTypes=*/ArrayRef<Type>{},
+                                               newOperands));
+  // Need to mutate the operands_segment_sizes in the resulting op.
+  res.setNumOutputBuffers(outputs.size());
+  res.setNumInitTensors(0);
+  // Replace the results of the old op with the new output buffers.
+  rewriter.replaceOp(linalgOp, outputs);
+}
 
-    // Generate a new linalg operation that works on buffers.
-    auto newLinalgOp = rewriter.create<linalg::GenericOp>(
-        loc,
-        /*resultTensorTypes=*/llvm::None,
-        /*inputs=*/adaptor.inputs(),
-        /*outputBuffers=*/newOutputBuffers,
-        /*initTensors=*/llvm::None, linalgOp.indexing_maps(),
-        linalgOp.iterator_types(), linalgOp.docAttr(),
-        linalgOp.library_callAttr(), linalgOp.symbol_sourceAttr());
-
-    // Create a new block in the region of the new Generic Op.
-    Block *oldBlock = linalgOp.getBody();
-    Region &newRegion = newLinalgOp.region();
-    Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
-                                           oldBlock->getArgumentTypes());
-
-    // Add the result arguments that do not come from init_tensors to the new
-    // block.
-    // TODO: update this assumption because the reality is more complex under
-    // linalg on tensor based transformations.
-    for (Value v :
-         ValueRange(newOutputBuffers).drop_front(adaptor.init_tensors().size()))
-      newBlock->addArgument(v.getType().cast<MemRefType>().getElementType());
-
-    // Clone the body of the old block to the new block.
-    BlockAndValueMapping mapping;
-    mapping.map(oldBlock->getArguments(), newBlock->getArguments());
-
-    OpBuilder::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToEnd(newBlock);
-    for (auto &op : oldBlock->getOperations()) {
-      Operation *clonedOp = rewriter.clone(op, mapping);
-      mapping.map(op.getResults(), clonedOp->getResults());
-    }
+LogicalResult mlir::linalg::LinalgOpConverter::matchAndRewrite(
+    Operation *op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  LinalgOp linalgOp = dyn_cast<linalg::LinalgOp>(op);
+  if (!linalgOp)
+    return failure();
+
+  // We abuse the GenericOpAdaptor here.
+  // TODO: Manually create an Adaptor that captures inputs, output_buffers and
+  // init_tensors for all linalg::LinalgOp interface ops.
+  linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
+
+  // All inputs need to be turned into buffers first. Until then, bail out.
+  if (llvm::any_of(adaptor.inputs(),
+                   [](Value in) { return !in.getType().isa<MemRefType>(); }))
+    return failure();
+
+  // All init_tensors need to be turned into buffers first. Until then, bail
+  // out.
+  if (llvm::any_of(adaptor.init_tensors(),
+                   [](Value in) { return !in.getType().isa<MemRefType>(); }))
+    return failure();
+
+  Location loc = linalgOp.getLoc();
+  SmallVector<Value, 2> newOutputBuffers(adaptor.output_buffers().begin(),
+                                         adaptor.output_buffers().end());
+
+  if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, newOutputBuffers,
+                                       rewriter))) {
+    linalgOp.emitOpError() << "Failed to allocate buffers for tensor results.";
+    return failure();
+  }
 
-    // Replace the results of the old op with the new output buffers.
-    rewriter.replaceOp(linalgOp, newOutputBuffers);
+  // Delegate to the linalg generic pattern.
+  if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
+    finalizeBufferAllocation(rewriter, genericOp, adaptor.inputs(),
+                             newOutputBuffers);
     return success();
   }
-};
 
-// Rewrite a tensor `constant` to a vector constant folloed by a vector store
-// and a vector.type_cast.
-class TensorConstantOpConverter
-    : public BufferAssignmentOpConversionPattern<ConstantOp> {
-public:
-  using BufferAssignmentOpConversionPattern<
-      ConstantOp>::BufferAssignmentOpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const final {
-    if (!op.getType().isa<RankedTensorType>())
-      return failure();
-    auto attr = op.getValue().cast<DenseElementsAttr>();
-
-    Location loc = op.getLoc();
-    MemRefType memrefType =
-        converter->convertType(op.getType()).cast<MemRefType>();
-    VectorType vectorType =
-        VectorType::get(memrefType.getShape(), memrefType.getElementType());
-
-    // vector constant takes attributes that are compatible with tensor
-    // constant.
-    Value cstVec =
-        rewriter.create<ConstantOp>(loc, vectorType, attr.reshape(vectorType));
-
-    // Alloc a memref<vector<...>>, store the constant and typecast the vector
-    // away.
-    MemRefType memrefOfVectorType = MemRefType::get({}, vectorType);
-    Value alloc =
-        rewriter.create<AllocOp>(loc, memrefOfVectorType, ValueRange{});
-    rewriter.create<StoreOp>(loc, cstVec, alloc);
-    rewriter.replaceOpWithNewOp<vector::TypeCastOp>(op, memrefType, alloc);
+  finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(),
+                           newOutputBuffers);
+  return success();
+}
 
-    return success();
-  }
-};
+LogicalResult mlir::linalg::TensorConstantOpConverter::matchAndRewrite(
+    ConstantOp op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  if (!op.getType().isa<RankedTensorType>())
+    return failure();
+  auto attr = op.getValue().cast<DenseElementsAttr>();
+
+  Location loc = op.getLoc();
+  MemRefType memrefType =
+      converter.convertType(op.getType()).cast<MemRefType>();
+  VectorType vectorType =
+      VectorType::get(memrefType.getShape(), memrefType.getElementType());
+  Value cstVec =
+      rewriter.create<ConstantOp>(loc, vectorType, attr.reshape(vectorType));
+
+  MemRefType memrefOfVectorType = MemRefType::get({}, vectorType);
+  Value alloc = rewriter.create<AllocOp>(loc, memrefOfVectorType, ValueRange{});
+  rewriter.create<StoreOp>(loc, cstVec, alloc);
+  rewriter.replaceOpWithNewOp<vector::TypeCastOp>(op, memrefType, alloc);
 
-// Rewrite a `tensor_cast` as a `memref_cast` with no layout, in the 0-memory
-// space.
-class TensorCastOpConverter
-    : public BufferAssignmentOpConversionPattern<TensorCastOp> {
-public:
-  using BufferAssignmentOpConversionPattern<
-      TensorCastOp>::BufferAssignmentOpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(TensorCastOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const final {
-    if (op.getType().hasRank())
-      return failure();
-    Type t = UnrankedMemRefType::get(op.getType().getElementType(),
-                                     /*memorySpace=*/0);
-    rewriter.replaceOpWithNewOp<MemRefCastOp>(op, t, operands.front());
-    return success();
-  }
-};
+  return success();
+}
+
+LogicalResult mlir::linalg::TensorCastOpConverter::matchAndRewrite(
+    TensorCastOp op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  if (op.getType().hasRank())
+    return failure();
+  Type t = UnrankedMemRefType::get(op.getType().getElementType(),
+                                   /*memorySpace=*/0);
+  rewriter.replaceOpWithNewOp<MemRefCastOp>(op, t, operands.front());
+  return success();
+}
+
+namespace {
 
 /// Converts Linalg operations that work on tensor-type operands or results to
 /// work on buffers.
@@ -326,11 +332,11 @@ struct ConvertLinalgOnTensorsToBuffers
         BufferAssignmentTypeConverter::AppendToArgumentsList);
 
     OwningRewritePatternList patterns;
-    populateConvertLinalgOnTensorsToBuffersPatterns(&context, &converter,
-                                                    &patterns);
+    populateConvertLinalgOnTensorsToBuffersPatterns(&context, converter,
+                                                    patterns);
     populateWithBufferAssignmentOpConversionPatterns<
-        mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(&context, &converter,
-                                                        &patterns);
+        mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(&context, converter,
+                                                        patterns);
     if (failed(applyFullConversion(this->getOperation(), target, patterns)))
       this->signalPassFailure();
   }
@@ -341,13 +347,13 @@ std::unique_ptr<OperationPass<ModuleOp>>
 mlir::createConvertLinalgOnTensorsToBuffersPass() {
   return std::make_unique<ConvertLinalgOnTensorsToBuffers>();
 }
-
 void mlir::linalg::populateConvertLinalgOnTensorsToBuffersPatterns(
-    MLIRContext *context, BufferAssignmentTypeConverter *converter,
-    OwningRewritePatternList *patterns) {
-  patterns->insert<
+
+    MLIRContext *context, BufferAssignmentTypeConverter &converter,
+    OwningRewritePatternList &patterns) {
+  patterns.insert<
       // clang-format off
-      GenericOpConverter,
+      LinalgOpConverter,
       TensorCastOpConverter,
       TensorConstantOpConverter
       // clang-format on

diff  --git a/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp
index e281cea24e73..83e432855f6e 100644
--- a/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/ShapeTypeConversion.cpp
@@ -38,7 +38,7 @@ class TypeConversionAssumingOpConverter
     newResultTypes.reserve(assumingOp.getNumResults());
     for (auto result : assumingOp.getResults()) {
       auto originalType = result.getType();
-      Type convertedType = converter->convertType(originalType);
+      Type convertedType = converter.convertType(originalType);
       newResultTypes.push_back(convertedType);
     }
 
@@ -60,7 +60,7 @@ struct ShapeTensorToMemrefPass
 
     OwningRewritePatternList patterns;
     BufferAssignmentTypeConverter converter;
-    populateShapeTypeConversionPatterns(&ctx, &converter, &patterns);
+    populateShapeTypeConversionPatterns(&ctx, converter, patterns);
 
     ConversionTarget target(getContext());
     auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
@@ -81,9 +81,9 @@ struct ShapeTensorToMemrefPass
 //
 // TODO: Change this to work generally with any type conversions.
 void mlir::populateShapeTypeConversionPatterns(
-    MLIRContext *context, BufferAssignmentTypeConverter *converter,
-    OwningRewritePatternList *patterns) {
-  patterns->insert<TypeConversionAssumingOpConverter>(context, converter);
+    MLIRContext *context, BufferAssignmentTypeConverter &converter,
+    OwningRewritePatternList &patterns) {
+  patterns.insert<TypeConversionAssumingOpConverter>(context, converter);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp
index 380c72087bbd..3a0d6ebd962b 100644
--- a/mlir/lib/Transforms/BufferPlacement.cpp
+++ b/mlir/lib/Transforms/BufferPlacement.cpp
@@ -875,8 +875,8 @@ LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite(
   TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
   for (auto argType : llvm::enumerate(funcType.getInputs())) {
     SmallVector<Type, 2> decomposedTypes, convertedTypes;
-    converter->tryDecomposeType(argType.value(), decomposedTypes);
-    converter->convertTypes(decomposedTypes, convertedTypes);
+    converter.tryDecomposeType(argType.value(), decomposedTypes);
+    converter.convertTypes(decomposedTypes, convertedTypes);
     conversion.addInputs(argType.index(), convertedTypes);
   }
 
@@ -885,10 +885,10 @@ LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite(
   newResultTypes.reserve(funcOp.getNumResults());
   for (Type resultType : funcType.getResults()) {
     SmallVector<Type, 2> originTypes;
-    converter->tryDecomposeType(resultType, originTypes);
+    converter.tryDecomposeType(resultType, originTypes);
     for (auto origin : originTypes) {
-      Type converted = converter->convertType(origin);
-      auto kind = converter->getResultConversionKind(origin, converted);
+      Type converted = converter.convertType(origin);
+      auto kind = converter.getResultConversionKind(origin, converted);
       if (kind == BufferAssignmentTypeConverter::AppendToArgumentsList)
         conversion.addInputs(converted);
       else
@@ -897,7 +897,7 @@ LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite(
     }
   }
 
-  if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter,
+  if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), converter,
                                          &conversion)))
     return failure();
 
@@ -986,8 +986,8 @@ LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
   // values if a decompose callback function has been provided by the user.
   for (auto operand : operands) {
     SmallVector<Value, 2> values;
-    this->converter->tryDecomposeValue(builder, loc, operand.getType(), operand,
-                                       values);
+    this->converter.tryDecomposeValue(builder, loc, operand.getType(), operand,
+                                      values);
     newOperands.append(values.begin(), values.end());
   }
 
@@ -998,11 +998,11 @@ LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
   mappings.resize(callOp.getNumResults());
   for (auto result : llvm::enumerate(callOp.getResults())) {
     SmallVector<Type, 2> originTypes;
-    converter->tryDecomposeType(result.value().getType(), originTypes);
+    converter.tryDecomposeType(result.value().getType(), originTypes);
     auto &resultMapping = mappings[result.index()];
     for (Type origin : originTypes) {
-      Type converted = converter->convertType(origin);
-      auto kind = converter->getResultConversionKind(origin, converted);
+      Type converted = converter.convertType(origin);
+      auto kind = converter.getResultConversionKind(origin, converted);
       if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) {
         newResultTypes.push_back(converted);
         // The result value is not yet available. Its index is kept and it is
@@ -1039,7 +1039,7 @@ LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
     } else {
       // Values need to be packed using callback function. The same callback
       // that is used for materializeArgumentConversion is used for packing.
-      Value packed = converter->materializeArgumentConversion(
+      Value packed = converter.materializeArgumentConversion(
           nextBuilder, loc, callOp.getType(i), valuesToPack);
       replacedValues.push_back(packed);
     }

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 5f2b9e32dac7..32d618d9008e 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -77,7 +77,7 @@ struct FoldingPattern : public RewritePattern {
 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
   void runOnFunction() override {
     mlir::OwningRewritePatternList patterns;
-    populateWithGenerated(&getContext(), &patterns);
+    populateWithGenerated(&getContext(), patterns);
 
     // Verify named pattern is generated with expected name.
     patterns.insert<FoldingPattern, TestNamedPatternRule>(&getContext());
@@ -547,7 +547,7 @@ struct TestLegalizePatternDriver
   void runOnOperation() override {
     TestTypeConverter converter;
     mlir::OwningRewritePatternList patterns;
-    populateWithGenerated(&getContext(), &patterns);
+    populateWithGenerated(&getContext(), patterns);
     patterns.insert<
         TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
         TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase,

diff  --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
index aecf99f69729..ead4c5d6fb38 100644
--- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp
+++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
@@ -147,12 +147,12 @@ struct TestBufferPlacementPreparationPass
   };
 
   void populateTensorLinalgToBufferLinalgConversionPattern(
-      MLIRContext *context, BufferAssignmentTypeConverter *converter,
-      OwningRewritePatternList *patterns) {
+      MLIRContext *context, BufferAssignmentTypeConverter &converter,
+      OwningRewritePatternList &patterns) {
     populateWithBufferAssignmentOpConversionPatterns<
         mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, converter,
                                                         patterns);
-    patterns->insert<GenericOpConverter>(context, converter);
+    patterns.insert<GenericOpConverter>(context, converter);
   }
 
   void getDependentDialects(DialectRegistry &registry) const override {
@@ -230,8 +230,8 @@ struct TestBufferPlacementPreparationPass
     });
 
     OwningRewritePatternList patterns;
-    populateTensorLinalgToBufferLinalgConversionPattern(&context, &converter,
-                                                        &patterns);
+    populateTensorLinalgToBufferLinalgConversionPattern(&context, converter,
+                                                        patterns);
     if (failed(applyFullConversion(this->getOperation(), target, patterns)))
       this->signalPassFailure();
   };

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 495fbe1715e0..ff6138f73914 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1147,9 +1147,9 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
 
   // Emit function to add the generated matchers to the pattern list.
   os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(::mlir::MLIRContext "
-        "*context, ::mlir::OwningRewritePatternList *patterns) {\n";
+        "*context, ::mlir::OwningRewritePatternList &patterns) {\n";
   for (const auto &name : rewriterNames) {
-    os << "  patterns->insert<" << name << ">(context);\n";
+    os << "  patterns.insert<" << name << ">(context);\n";
   }
   os << "}\n";
 }


        


More information about the Mlir-commits mailing list