[Mlir-commits] [mlir] 87f2dee - [mlir][bufferization] Add DeallocationSimplification pass
Martin Erhart
llvmlistbot at llvm.org
Thu Aug 10 05:46:22 PDT 2023
Author: Martin Erhart
Date: 2023-08-10T12:45:38Z
New Revision: 87f2dee423d283e584d55ea6169bd4b0be0246c0
URL: https://github.com/llvm/llvm-project/commit/87f2dee423d283e584d55ea6169bd4b0be0246c0
DIFF: https://github.com/llvm/llvm-project/commit/87f2dee423d283e584d55ea6169bd4b0be0246c0.diff
LOG: [mlir][bufferization] Add DeallocationSimplification pass
Adds a pass that can be run after buffer deallocation to simplify the deallocation operations.
In particular, there are patterns that need alias information and thus cannot be added as a regular canonicalization pattern.
This initial commit moves an incorrect canonicalization pattern from over to this new pass and fixes it by querying the alias analysis for the additional information it needs to be correct (there must not by any potential aliasing memref in the retain list other than the currently mached one).
Also, improves this pattern by considering the `extract_strided_metadata` operation which is inserted by the deallocation pass by default.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D157398
Added:
mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
Modified:
mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
mlir/test/Dialect/Bufferization/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 972fa2b0c49d09..8c09ffe17bebd0 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -24,6 +24,11 @@ struct OneShotBufferizationOptions;
/// buffers.
std::unique_ptr<Pass> createBufferDeallocationPass();
+/// Creates a pass that optimizes `bufferization.dealloc` operations. For
+/// example, it reduces the number of alias checks needed at runtime using
+/// static alias analysis.
+std::unique_ptr<Pass> createBufferDeallocationSimplificationPass();
+
/// Run buffer deallocation.
LogicalResult deallocateBuffers(Operation *op);
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index d801aec2a2dc24..7f1474e26be432 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -88,6 +88,26 @@ def BufferDeallocation : Pass<"buffer-deallocation", "func::FuncOp"> {
let constructor = "mlir::bufferization::createBufferDeallocationPass()";
}
+def BufferDeallocationSimplification :
+ Pass<"buffer-deallocation-simplification", "func::FuncOp"> {
+ let summary = "Optimizes `bufferization.dealloc` operation for more "
+ "efficient codegen";
+ let description = [{
+ This pass uses static alias analysis to reduce the number of alias checks
+ required at runtime. Such checks are sometimes necessary to make sure that
+ memrefs aren't deallocated before their last usage (use after free) or that
+ some memref isn't deallocated twice (double free).
+ }];
+
+ let constructor =
+ "mlir::bufferization::createBufferDeallocationSimplificationPass()";
+
+ let dependentDialects = [
+ "mlir::bufferization::BufferizationDialect", "mlir::arith::ArithDialect",
+ "mlir::memref::MemRefDialect"
+ ];
+}
+
def BufferHoisting : Pass<"buffer-hoisting", "func::FuncOp"> {
let summary = "Optimizes placement of allocation operations by moving them "
"into common dominators and out of nested regions";
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index d5237164bd0e83..bc6f2cd29523a3 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -869,57 +869,6 @@ struct DeallocRemoveDuplicateRetainedMemrefs
}
};
-/// Remove memrefs to be deallocated that are also present in the retained list
-/// since they will always alias and thus never actually be deallocated.
-/// Example:
-/// ```mlir
-/// %0 = bufferization.dealloc (%arg0 : ...) if (%arg1) retain (%arg0 : ...)
-/// ```
-/// is canonicalized to
-/// ```mlir
-/// %0 = bufferization.dealloc retain (%arg0 : ...)
-/// ```
-struct DeallocRemoveDeallocMemrefsContainedInRetained
- : public OpRewritePattern<DeallocOp> {
- using OpRewritePattern<DeallocOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(DeallocOp deallocOp,
- PatternRewriter &rewriter) const override {
- // Unique memrefs to be deallocated.
- DenseMap<Value, unsigned> retained;
- for (auto [i, ret] : llvm::enumerate(deallocOp.getRetained()))
- retained[ret] = i;
-
- // There must not be any duplicates in the retain list anymore because we
- // would miss updating one of the result values otherwise.
- if (retained.size() != deallocOp.getRetained().size())
- return failure();
-
- SmallVector<Value> newMemrefs, newConditions;
- for (auto [memref, cond] :
- llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
- if (retained.contains(memref)) {
- rewriter.setInsertionPointAfter(deallocOp);
- auto orOp = rewriter.create<arith::OrIOp>(
- deallocOp.getLoc(),
- deallocOp.getUpdatedConditions()[retained[memref]], cond);
- rewriter.replaceAllUsesExcept(
- deallocOp.getUpdatedConditions()[retained[memref]],
- orOp.getResult(), orOp);
- continue;
- }
-
- newMemrefs.push_back(memref);
- newConditions.push_back(cond);
- }
-
- // Return failure if we don't change anything such that we don't run into an
- // infinite loop of pattern applications.
- return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
- rewriter);
- }
-};
-
/// Erase deallocation operations where the variadic list of memrefs to
/// deallocate is empty. Example:
/// ```mlir
@@ -1021,8 +970,7 @@ struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DeallocRemoveDuplicateDeallocMemrefs,
- DeallocRemoveDuplicateRetainedMemrefs,
- DeallocRemoveDeallocMemrefsContainedInRetained, EraseEmptyDealloc,
+ DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc>(context);
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
new file mode 100644
index 00000000000000..c5f0450121e73c
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -0,0 +1,188 @@
+//===- BufferDeallocationSimplification.cpp -------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements logic for optimizing `bufferization.dealloc` operations
+// that requires more analysis than what can be supported by regular
+// canonicalization patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/AliasAnalysis.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace bufferization {
+#define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATION
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
+} // namespace bufferization
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::bufferization;
+
+//===----------------------------------------------------------------------===//
+// Helpers
+//===----------------------------------------------------------------------===//
+
+static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
+ ValueRange memrefs,
+ ValueRange conditions,
+ PatternRewriter &rewriter) {
+ if (deallocOp.getMemrefs() == memrefs &&
+ deallocOp.getConditions() == conditions)
+ return failure();
+
+ rewriter.updateRootInPlace(deallocOp, [&]() {
+ deallocOp.getMemrefsMutable().assign(memrefs);
+ deallocOp.getConditionsMutable().assign(conditions);
+ });
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Remove values from the `memref` operand list that are also present in the
+/// `retained` list since they will always alias and thus never actually be
+/// deallocated. However, we also need to be certain that no other value in the
+/// `retained` list can alias, for which we use a static alias analysis. This is
+/// necessary because the `dealloc` operation is defined to return one `i1`
+/// value per memref in the `retained` list which represents the disjunction of
+/// the condition values corresponding to all aliasing values in the `memref`
+/// list. In particular, this means that if there is some value R in the
+/// `retained` list which aliases with a value M in the `memref` list (but can
+/// only be staticaly determined to may-alias) and M is also present in the
+/// `retained` list, then it would be illegal to remove M because the result
+/// corresponding to R would be computed incorrectly afterwards.
+/// Because we require an alias analysis, this pattern cannot be applied as a
+/// regular canonicalization pattern.
+///
+/// Example:
+/// ```mlir
+/// %0:3 = bufferization.dealloc (%m0 : ...) if (%cond0)
+/// retain (%m0, %r0, %r1 : ...)
+/// ```
+/// is canonicalized to
+/// ```mlir
+/// // bufferization.dealloc without memrefs and conditions returns %false for
+/// // every retained value
+/// %0:3 = bufferization.dealloc retain (%m0, %r0, %r1 : ...)
+/// %1 = arith.ori %0#0, %cond0 : i1
+/// // replace %0#0 with %1
+/// ```
+/// given that `%r0` and `%r1` may not alias with `%m0`.
+struct DeallocRemoveDeallocMemrefsContainedInRetained
+ : public OpRewritePattern<DeallocOp> {
+ DeallocRemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
+ AliasAnalysis &aliasAnalysis)
+ : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+
+ LogicalResult matchAndRewrite(DeallocOp deallocOp,
+ PatternRewriter &rewriter) const override {
+ // Unique memrefs to be deallocated.
+ DenseMap<Value, unsigned> retained;
+ for (auto [i, ret] : llvm::enumerate(deallocOp.getRetained()))
+ retained[ret] = i;
+
+ // There must not be any duplicates in the retain list anymore because we
+ // would miss updating one of the result values otherwise.
+ if (retained.size() != deallocOp.getRetained().size())
+ return failure();
+
+ SmallVector<Value> newMemrefs, newConditions;
+ for (auto memrefAndCond :
+ llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
+ Value memref = std::get<0>(memrefAndCond);
+ Value cond = std::get<1>(memrefAndCond);
+
+ auto replaceResultsIfNoInvalidAliasing = [&](Value memref) -> bool {
+ Value retainedMemref = deallocOp.getRetained()[retained[memref]];
+ // The current memref must not have a may-alias relation to any retained
+ // memref, and exactly one must-alias relation.
+ // TODO: it is possible to extend this pattern to allow an arbitrary
+ // number of must-alias relations as long as there is no may-alias. If
+ // it's no-alias, then just proceed (only supported case as of now), if
+ // it's must-alias, we also need to update the condition for that alias.
+ if (llvm::all_of(deallocOp.getRetained(), [&](Value mr) {
+ return aliasAnalysis.alias(mr, memref).isNo() ||
+ mr == retainedMemref;
+ })) {
+ rewriter.setInsertionPointAfter(deallocOp);
+ auto orOp = rewriter.create<arith::OrIOp>(
+ deallocOp.getLoc(),
+ deallocOp.getUpdatedConditions()[retained[memref]], cond);
+ rewriter.replaceAllUsesExcept(
+ deallocOp.getUpdatedConditions()[retained[memref]],
+ orOp.getResult(), orOp);
+ return true;
+ }
+ return false;
+ };
+
+ if (retained.contains(memref) &&
+ replaceResultsIfNoInvalidAliasing(memref))
+ continue;
+
+ auto extractOp = memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
+ if (extractOp && retained.contains(extractOp.getOperand()) &&
+ replaceResultsIfNoInvalidAliasing(extractOp.getOperand()))
+ continue;
+
+ newMemrefs.push_back(memref);
+ newConditions.push_back(cond);
+ }
+
+ // Return failure if we don't change anything such that we don't run into an
+ // infinite loop of pattern applications.
+ return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
+ rewriter);
+ }
+
+private:
+ AliasAnalysis &aliasAnalysis;
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// BufferDeallocationSimplificationPass
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// The actual buffer deallocation pass that inserts and moves dealloc nodes
+/// into the right positions. Furthermore, it inserts additional clones if
+/// necessary. It uses the algorithm described at the top of the file.
+struct BufferDeallocationSimplificationPass
+ : public bufferization::impl::BufferDeallocationSimplificationBase<
+ BufferDeallocationSimplificationPass> {
+ void runOnOperation() override {
+ AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
+ RewritePatternSet patterns(&getContext());
+ patterns.add<DeallocRemoveDeallocMemrefsContainedInRetained>(&getContext(),
+ aliasAnalysis);
+
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+std::unique_ptr<Pass>
+mlir::bufferization::createBufferDeallocationSimplificationPass() {
+ return std::make_unique<BufferDeallocationSimplificationPass>();
+}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 8940c58c08f6aa..4c6731f6aec117 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRBufferizationTransforms
Bufferize.cpp
BufferDeallocation.cpp
+ BufferDeallocationSimplification.cpp
BufferOptimizations.cpp
BufferResultsToOutParams.cpp
BufferUtils.cpp
diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
new file mode 100644
index 00000000000000..e5de8569353e35
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt %s --buffer-deallocation-simplification --split-input-file | FileCheck %s
+
+func.func @dealloc_deallocated_in_retained(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1, i1) {
+ %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0 : memref<2xi32>)
+ %1 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>)
+ %2:2 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0, %arg2 : memref<2xi32>, memref<2xi32>)
+ return %0, %1, %2#0, %2#1 : i1, i1, i1, i1
+}
+
+// CHECK-LABEL: func @dealloc_deallocated_in_retained
+// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
+// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc retain ([[ARG0]] : memref<2xi32>)
+// CHECK-NEXT: [[O0:%.+]] = arith.ori [[V0]], [[ARG1]]
+// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
+// CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]]
+// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>)
+// CHECK-NEXT: return [[O0]], [[O1]], [[V2]]#0, [[V2]]#1 :
+
+// -----
+
+func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1, i1) {
+ %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %arg0 : memref<2xi32> -> memref<i32>, index, index, index
+ %base_buffer0, %offset0, %size0, %stride0 = memref.extract_strided_metadata %arg2 : memref<2xi32> -> memref<i32>, index, index, index
+ %0 = bufferization.dealloc (%base_buffer : memref<i32>) if (%arg1) retain (%arg0 : memref<2xi32>)
+ %1 = bufferization.dealloc (%base_buffer, %base_buffer0 : memref<i32>, memref<i32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>)
+ %2:2 = bufferization.dealloc (%base_buffer : memref<i32>) if (%arg1) retain (%arg0, %arg2 : memref<2xi32>, memref<2xi32>)
+ return %0, %1, %2#0, %2#1 : i1, i1, i1, i1
+}
+
+// CHECK-LABEL: func @dealloc_deallocated_in_retained_extract_base_memref
+// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
+// CHECK-NEXT: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG0]] :
+// CHECK-NEXT: [[BASE1:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG2]] :
+// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc retain ([[ARG0]] : memref<2xi32>)
+// CHECK-NEXT: [[O0:%.+]] = arith.ori [[V0]], [[ARG1]]
+// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[BASE1]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
+// CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]]
+// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[BASE0]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>)
+// CHECK-NEXT: return [[O0]], [[O1]], [[V2]]#0, [[V2]]#1 :
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index af222899e5bbd5..0f0ac678d25110 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -297,19 +297,16 @@ func.func @dealloc_canonicalize_duplicates(%arg0: memref<2xi32>, %arg1: i1, %arg
// -----
-func.func @dealloc_canonicalize_retained_and_deallocated(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1) {
- %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0 : memref<2xi32>)
- %1 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>)
+func.func @dealloc_erase_empty(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> i1 {
bufferization.dealloc
- bufferization.dealloc retain (%arg0 : memref<2xi32>)
- return %0, %1 : i1, i1
+ %0 = bufferization.dealloc retain (%arg0 : memref<2xi32>)
+ return %0 : i1
}
-// CHECK-LABEL: func @dealloc_canonicalize_retained_and_deallocated
+// CHECK-LABEL: func @dealloc_erase_empty
// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
-// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
-// CHECK-NEXT: [[V1:%.+]] = arith.ori [[V0]], [[ARG1]]
-// CHECK-NEXT: return [[ARG1]], [[V1]] :
+// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false
+// CHECK-NEXT: return [[FALSE]] :
// -----
More information about the Mlir-commits
mailing list