[Mlir-commits] [mlir] 52b0fe6 - [mlir] Add func-bufferize pass.

Sean Silva llvmlistbot at llvm.org
Mon Nov 2 12:49:03 PST 2020


Author: Sean Silva
Date: 2020-11-02T12:42:32-08:00
New Revision: 52b0fe64045d3fbbb7604f70066ac91970da612f

URL: https://github.com/llvm/llvm-project/commit/52b0fe64045d3fbbb7604f70066ac91970da612f
DIFF: https://github.com/llvm/llvm-project/commit/52b0fe64045d3fbbb7604f70066ac91970da612f.diff

LOG: [mlir] Add func-bufferize pass.

This is the most basic possible finalizing bufferization pass, which I
also think is sufficient for most new use cases. The more concentrated
nature of this pass also greatly clarifies the invariants that it
requires on its input to safely transform the program (see the
pass description in Passes.td).

With this pass, I have now upstreamed practically all of the
bufferizations from npcomp (the exception being std.constant, which can
be upstreamed when std.global_memref lands:
https://llvm.discourse.group/t/rfc-global-variables-in-mlir/2076/16 )

Differential Revision: https://reviews.llvm.org/D90205

Added: 
    mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
    mlir/test/Dialect/Standard/func-bufferize.mlir

Modified: 
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
    mlir/include/mlir/Transforms/Bufferize.h
    mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
    mlir/lib/Transforms/Bufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
index 714acdff997e3..76fa79a77b25f 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -38,6 +38,9 @@ void populateStdBufferizePatterns(MLIRContext *context,
 /// Creates an instance of std bufferization pass.
 std::unique_ptr<Pass> createStdBufferizePass();
 
+/// Creates an instance of func bufferization pass.
+std::unique_ptr<Pass> createFuncBufferizePass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
index 1ccef1d8f4ea5..b0b172c33e82c 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
@@ -22,4 +22,33 @@ def StdBufferize : FunctionPass<"std-bufferize"> {
   let dependentDialects = ["scf::SCFDialect"];
 }
 
+def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
+  let summary = "Bufferize func/call/return ops";
+  let description = [{
+    A finalizing bufferize pass that bufferizes std.func and std.call ops.
+
+    Because this pass updates std.func ops, it must be a module pass. It is
+    useful to keep this pass separate from other bufferizations so that the
+    other ones can be run at function-level in parallel.
+
+    This pass must be done atomically for two reasons:
+    1. This pass changes func op signatures, which requires atomically updating
+       calls as well throughout the entire module.
+    2. This pass changes the type of block arguments, which requires that all
+       successor arguments of predecessors be converted. Terminators are not
+       a closed universe (and need not implement BranchOpInterface), and so we
+       cannot in general rewrite them.
+
+    Note, because this is a "finalizing" bufferize step, it can create
+    invalid IR because it will not create materializations. To avoid this
+    situation, the pass must only be run when the only SSA values of
+    tensor type are:
+    - block arguments
+    - the result of tensor_load
+    Other values of tensor type should be eliminated by earlier
+    bufferization passes.
+  }];
+  let constructor = "mlir::createFuncBufferizePass()";
+}
+
 #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES

diff  --git a/mlir/include/mlir/Transforms/Bufferize.h b/mlir/include/mlir/Transforms/Bufferize.h
index e40c1bccdd9b8..920eb6c301907 100644
--- a/mlir/include/mlir/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Transforms/Bufferize.h
@@ -150,9 +150,18 @@ class BufferizeTypeConverter : public TypeConverter {
 /// This function should be called by all bufferization passes using
 /// BufferizeTypeConverter so that materializations work proprely. One exception
 /// is bufferization passes doing "full" conversions, where it can be desirable
-/// for even the materializations to remain illegal so that they are eliminated.
+/// for even the materializations to remain illegal so that they are eliminated,
+/// such as via the patterns in
+/// populateEliminateBufferizeMaterializationsPatterns.
 void populateBufferizeMaterializationLegality(ConversionTarget &target);
 
+/// Populate patterns to eliminate bufferize materializations.
+///
+/// In particular, these are the tensor_load/tensor_to_memref ops.
+void populateEliminateBufferizeMaterializationsPatterns(
+    MLIRContext *context, BufferizeTypeConverter &typeConverter,
+    OwningRewritePatternList &patterns);
+
 /// Helper conversion pattern that encapsulates a BufferizeTypeConverter
 /// instance.
 template <typename SourceOp>

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
index aabb81cf3d06e..1334e7f83d557 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms
   ExpandAtomic.cpp
   ExpandMemRefReshape.cpp
   ExpandTanh.cpp
+  FuncBufferize.cpp
   FuncConversions.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
new file mode 100644
index 0000000000000..4aadb72e6368c
--- /dev/null
+++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
@@ -0,0 +1,56 @@
+//===- Bufferize.cpp - Bufferization for std ops --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements bufferization of std.func's and std.call's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
+#include "mlir/Transforms/Bufferize.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
+  void runOnOperation() override {
+    auto module = getOperation();
+    auto *context = &getContext();
+
+    BufferizeTypeConverter typeConverter;
+    OwningRewritePatternList patterns;
+    ConversionTarget target(*context);
+
+    populateFuncOpTypeConversionPattern(patterns, context, typeConverter);
+    target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+      return typeConverter.isSignatureLegal(op.getType()) &&
+             typeConverter.isLegal(&op.getBody());
+    });
+    populateCallOpTypeConversionPattern(patterns, context, typeConverter);
+    populateEliminateBufferizeMaterializationsPatterns(context, typeConverter,
+                                                       patterns);
+    target.addIllegalOp<TensorLoadOp, TensorToMemrefOp>();
+
+    // If all result types are legal, and all block arguments are legal (ensured
+    // by func conversion above), then all types in the program are legal.
+    target.markUnknownOpDynamicallyLegal([&](Operation *op) {
+      return typeConverter.isLegal(op->getResultTypes());
+    });
+
+    if (failed(applyFullConversion(module, target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createFuncBufferizePass() {
+  return std::make_unique<FuncBufferizePass>();
+}

diff  --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp
index d1f91090a62d4..1564290cce4a5 100644
--- a/mlir/lib/Transforms/Bufferize.cpp
+++ b/mlir/lib/Transforms/Bufferize.cpp
@@ -76,6 +76,45 @@ void mlir::populateBufferizeMaterializationLegality(ConversionTarget &target) {
   target.addLegalOp<TensorLoadOp, TensorToMemrefOp>();
 }
 
+namespace {
+// In a finalizing bufferize conversion, we know that all tensors have been
+// converted to memrefs, thus, this op becomes an identity.
+class BufferizeTensorLoadOp : public OpConversionPattern<TensorLoadOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(TensorLoadOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    TensorLoadOp::Adaptor adaptor(operands);
+    rewriter.replaceOp(op, adaptor.memref());
+    return success();
+  }
+};
+} // namespace
+
+namespace {
+// In a finalizing bufferize conversion, we know that all tensors have been
+// converted to memrefs, thus, this op becomes an identity.
+class BufferizeTensorToMemrefOp : public OpConversionPattern<TensorToMemrefOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(TensorToMemrefOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    TensorToMemrefOp::Adaptor adaptor(operands);
+    rewriter.replaceOp(op, adaptor.tensor());
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateEliminateBufferizeMaterializationsPatterns(
+    MLIRContext *context, BufferizeTypeConverter &typeConverter,
+    OwningRewritePatternList &patterns) {
+  patterns.insert<BufferizeTensorLoadOp, BufferizeTensorToMemrefOp>(
+      typeConverter, context);
+}
+
 //===----------------------------------------------------------------------===//
 // BufferizeFuncOpConverter
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Standard/func-bufferize.mlir b/mlir/test/Dialect/Standard/func-bufferize.mlir
new file mode 100644
index 0000000000000..20af66cdb8806
--- /dev/null
+++ b/mlir/test/Dialect/Standard/func-bufferize.mlir
@@ -0,0 +1,64 @@
+// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL:   func @identity(
+// CHECK-SAME:        %[[ARG:.*]]: memref<f32>) -> memref<f32> {
+// CHECK:           return %[[ARG]] : memref<f32>
+func @identity(%arg0: tensor<f32>) -> tensor<f32> {
+  return %arg0 : tensor<f32>
+}
+
+// CHECK-LABEL:   func @block_arguments(
+// CHECK-SAME:        %[[ARG:.*]]: memref<f32>) -> memref<f32> {
+// CHECK:           br ^bb1(%[[ARG]] : memref<f32>)
+// CHECK:         ^bb1(%[[BBARG:.*]]: memref<f32>):
+// CHECK:           return %[[BBARG]] : memref<f32>
+func @block_arguments(%arg0: tensor<f32>) -> tensor<f32> {
+  br ^bb1(%arg0: tensor<f32>)
+^bb1(%bbarg: tensor<f32>):
+  return %bbarg : tensor<f32>
+}
+
+// CHECK-LABEL:   func @eliminate_target_materialization(
+// CHECK-SAME:        %[[ARG:.*]]: memref<f32>) -> memref<f32> {
+// CHECK:           return %[[ARG]] : memref<f32>
+func @eliminate_target_materialization(%arg0: tensor<f32>) -> memref<f32> {
+  %0 = tensor_to_memref %arg0 : memref<f32>
+  return %0 : memref<f32>
+}
+
+// CHECK-LABEL:   func @eliminate_source_materialization(
+// CHECK-SAME:        %[[ARG:.*]]: memref<f32>) -> memref<f32> {
+// CHECK:           return %[[ARG]] : memref<f32>
+func @eliminate_source_materialization(%arg0: memref<f32>) -> tensor<f32> {
+  %0 = tensor_load %arg0 : memref<f32>
+  return %0 : tensor<f32>
+}
+
+// CHECK-LABEL:   func @source() -> memref<f32>
+// CHECK-LABEL:   func @call_source() -> memref<f32> {
+// CHECK:           %[[RET:.*]] = call @source() : () -> memref<f32>
+// CHECK:           return %[[RET]] : memref<f32>
+func @source() -> tensor<f32>
+func @call_source() -> tensor<f32> {
+  %0 = call @source() : () -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// CHECK-LABEL:   func @sink(memref<f32>)
+// CHECK-LABEL:   func @call_sink(
+// CHECK-SAME:        %[[ARG:.*]]: memref<f32>) {
+// CHECK:           call @sink(%[[ARG]]) : (memref<f32>) -> ()
+// CHECK:           return
+func @sink(tensor<f32>)
+func @call_sink(%arg0: tensor<f32>) {
+  call @sink(%arg0) : (tensor<f32>) -> ()
+  return
+}
+
+// -----
+
+func @failed_to_legalize() -> tensor<f32> {
+  // expected-error @+1 {{failed to legalize operation 'test.source'}}
+  %0 = "test.source"() : () -> (tensor<f32>)
+  return %0 : tensor<f32>
+}


        


More information about the Mlir-commits mailing list