[Mlir-commits] [mlir] [mlir] Generalize OneShotModuleBufferize to operate on any Operation (PR #148327)

Evan Liu llvmlistbot at llvm.org
Fri Jul 11 19:52:51 PDT 2025


https://github.com/Evanyl updated https://github.com/llvm/llvm-project/pull/148327

>From 2aa588e7192a4c4afd419a3ea40458cab6443e2a Mon Sep 17 00:00:00 2001
From: Evan Liu <liuyievan at gmail.com>
Date: Fri, 11 Jul 2025 16:05:35 -0700
Subject: [PATCH] [mlir] Generalize OneShotModuleBufferize to operate on any
 Operation

---
 ...duleBufferize.h => OneShotRootBufferize.h} |  32 ++---
 .../BufferizationTransformOps.cpp             |   6 +-
 .../Bufferization/Transforms/Bufferize.cpp    |   5 +-
 .../Bufferization/Transforms/CMakeLists.txt   |   2 +-
 .../Transforms/EmptyTensorElimination.cpp     |   4 +-
 ...Bufferize.cpp => OneShotRootBufferize.cpp} | 124 ++++++++++--------
 .../Transforms/TensorCopyInsertion.cpp        |   4 +-
 .../SparsificationAndBufferizationPass.cpp    |  10 +-
 ...t-root-bufferize-allow-return-allocs.mlir} |   0
 ... => one-shot-root-bufferize-analysis.mlir} |   0
 ...ot-bufferize-force-copy-before-write.mlir} |   0
 ...r => one-shot-root-bufferize-invalid.mlir} |   0
 ...> one-shot-root-bufferize-out-params.mlir} |   0
 ...rize.mlir => one-shot-root-bufferize.mlir} |   0
 .../one-shot-root-non-module-bufferize.mlir   |  33 +++++
 .../lib/Dialect/Bufferization/CMakeLists.txt  |   1 +
 .../TestOneShotRootBufferize.cpp              |  54 ++++++++
 mlir/test/lib/Dialect/Test/TestOps.td         |   9 ++
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 19 files changed, 197 insertions(+), 89 deletions(-)
 rename mlir/include/mlir/Dialect/Bufferization/Transforms/{OneShotModuleBufferize.h => OneShotRootBufferize.h} (64%)
 rename mlir/lib/Dialect/Bufferization/Transforms/{OneShotModuleBufferize.cpp => OneShotRootBufferize.cpp} (85%)
 rename mlir/test/Dialect/Bufferization/Transforms/{one-shot-module-bufferize-allow-return-allocs.mlir => one-shot-root-bufferize-allow-return-allocs.mlir} (100%)
 rename mlir/test/Dialect/Bufferization/Transforms/{one-shot-module-bufferize-analysis.mlir => one-shot-root-bufferize-analysis.mlir} (100%)
 rename mlir/test/Dialect/Bufferization/Transforms/{one-shot-module-bufferize-force-copy-before-write.mlir => one-shot-root-bufferize-force-copy-before-write.mlir} (100%)
 rename mlir/test/Dialect/Bufferization/Transforms/{one-shot-module-bufferize-invalid.mlir => one-shot-root-bufferize-invalid.mlir} (100%)
 rename mlir/test/Dialect/Bufferization/Transforms/{one-shot-module-bufferize-out-params.mlir => one-shot-root-bufferize-out-params.mlir} (100%)
 rename mlir/test/Dialect/Bufferization/Transforms/{one-shot-module-bufferize.mlir => one-shot-root-bufferize.mlir} (100%)
 create mode 100644 mlir/test/Dialect/Bufferization/Transforms/one-shot-root-non-module-bufferize.mlir
 create mode 100644 mlir/test/lib/Dialect/Bufferization/TestOneShotRootBufferize.cpp

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h
similarity index 64%
rename from mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
rename to mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h
index 2cf801dd1d951..32b76269e2c03 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h
@@ -1,4 +1,4 @@
-//===- OneShotModuleBufferize.h - Bufferization across Func. Boundaries ---===//
+//===- OneShotRootBufferize.h - Bufferization across Func. Boundaries ---===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,15 +6,15 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
-#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
+#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTROOTBUFFERIZE_H
+#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTROOTBUFFERIZE_H
 
 namespace llvm {
 struct LogicalResult;
 } // namespace llvm
 
 namespace mlir {
-class ModuleOp;
+class Operation;
 
 namespace bufferization {
 struct BufferizationStatistics;
@@ -22,11 +22,11 @@ class OneShotAnalysisState;
 struct OneShotBufferizationOptions;
 class BufferizationState;
 
-/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
+/// Analyze `rootOp` and its nested ops. Bufferization decisions are stored in
 /// `state`.
 llvm::LogicalResult
-analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
-                BufferizationStatistics *statistics = nullptr);
+analyzeRootOp(Operation *rootOp, OneShotAnalysisState &state,
+              BufferizationStatistics *statistics = nullptr);
 
 /// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
 ///
@@ -38,23 +38,23 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
 ///   is not empty. The FuncOps it contains were not analyzed. Buffer copies
 ///   will be inserted only to these FuncOps.
 llvm::LogicalResult
-bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
-                  BufferizationState &state,
-                  BufferizationStatistics *statistics = nullptr);
+bufferizeRootOp(Operation *rootOp, const OneShotBufferizationOptions &options,
+                BufferizationState &state,
+                BufferizationStatistics *statistics = nullptr);
 
-/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
-void removeBufferizationAttributesInModule(ModuleOp moduleOp);
+/// Remove bufferization attributes on every FuncOp arguments in the RootOp.
+void removeBufferizationAttributesInRoot(Operation *rootOp);
 
-/// Run One-Shot Module Bufferization on the given module. Performs a simple
+/// Run One-Shot Root Bufferization on the given root op. Performs a simple
 /// function call analysis to determine which function arguments are
 /// inplaceable. Then analyzes and bufferizes FuncOps one-by-one with One-Shot
 /// Bufferize.
-llvm::LogicalResult runOneShotModuleBufferize(
-    ModuleOp moduleOp,
+llvm::LogicalResult runOneShotRootBufferize(
+    Operation *rootOp,
     const bufferization::OneShotBufferizationOptions &options,
     BufferizationState &state, BufferizationStatistics *statistics = nullptr);
 
 } // namespace bufferization
 } // namespace mlir
 
-#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
+#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTROOTBUFFERIZE_H
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index db1eb20512033..5a52daf6c7698 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -10,7 +10,7 @@
 
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -92,8 +92,8 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
     if (options.bufferizeFunctionBoundaries) {
       if (!moduleOp)
         return emitSilenceableError() << "expected module target";
-      if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
-                                                          bufferizationState)))
+      if (failed(bufferization::runOneShotRootBufferize(moduleOp, options,
+                                                        bufferizationState)))
         return emitSilenceableError() << "bufferization failed";
     } else {
       if (failed(bufferization::runOneShotBufferize(target, options,
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 246555dc8c699..baef091eeebd1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -12,7 +12,7 @@
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Diagnostics.h"
@@ -163,8 +163,7 @@ struct OneShotBufferizePass
     BufferizationStatistics statistics;
     ModuleOp moduleOp = getOperation();
     if (opt.bufferizeFunctionBoundaries) {
-      if (failed(
-              runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
+      if (failed(runOneShotRootBufferize(moduleOp, opt, state, &statistics))) {
         signalPassFailure();
         return;
       }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 7c38621be1bb5..fa310b95df4bd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -11,7 +11,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
   FuncBufferizableOpInterfaceImpl.cpp
   LowerDeallocations.cpp
   OneShotAnalysis.cpp
-  OneShotModuleBufferize.cpp
+  OneShotRootBufferize.cpp
   OwnershipBasedBufferDeallocation.cpp
   TensorCopyInsertion.cpp
   OptimizeAllocationLiveness.cpp
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index b7db2e847a335..ee1a9178a9d0d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -11,7 +11,7 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Dominance.h"
@@ -209,7 +209,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter,
   OneShotAnalysisState state(op, options);
   if (moduleOp) {
     // Module analysis takes into account function boundaries.
-    if (failed(analyzeModuleOp(moduleOp, state)))
+    if (failed(analyzeRootOp(moduleOp, state)))
       return failure();
   } else {
     // Regular One-Shot Bufferize ignores func.func block arguments, func.call,
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotRootBufferize.cpp
similarity index 85%
rename from mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
rename to mlir/lib/Dialect/Bufferization/Transforms/OneShotRootBufferize.cpp
index d1d106220a38c..a7865050e6e38 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotRootBufferize.cpp
@@ -1,4 +1,5 @@
-//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
+//===- OneShotRootBufferize.cpp - Bufferization across Func. Boundaries
+//----===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,12 +7,12 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// Module Bufferization is an extension of One-Shot Bufferize that
+// Root Bufferization is an extension of One-Shot Bufferize that
 // bufferizes function boundaries. It provides `BufferizableOpInterface`
 // implementations for FuncOp, CallOp and ReturnOp.
 //
-// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
-// This function analyzes the given module and determines the order of analysis
+// Root Bufferization is run via `runOneShotRootBufferize(RootOp, ...)`.
+// This function analyzes the given op and determines the order of analysis
 // and bufferization: Functions that are called are processed before their
 // respective callers.
 //
@@ -24,7 +25,7 @@
 // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
 //   read/written.
 //
-// Module Bufferization implements the following calling convention.
+// Root Bufferization implements the following calling convention.
 //
 // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
 //   be written to in-place.
@@ -57,7 +58,7 @@
 // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
 // as "not reading" and/or "not writing".
 
-#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
 
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -299,7 +300,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
                       llvm::IsaPred<TensorType>);
 }
 
-/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
+/// Store all functions of the `rootOp` in `orderedFuncOps`, sorted by
 /// callee-caller order (i.e., callees without callers first). Store all
 /// remaining functions (i.e., the ones that call each other recursively) in
 /// `remainingFuncOps`. Does not traverse nested symbol tables.
@@ -309,7 +310,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
 /// Return `failure()` if we are unable to retrieve the called FuncOp from
 /// any func::CallOp.
 static LogicalResult getFuncOpsOrderedByCalls(
-    ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
+    Operation *rootOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
     SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
     SymbolTableCollection &symbolTables) {
   // For each FuncOp, the set of functions called by it (i.e. the union of
@@ -317,26 +318,29 @@ static LogicalResult getFuncOpsOrderedByCalls(
   DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
   // For each FuncOp, the number of func::CallOp it contains.
   DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
-
-  for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
-    // Collect function calls and populate the caller map.
-    numberCallOpsContainedInFuncOp[funcOp] = 0;
-    WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
-      func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
-      assert(calledFunction && "could not retrieved called func::FuncOp");
-      // If the called function does not have any tensors in its signature, then
-      // it is not necessary to bufferize the callee before the caller.
-      if (!hasTensorSignature(calledFunction))
-        return WalkResult::skip();
-
-      callerMap[calledFunction].insert(callOp);
-      if (calledBy[calledFunction].insert(funcOp).second) {
-        numberCallOpsContainedInFuncOp[funcOp]++;
+  for (mlir::Region &region : rootOp->getRegions()) {
+    for (mlir::Block &block : region.getBlocks()) {
+      for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
+        // Collect function calls and populate the caller map.
+        numberCallOpsContainedInFuncOp[funcOp] = 0;
+        WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
+          func::FuncOp calledFunction = getCalledFunction(callOp);
+          assert(calledFunction && "could not retrieved called func::FuncOp");
+          // If the called function does not have any tensors in its signature,
+          // then it is not necessary to bufferize the callee before the caller.
+          if (!hasTensorSignature(calledFunction))
+            return WalkResult::skip();
+
+          callerMap[calledFunction].insert(callOp);
+          if (calledBy[calledFunction].insert(funcOp).second) {
+            numberCallOpsContainedInFuncOp[funcOp]++;
+          }
+          return WalkResult::advance();
+        });
+        if (res.wasInterrupted())
+          return failure();
       }
-      return WalkResult::advance();
-    });
-    if (res.wasInterrupted())
-      return failure();
+    }
   }
 
   // Iteratively remove function operations that do not call any of the
@@ -447,9 +451,9 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
 }
 
 LogicalResult
-mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
-                                     OneShotAnalysisState &state,
-                                     BufferizationStatistics *statistics) {
+mlir::bufferization::analyzeRootOp(Operation *rootOp,
+                                   OneShotAnalysisState &state,
+                                   BufferizationStatistics *statistics) {
   assert(state.getOptions().bufferizeFunctionBoundaries &&
          "expected that function boundary bufferization is activated");
   FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
@@ -465,9 +469,8 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
   // A mapping of FuncOps to their callers.
   FuncCallerMap callerMap;
 
-  if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
-                                      remainingFuncOps, callerMap,
-                                      funcState.symbolTables)))
+  if (failed(getFuncOpsOrderedByCalls(rootOp, orderedFuncOps, remainingFuncOps,
+                                      callerMap, funcState.symbolTables)))
     return failure();
 
   // Analyze functions in order. Starting with functions that are not calling
@@ -511,20 +514,24 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
   return success();
 }
 
-void mlir::bufferization::removeBufferizationAttributesInModule(
-    ModuleOp moduleOp) {
-  for (auto op : moduleOp.getOps<func::FuncOp>()) {
-    for (BlockArgument bbArg : op.getArguments())
-      removeBufferizationAttributes(bbArg);
+void mlir::bufferization::removeBufferizationAttributesInRoot(
+    Operation *rootOp) {
+  for (mlir::Region &region : rootOp->getRegions()) {
+    for (mlir::Block &block : region.getBlocks()) {
+      for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
+        for (BlockArgument bbArg : funcOp.getArguments())
+          removeBufferizationAttributes(bbArg);
+      }
+    }
   }
 }
 
-LogicalResult mlir::bufferization::bufferizeModuleOp(
-    ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+LogicalResult mlir::bufferization::bufferizeRootOp(
+    Operation *rootOp, const OneShotBufferizationOptions &options,
     BufferizationState &state, BufferizationStatistics *statistics) {
   assert(options.bufferizeFunctionBoundaries &&
          "expected that function boundary bufferization is activated");
-  IRRewriter rewriter(moduleOp.getContext());
+  IRRewriter rewriter(rootOp->getContext());
 
   // A list of non-circular functions in the order in which they are analyzed
   // and bufferized.
@@ -542,9 +549,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
   // accurate buffer types for function return values. Functions that call
   // each other recursively are bufferized in an unspecified order at the end.
   // We may use unnecessarily "complex" (in terms of layout map) buffer types.
-  if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
-                                      remainingFuncOps, callerMap,
-                                      state.getSymbolTables())))
+  if (failed(getFuncOpsOrderedByCalls(rootOp, orderedFuncOps, remainingFuncOps,
+                                      callerMap, state.getSymbolTables())))
     return failure();
   llvm::append_range(orderedFuncOps, remainingFuncOps);
 
@@ -571,22 +577,27 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
   }
 
   // Bufferize all other ops.
-  for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
-    // Functions were already bufferized.
-    if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
-      continue;
-    if (failed(bufferizeOp(&op, options, state, statistics)))
-      return failure();
+  for (mlir::Region &region : rootOp->getRegions()) {
+    for (mlir::Block &block : region.getBlocks()) {
+      for (mlir::Operation &op :
+           llvm::make_early_inc_range(block.getOperations())) {
+        // Functions were already bufferized.
+        if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
+          continue;
+        if (failed(bufferizeOp(&op, options, state, statistics)))
+          return failure();
+      }
+    }
   }
 
   // Post-pass cleanup of function argument attributes.
-  removeBufferizationAttributesInModule(moduleOp);
+  removeBufferizationAttributesInRoot(rootOp);
 
   return success();
 }
 
-LogicalResult mlir::bufferization::runOneShotModuleBufferize(
-    ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+LogicalResult mlir::bufferization::runOneShotRootBufferize(
+    Operation *rootOp, const OneShotBufferizationOptions &options,
     BufferizationState &state, BufferizationStatistics *statistics) {
   assert(options.bufferizeFunctionBoundaries &&
          "expected that function boundary bufferization is activated");
@@ -594,7 +605,7 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
          "invalid combination of bufferization flags");
   if (!options.copyBeforeWrite) {
     if (options.noAnalysisFuncFilter.empty()) {
-      if (failed(insertTensorCopies(moduleOp, options, state, statistics)))
+      if (failed(insertTensorCopies(rootOp, options, state, statistics)))
         return failure();
     } else {
       // FuncOps whose names are specified in options.noAnalysisFuncFilter will
@@ -610,14 +621,13 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
       };
       OneShotBufferizationOptions updatedOptions(options);
       updatedOptions.opFilter.denyOperation(analysisFilterFn);
-      if (failed(
-              insertTensorCopies(moduleOp, updatedOptions, state, statistics)))
+      if (failed(insertTensorCopies(rootOp, updatedOptions, state, statistics)))
         return failure();
     }
   }
   if (options.testAnalysisOnly)
     return success();
-  if (failed(bufferizeModuleOp(moduleOp, options, state, statistics)))
+  if (failed(bufferizeRootOp(moduleOp, options, state, statistics)))
     return failure();
   return success();
 }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index 784d95a5dd22a..84dc48fb3237d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -12,7 +12,7 @@
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 
@@ -35,7 +35,7 @@ LogicalResult mlir::bufferization::insertTensorCopies(
   // analysis depending on whether function boundary bufferization is enabled or
   // not.
   if (options.bufferizeFunctionBoundaries) {
-    if (failed(analyzeModuleOp(cast<ModuleOp>(op), analysisState, statistics)))
+    if (failed(analyzeRootOp(op, analysisState, statistics)))
       return failure();
   } else {
     if (failed(analyzeOp(op, analysisState, statistics)))
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 15e5102462ad7..c14b2784be6bf 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -13,7 +13,7 @@
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -116,12 +116,12 @@ class SparsificationAndBufferizationPass
 
     bufferization::BufferizationState bufferizationState;
 
-    if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()),
-                                                updatedOptions,
-                                                bufferizationState)))
+    if (failed(bufferization::bufferizeRootOp(cast<ModuleOp>(getOperation()),
+                                              updatedOptions,
+                                              bufferizationState)))
       return failure();
 
-    bufferization::removeBufferizationAttributesInModule(getOperation());
+    bufferization::removeBufferizationAttributesInRoot(getOperation());
     return success();
   }
 
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-allow-return-allocs.mlir
similarity index 100%
rename from mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir
rename to mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-allow-return-allocs.mlir
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-analysis.mlir
similarity index 100%
rename from mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
rename to mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-analysis.mlir
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-force-copy-before-write.mlir
similarity index 100%
rename from mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir
rename to mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-force-copy-before-write.mlir
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-invalid.mlir
similarity index 100%
rename from mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
rename to mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-invalid.mlir
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-out-params.mlir
similarity index 100%
rename from mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir
rename to mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-out-params.mlir
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize.mlir
similarity index 100%
rename from mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
rename to mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize.mlir
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-root-non-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-root-non-module-bufferize.mlir
new file mode 100644
index 0000000000000..25ff512a885d7
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-root-non-module-bufferize.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='builtin.module(test.symbol_scope_isolated(test-one-shot-root-bufferize))' -split-input-file | FileCheck %s
+
+"test.symbol_scope_isolated"() ({
+  // CHECK-LABEL: func @inner_func(
+  //  CHECK-SAME:     %[[arg0:.*]]: memref<?xf32
+  func.func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
+    // CHECK-NOT: copy
+    %f = arith.constant 1.0 : f32
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    // CHECK: memref.store %{{.*}}, %[[arg0]]
+    %0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
+    // CHECK: %[[load:.*]] = memref.load %[[arg0]]
+    %1 = tensor.extract %0[%c1] : tensor<?xf32>
+    // CHECK: return %[[arg0]], %[[load]] : memref<?xf32{{.*}}>, f32
+    return %0, %1 : tensor<?xf32>, f32
+  }
+
+  // CHECK-LABEL: func @call_func_with_non_tensor_return(
+  //  CHECK-SAME:     %[[arg0:.*]]: memref<?xf32
+  func.func @call_func_with_non_tensor_return(
+      %t0: tensor<?xf32> {bufferization.writable = true}) -> (f32, tensor<?xf32>) {
+    // CHECK-NOT: alloc
+    // CHECK-NOT: copy
+    // CHECK: %[[call:.*]]:2 = call @inner_func(%[[arg0]])
+    %0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
+    // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32,{{.*}}>
+    return %1, %0 : f32, tensor<?xf32>
+  }
+  "test.finish" () : () -> ()
+}) : () -> ()
+
+
diff --git a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt
index 226e0bb97732d..50e1e9f8e06c7 100644
--- a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt
@@ -1,5 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRBufferizationTestPasses
+  TestOneShotRootBufferize.cpp
   TestTensorCopyInsertion.cpp
   TestTensorLikeAndBufferLike.cpp
 
diff --git a/mlir/test/lib/Dialect/Bufferization/TestOneShotRootBufferize.cpp b/mlir/test/lib/Dialect/Bufferization/TestOneShotRootBufferize.cpp
new file mode 100644
index 0000000000000..30f0b312eb1af
--- /dev/null
+++ b/mlir/test/lib/Dialect/Bufferization/TestOneShotRootBufferize.cpp
@@ -0,0 +1,54 @@
+//===- TestOneShotRootBufferzation.cpp - Bufferization Test -----*- c++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+struct TestOneShotRootBufferizePass
+    : public PassWrapper<TestOneShotRootBufferizePass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneShotRootBufferizePass)
+
+  TestOneShotRootBufferizePass() = default;
+  TestOneShotRootBufferizePass(const TestOneShotRootBufferizePass &pass)
+      : PassWrapper(pass) {}
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<bufferization::BufferizationDialect>();
+  }
+  StringRef getArgument() const final { return "test-one-shot-root-bufferize"; }
+  StringRef getDescription() const final {
+    return "Module pass to test One Shot Root Bufferization";
+  }
+
+  void runOnOperation() override {
+
+    llvm::errs() << "Running TestOneShotRootBufferize on: "
+                 << getOperation()->getName() << "\n";
+    bufferization::OneShotBufferizationOptions opt;
+
+    opt.bufferizeFunctionBoundaries = true;
+    bufferization::BufferizationState bufferizationState;
+
+    if (failed(bufferization::runOneShotRootBufferize(getOperation(), opt,
+                                                      bufferizationState)))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+namespace mlir::test {
+void registerTestOneShotRootBufferizePass() {
+  PassRegistration<TestOneShotRootBufferizePass>();
+}
+} // namespace mlir::test
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index ab3f847ca2acf..2727ad34b23cb 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -125,6 +125,15 @@ def SymbolScopeOp : TEST_Op<"symbol_scope",
   let regions = (region SizedRegion<1>:$region);
 }
 
+def SymbolScopeIsolatedOp
+    : TEST_Op<"symbol_scope_isolated", [IsolatedFromAbove, SymbolTable,
+                                        SingleBlockImplicitTerminator<
+                                            "TerminatorOp">]> {
+  let summary =
+      "operation which defines a new symbol table that is IsolatedFromAbove";
+  let regions = (region SizedRegion<1>:$region);
+}
+
 def SymbolTableRegionOp : TEST_Op<"symbol_table_region", [SymbolTable]> {
   let summary =  "operation which defines a new symbol table without a "
                  "restriction on a terminator";
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 143a5e8e8f8dd..b0aa5c86615e4 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -135,6 +135,7 @@ void registerTestMeshSimplificationsPass();
 void registerTestMultiBuffering();
 void registerTestNextAccessPass();
 void registerTestNVGPULowerings();
+void registerTestOneShotRootBufferizePass();
 void registerTestOpaqueLoc();
 void registerTestOpLoweringPasses();
 void registerTestPadFusion();
@@ -281,6 +282,7 @@ void registerTestPasses() {
   mlir::test::registerTestMultiBuffering();
   mlir::test::registerTestNextAccessPass();
   mlir::test::registerTestNVGPULowerings();
+  mlir::test::registerTestOneShotRootBufferizePass();
   mlir::test::registerTestOpaqueLoc();
   mlir::test::registerTestOpLoweringPasses();
   mlir::test::registerTestPadFusion();



More information about the Mlir-commits mailing list