[Mlir-commits] [mlir] f30ec8f - [mlir][linalg][bufferize][NFC] Allow passing custom BufferizationOptions to pass

Matthias Springer llvmlistbot at llvm.org
Wed Feb 9 02:19:54 PST 2022


Author: Matthias Springer
Date: 2022-02-09T19:15:31+09:00
New Revision: f30ec8f627404202a6266ed844a001500824f814

URL: https://github.com/llvm/llvm-project/commit/f30ec8f627404202a6266ed844a001500824f814
DIFF: https://github.com/llvm/llvm-project/commit/f30ec8f627404202a6266ed844a001500824f814.diff

LOG: [mlir][linalg][bufferize][NFC] Allow passing custom BufferizationOptions to pass

Differential Revision: https://reviews.llvm.org/D118891

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
index 194465e29c4c1..d6e50072317b4 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
@@ -27,9 +27,9 @@ namespace comprehensive_bufferize {
 /// Run Module Bufferization on the given module. Performs a simple function
 /// call analysis to determine which function arguments are inplaceable. Then
 /// analyzes and bufferizes FuncOps one-by-one with One-Shot Bufferize.
-LogicalResult runComprehensiveBufferize(
-    ModuleOp moduleOp,
-    std::unique_ptr<bufferization::AnalysisBufferizationOptions> options);
+LogicalResult
+runModuleBufferize(ModuleOp moduleOp,
+                   bufferization::AnalysisBufferizationOptions options);
 
 namespace std_ext {
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 4b145a944e7fa..487362c62e60a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -18,6 +18,9 @@
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+namespace bufferization {
+struct AnalysisBufferizationOptions;
+} // namespace bufferization
 
 std::unique_ptr<Pass> createConvertElementwiseToLinalgPass();
 
@@ -64,8 +67,8 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass();
 /// on SSA use-def chains starting from function operands that are annotated
 /// with the 'inplaceable' attribute.
 std::unique_ptr<Pass> createLinalgComprehensiveModuleBufferizePass();
-std::unique_ptr<Pass>
-createLinalgComprehensiveModuleBufferizePass(bool useLinalgCopy);
+std::unique_ptr<Pass> createLinalgComprehensiveModuleBufferizePass(
+    const bufferization::AnalysisBufferizationOptions &options);
 
 /// Create a pass to convert Linalg operations which work on tensors to use
 /// buffers instead.

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index a90376cbd9a78..44fac0cb46914 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -52,9 +52,6 @@ def LinalgComprehensiveModuleBufferize :
     Option<"useAlloca", "use-alloca", "bool",
            /*default=*/"false",
            "Use stack allocations for memrefs (for testing purposes only)">,
-    Option<"useLinalgCopy", "use-memref.copy", "bool",
-           /*default=*/"false",
-           "Use a copy operation implemented as a Linalg op.">,
     Option<"fullyDynamicLayoutMaps", "fully-dynamic-layout-maps", "bool",
            /*default=*/"true",
            "Generate MemRef types with dynamic offset+strides by default.">,

diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index ae9b8a337f793..673c27ee2a7f4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -135,6 +135,11 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
                           Value outputTensor,
                           ArrayRef<int64_t> transposeVector);
 
+/// Returns GenericOp that copies an n-D memref. Unlike the current
+/// implementation of memref::CopyOp, this op can further tile, lower to loops
+/// or vectorize.
+GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to);
+
 //===----------------------------------------------------------------------===//
 // Fusion / Tiling utilities
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 6f04a2fd40c27..8eb6df075d3ce 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -10,10 +10,10 @@
 // bufferizes function boundaries. It provides `BufferizableOpInterface`
 // implementations for FuncOp, CallOp and ReturnOp.
 //
-// Module Bufferization is run via `runComprehensiveBufferize(ModuleOp, ...)`.
-// This function analyzed the given module and determines the order of
-// analysis and bufferization: Functions that are called are processed before
-// their respective callers.
+// Module Bufferization is run via `runModuleBufferize(ModuleOp, ...)`. This
+// function analyzes the given module and determines the order of analysis and
+// bufferization: Functions that are called are processed before their
+// respective callers.
 //
 // After analyzing a FuncOp, additional information about its bbArgs is
 // gathered through PostAnalysisStepFns and stored in
@@ -971,10 +971,10 @@ annotateOpsWithBufferizationMarkers(FuncOp funcOp,
       setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state));
 }
 
-LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
-    ModuleOp moduleOp, std::unique_ptr<AnalysisBufferizationOptions> options) {
+LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
+    ModuleOp moduleOp, AnalysisBufferizationOptions options) {
   IRRewriter rewriter(moduleOp.getContext());
-  AnalysisBufferizationState state(moduleOp, *options);
+  AnalysisBufferizationState state(moduleOp, options);
   ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
   BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
 
@@ -983,8 +983,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     return failure();
 
   // Collect bbArg/return value information after the analysis.
-  options->postAnalysisSteps.push_back(equivalentFuncOpBBArgsAnalysis);
-  options->postAnalysisSteps.push_back(funcOpBbArgReadWriteAnalysis);
+  options.addPostAnalysisStep(equivalentFuncOpBBArgsAnalysis);
+  options.addPostAnalysisStep(funcOpBbArgReadWriteAnalysis);
 
   // Analyze ops.
   for (FuncOp funcOp : moduleState.orderedFuncOps) {
@@ -1007,11 +1007,11 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
 
     // Add annotations to function arguments.
-    if (options->testAnalysisOnly)
+    if (options.testAnalysisOnly)
       annotateOpsWithBufferizationMarkers(funcOp, state);
   }
 
-  if (options->testAnalysisOnly)
+  if (options.testAnalysisOnly)
     return success();
 
   // Bufferize function bodies.
@@ -1031,7 +1031,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     if (failed(bufferizeFuncOpBoundary(funcOp, rewriter, state)))
       return failure();
 
-    if (!options->allowReturnMemref &&
+    if (!options.allowReturnMemref &&
         llvm::any_of(funcOp.getType().getResults(), [](Type t) {
           return t.isa<MemRefType, UnrankedMemRefType>();
         })) {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 708db1e089072..cd71264064168 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -38,9 +38,9 @@ struct LinalgComprehensiveModuleBufferize
   LinalgComprehensiveModuleBufferize(
       const LinalgComprehensiveModuleBufferize &p) = default;
 
-  LinalgComprehensiveModuleBufferize(bool linalgCopy) {
-    this->useLinalgCopy = linalgCopy;
-  }
+  explicit LinalgComprehensiveModuleBufferize(
+      AnalysisBufferizationOptions options)
+      : options(options) {}
 
   void runOnOperation() override;
 
@@ -58,6 +58,9 @@ struct LinalgComprehensiveModuleBufferize
     tensor::registerBufferizableOpInterfaceExternalModels(registry);
     vector::registerBufferizableOpInterfaceExternalModels(registry);
   }
+
+private:
+  llvm::Optional<AnalysisBufferizationOptions> options;
 };
 } // namespace
 
@@ -76,71 +79,44 @@ static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
   return allocated;
 }
 
-/// Create a linalg::GenericOp version of an n-D copy that can further tile,
-/// lower to loops or vectorize, unlike the current implementation of
-/// memref::CopyOp.
-/// Do not depend on memref::CopyOp that is getting deprecated.
-static LogicalResult createLinalgCopyOp(OpBuilder &b, Location loc, Value from,
-                                        Value to) {
-  auto memrefTypeFrom = from.getType().cast<MemRefType>();
-  auto memrefTypeTo = to.getType().cast<MemRefType>();
-  if (!memrefTypeFrom || !memrefTypeTo ||
-      memrefTypeFrom.getRank() != memrefTypeTo.getRank())
-    return failure();
-  AffineMap id =
-      AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
-  SmallVector<StringRef> iteratorTypes(memrefTypeTo.getRank(),
-                                       getParallelIteratorTypeName());
-  b.create<linalg::GenericOp>(loc,
-                              /*inputs=*/from,
-                              /*outputs=*/to,
-                              /*indexingMaps=*/llvm::makeArrayRef({id, id}),
-                              /*iteratorTypes=*/iteratorTypes,
-                              [](OpBuilder &b, Location loc, ValueRange args) {
-                                b.create<linalg::YieldOp>(loc, args.front());
-                              });
-  return success();
-}
-
 void LinalgComprehensiveModuleBufferize::runOnOperation() {
-  auto options = std::make_unique<AnalysisBufferizationOptions>();
-  if (useAlloca) {
-    options->allocationFn = allocationFnUsingAlloca;
-    options->deallocationFn = [](OpBuilder &b, Location loc, Value v) {
-      return success();
-    };
-  }
-  // TODO: atm memref::CopyOp can be 200x slower than linalg::GenericOp.
-  // Once this perf bug is fixed more systematically, we can revisit.
-  if (useLinalgCopy)
-    options->memCpyFn = createLinalgCopyOp;
-
-  options->allowReturnMemref = allowReturnMemref;
-  options->allowUnknownOps = allowUnknownOps;
-  options->analysisFuzzerSeed = analysisFuzzerSeed;
-  options->createDeallocs = createDeallocs;
-  options->fullyDynamicLayoutMaps = fullyDynamicLayoutMaps;
-  options->printConflicts = printConflicts;
-  options->testAnalysisOnly = testAnalysisOnly;
-
-  // Enable InitTensorOp elimination.
-  if (initTensorElimination) {
-    options->addPostAnalysisStep(
-        linalg_ext::insertSliceAnchoredInitTensorEliminationStep);
+  AnalysisBufferizationOptions opt;
+  if (!options) {
+    // Make new bufferization options if none were provided when creating the
+    // pass.
+    if (useAlloca) {
+      opt.allocationFn = allocationFnUsingAlloca;
+      opt.deallocationFn = [](OpBuilder &b, Location loc, Value v) {
+        return success();
+      };
+    }
+    opt.allowReturnMemref = allowReturnMemref;
+    opt.allowUnknownOps = allowUnknownOps;
+    opt.analysisFuzzerSeed = analysisFuzzerSeed;
+    opt.createDeallocs = createDeallocs;
+    opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps;
+    opt.printConflicts = printConflicts;
+    opt.testAnalysisOnly = testAnalysisOnly;
+    if (initTensorElimination) {
+      opt.addPostAnalysisStep(
+          linalg_ext::insertSliceAnchoredInitTensorEliminationStep);
+    }
+  } else {
+    opt = *options;
   }
 
   // Only certain scf.for ops are supported by the analysis.
-  options->addPostAnalysisStep(scf::assertScfForAliasingProperties);
+  opt.addPostAnalysisStep(scf::assertScfForAliasingProperties);
 
   ModuleOp moduleOp = getOperation();
   applyEnablingTransformations(moduleOp);
 
-  if (failed(runComprehensiveBufferize(moduleOp, std::move(options)))) {
+  if (failed(runModuleBufferize(moduleOp, opt))) {
     signalPassFailure();
     return;
   }
 
-  if (testAnalysisOnly)
+  if (opt.testAnalysisOnly)
     return;
 
   OpPassManager cleanupPipeline("builtin.module");
@@ -154,7 +130,7 @@ std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass() {
   return std::make_unique<LinalgComprehensiveModuleBufferize>();
 }
 
-std::unique_ptr<Pass>
-mlir::createLinalgComprehensiveModuleBufferizePass(bool useLinalgCopy) {
-  return std::make_unique<LinalgComprehensiveModuleBufferize>(useLinalgCopy);
+std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass(
+    const AnalysisBufferizationOptions &options) {
+  return std::make_unique<LinalgComprehensiveModuleBufferize>(options);
 }

diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index ae2ccdf076656..98a62a0f3cd6f 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -423,6 +423,29 @@ GenericOp makeTransposeOp(OpBuilder &b, Location loc, Value inputTensor,
   return transposeOp;
 }
 
+GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
+  auto memrefTypeTo = to.getType().cast<MemRefType>();
+#ifndef NDEBUG
+  auto memrefTypeFrom = from.getType().cast<MemRefType>();
+  assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
+         "`from` and `to` memref must have the same rank");
+#endif // NDEBUG
+
+  AffineMap id =
+      AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
+  SmallVector<StringRef> iteratorTypes(memrefTypeTo.getRank(),
+                                       getParallelIteratorTypeName());
+  return b.create<linalg::GenericOp>(
+      loc,
+      /*inputs=*/from,
+      /*outputs=*/to,
+      /*indexingMaps=*/llvm::makeArrayRef({id, id}),
+      /*iteratorTypes=*/iteratorTypes,
+      [](OpBuilder &b, Location loc, ValueRange args) {
+        b.create<linalg::YieldOp>(loc, args.front());
+      });
+}
+
 /// Specialization to build an scf "for" nest.
 template <>
 void GenerateLoopNest<scf::ForOp>::doit(


        


More information about the Mlir-commits mailing list