[Mlir-commits] [mlir] [mlir][bufferize] Update empty_tensor_elimination transform op (PR #68497)

Matthias Springer llvmlistbot at llvm.org
Sat Oct 7 14:00:00 PDT 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/68497

The empty tensor elimination pass semantics have changed recently: when applied to a module, the One-Shot Module Analysis is run. Otherwise, the regular One-Shot Analysis is run. The latter one is slightly different because it ignores function boundaries and treats function block arguments as "read-only".

This commit updates the transform dialect op to behave in the same way.

>From bcc7329f72a71c8a747af8a79008530703158564 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sat, 7 Oct 2023 22:56:15 +0200
Subject: [PATCH] [mlir][bufferize] Update empty_tensor_elimination transform
 op

The empty tensor elimination pass semantics have changed recently: when
applied to a module, the One-Shot Module Analysis is run. Otherwise, the
regular One-Shot Analysis is run. The latter one is slightly different
because it ignores function boundaries and treats function block
arguments as "read-only".

This commit updates the transform dialect op to behave in the same way.
---
 .../Bufferization/Transforms/Transforms.h     |  7 ++++++
 .../BufferizationTransformOps.cpp             | 11 ++-------
 .../Transforms/EmptyTensorElimination.cpp     | 24 +++++++++----------
 3 files changed, 21 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
index df866daf1ab1ff6..892675954493b99 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
@@ -32,6 +32,13 @@ struct OneShotBufferizationOptions;
 /// In the above example, the subset op is "tensor.insert_slice". When tracing
 /// back the reverse use-def chain of a the source, we end up at a
 /// "tensor.empty" op.
+LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op);
+
+/// Try to eliminate "tensor.empty" ops inside `op`.
+///
+/// This function overload accepts an existing `OneShotAnalysisState`, which
+/// contains in-place bufferization decisions. This overload is useful if an
+/// existing analysis should be reused for empty tensor elimination.
 LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
                                     OneShotAnalysisState &state);
 
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index b7db4917a4138ec..cbb36639a3383d7 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -117,17 +117,10 @@ void transform::EliminateEmptyTensorsOp::getEffects(
 DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply(
     transform::TransformRewriter &rewriter, TransformResults &transformResults,
     TransformState &state) {
-  OneShotBufferizationOptions options;
-  options.allowReturnAllocsFromLoops = true;
-
   for (Operation *target : state.getPayloadOps(getTarget())) {
-    OneShotAnalysisState state(target, options);
-    if (failed(analyzeOp(target, state)))
-      return mlir::emitSilenceableFailure(target->getLoc())
-             << "failed to analyze op";
-    if (failed(bufferization::eliminateEmptyTensors(rewriter, target, state)))
+    if (failed(bufferization::eliminateEmptyTensors(rewriter, target)))
       return mlir::emitSilenceableFailure(target->getLoc())
-             << "failed to eliminate insert_slice anchored tensor.empty ops";
+             << "empty tensor elimination failed";
   }
   return DiagnosedSilenceableFailure::success();
 }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 2ddc51357448a7c..1a5a65bfac132a8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -183,8 +183,8 @@ struct EmptyTensorElimination
 };
 } // namespace
 
-void EmptyTensorElimination::runOnOperation() {
-  Operation *op = getOperation();
+LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter,
+                                                         Operation *op) {
   auto moduleOp = dyn_cast<ModuleOp>(op);
   OneShotBufferizationOptions options;
   options.allowReturnAllocsFromLoops = true;
@@ -193,21 +193,21 @@ void EmptyTensorElimination::runOnOperation() {
   OneShotAnalysisState state(op, options);
   if (moduleOp) {
     // Module analysis takes into account function boundaries.
-    if (failed(analyzeModuleOp(moduleOp, state))) {
-      signalPassFailure();
-      return;
-    }
+    if (failed(analyzeModuleOp(moduleOp, state)))
+      return failure();
   } else {
     // Regular One-Shot Bufferize ignores func.func block arguments, func.call,
     // func.return.
-    if (failed(analyzeOp(op, state))) {
-      signalPassFailure();
-      return;
-    }
+    if (failed(analyzeOp(op, state)))
+      return failure();
   }
 
-  IRRewriter rewriter(op->getContext());
-  if (failed(bufferization::eliminateEmptyTensors(rewriter, op, state)))
+  return bufferization::eliminateEmptyTensors(rewriter, op, state);
+}
+
+void EmptyTensorElimination::runOnOperation() {
+  IRRewriter rewriter(getOperation()->getContext());
+  if (failed(bufferization::eliminateEmptyTensors(rewriter, getOperation())))
     signalPassFailure();
 }
 



More information about the Mlir-commits mailing list