[Mlir-commits] [mlir] 07c079a - [mlir][bufferization] Add lowering of bufferization.dealloc to memref.dealloc
Martin Erhart
llvmlistbot at llvm.org
Wed Jul 19 07:28:57 PDT 2023
Author: Martin Erhart
Date: 2023-07-19T14:28:01Z
New Revision: 07c079a97adf7f028b63aeba9952c75d5f8da030
URL: https://github.com/llvm/llvm-project/commit/07c079a97adf7f028b63aeba9952c75d5f8da030
DIFF: https://github.com/llvm/llvm-project/commit/07c079a97adf7f028b63aeba9952c75d5f8da030.diff
LOG: [mlir][bufferization] Add lowering of bufferization.dealloc to memref.dealloc
Adds a generic lowering that suppors all cases of bufferization.dealloc
and one specialized, more efficient lowering for the simple case. Using
a helper function with for loops in the general case enables
O(|num_dealloc_memrefs|+|num_retain_memrefs|) size of the lowered code.
Depends on D155467
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D155468
Added:
Modified:
mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt
mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h b/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h
index 73112bccdbc171..90d299181aaef8 100644
--- a/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h
+++ b/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h
@@ -9,21 +9,16 @@
#ifndef MLIR_CONVERSION_BUFFERIZATIONTOMEMREF_BUFFERIZATIONTOMEMREF_H
#define MLIR_CONVERSION_BUFFERIZATIONTOMEMREF_BUFFERIZATIONTOMEMREF_H
+#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
-class Pass;
-class RewritePatternSet;
+class ModuleOp;
#define GEN_PASS_DECL_CONVERTBUFFERIZATIONTOMEMREF
#include "mlir/Conversion/Passes.h.inc"
-/// Collect a set of patterns to convert memory-related operations from the
-/// Bufferization dialect to the MemRef dialect.
-void populateBufferizationToMemRefConversionPatterns(
- RewritePatternSet &patterns);
-
-std::unique_ptr<Pass> createBufferizationToMemRefPass();
+std::unique_ptr<OperationPass<ModuleOp>> createBufferizationToMemRefPass();
} // namespace mlir
#endif // MLIR_CONVERSION_BUFFERIZATIONTOMEMREF_BUFFERIZATIONTOMEMREF_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 767843b73098ab..de07f3e4ccbaf0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -167,7 +167,8 @@ def ConvertAsyncToLLVMPass : Pass<"convert-async-to-llvm", "ModuleOp"> {
// BufferizationToMemRef
//===----------------------------------------------------------------------===//
-def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref"> {
+def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref",
+ "mlir::ModuleOp"> {
let summary = "Convert operations from the Bufferization dialect to the "
"MemRef dialect";
let description = [{
@@ -195,7 +196,10 @@ def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref"> {
}];
let constructor = "mlir::createBufferizationToMemRefPass()";
- let dependentDialects = ["arith::ArithDialect", "memref::MemRefDialect"];
+ let dependentDialects = [
+ "arith::ArithDialect", "memref::MemRefDialect", "scf::SCFDialect",
+ "func::FuncDialect"
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index a01fc9f117c2be..886bcfa2f8530d 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -15,7 +15,9 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
@@ -77,12 +79,317 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
return success();
}
};
-} // namespace
-void mlir::populateBufferizationToMemRefConversionPatterns(
- RewritePatternSet &patterns) {
- patterns.add<CloneOpConversion>(patterns.getContext());
-}
+/// The DeallocOpConversion transforms all bufferization dealloc operations into
+/// memref dealloc operations potentially guarded by scf if operations.
+/// Additionally, memref extract_aligned_pointer_as_index and arith operations
+/// are inserted to compute the guard conditions. We distinguish multiple cases
+/// to provide an overall more efficient lowering. In the general case, a helper
+/// func is created to avoid quadratic code size explosion (relative to the
+/// number of operands of the dealloc operation). For examples of each case,
+/// refer to the documentation of the member functions of this class.
+class DeallocOpConversion
+ : public OpConversionPattern<bufferization::DeallocOp> {
+
+ /// Lower a simple case avoiding the helper function. Ideally, static analysis
+ /// can provide enough aliasing information to split the dealloc operations up
+ /// into this simple case as much as possible before running this pass.
+ ///
+ /// Example:
+ /// ```
+ /// %0 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
+ /// ```
+ /// is lowered to
+ /// ```
+ /// scf.if %arg1 {
+ /// memref.dealloc %arg0 : memref<2xf32>
+ /// }
+ /// %0 = arith.constant false
+ /// ```
+ LogicalResult
+ rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ rewriter.create<scf::IfOp>(op.getLoc(), adaptor.getConditions()[0],
+ [&](OpBuilder &builder, Location loc) {
+ builder.create<memref::DeallocOp>(
+ loc, adaptor.getMemrefs()[0]);
+ builder.create<scf::YieldOp>(loc);
+ });
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(op,
+ rewriter.getBoolAttr(false));
+ return success();
+ }
+
+ /// Lowering that supports all features the dealloc operation has to offer. It
+ /// computes the base pointer of each memref (as an index), stores them in a
+ /// new memref and passes it to the helper function generated in
+ /// 'buildDeallocationHelperFunction'. The two return values are used as
+ /// condition for the scf if operation containing the memref deallocate and as
+ /// replacement for the original bufferization dealloc respectively.
+ ///
+ /// Example:
+ /// ```
+ /// %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>)
+ /// if (%arg3, %arg4) retain (%arg2 : memref<1xf32>)
+ /// ```
+ /// lowers to (simplified):
+ /// ```
+ /// %c0 = arith.constant 0 : index
+ /// %c1 = arith.constant 1 : index
+ /// %alloc = memref.alloc() : memref<2xindex>
+ /// %alloc_0 = memref.alloc() : memref<1xindex>
+ /// %intptr = memref.extract_aligned_pointer_as_index %arg0
+ /// memref.store %intptr, %alloc[%c0] : memref<2xindex>
+ /// %intptr_1 = memref.extract_aligned_pointer_as_index %arg1
+ /// memref.store %intptr_1, %alloc[%c1] : memref<2xindex>
+ /// %intptr_2 = memref.extract_aligned_pointer_as_index %arg2
+ /// memref.store %intptr_2, %alloc_0[%c0] : memref<1xindex>
+ /// %cast = memref.cast %alloc : memref<2xindex> to memref<?xindex>
+ /// %cast_4 = memref.cast %alloc_0 : memref<1xindex> to memref<?xindex>
+ /// %0:2 = call @dealloc_helper(%cast, %cast_4, %c0)
+ /// %1 = arith.andi %0#0, %arg3 : i1
+ /// %2 = arith.andi %0#1, %arg3 : i1
+ /// scf.if %1 {
+ /// memref.dealloc %arg0 : memref<2xf32>
+ /// }
+ /// %3:2 = call @dealloc_helper(%cast, %cast_4, %c1)
+ /// %4 = arith.andi %3#0, %arg4 : i1
+ /// %5 = arith.andi %3#1, %arg4 : i1
+ /// scf.if %4 {
+ /// memref.dealloc %arg1 : memref<5xf32>
+ /// }
+ /// memref.dealloc %alloc : memref<2xindex>
+ /// memref.dealloc %alloc_0 : memref<1xindex>
+ /// // replace %0#0 with %2
+ /// // replace %0#1 with %5
+ /// ```
+ LogicalResult rewriteGeneralCase(bufferization::DeallocOp op,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // Allocate two memrefs holding the base pointer indices of the list of
+ // memrefs to be deallocated and the ones to be retained. These can then be
+ // passed to the helper function and the for-loops can iterate over them.
+ // Without storing them to memrefs, we could not use for-loops but only a
+ // completely unrolled version of it, potentially leading to code-size
+ // blow-up.
+ Value toDeallocMemref = rewriter.create<memref::AllocOp>(
+ op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
+ rewriter.getIndexType()));
+ Value toRetainMemref = rewriter.create<memref::AllocOp>(
+ op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()},
+ rewriter.getIndexType()));
+
+ auto getConstValue = [&](uint64_t value) -> Value {
+ return rewriter.create<arith::ConstantOp>(op.getLoc(),
+ rewriter.getIndexAttr(value));
+ };
+
+ // Extract the base pointers of the memrefs as indices to check for aliasing
+ // at runtime.
+ for (auto [i, toDealloc] : llvm::enumerate(adaptor.getMemrefs())) {
+ Value memrefAsIdx =
+ rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
+ toDealloc);
+ rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx,
+ toDeallocMemref, getConstValue(i));
+ }
+ for (auto [i, toRetain] : llvm::enumerate(adaptor.getRetained())) {
+ Value memrefAsIdx =
+ rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
+ toRetain);
+ rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx, toRetainMemref,
+ getConstValue(i));
+ }
+
+ // Cast the allocated memrefs to dynamic shape because we want only one
+ // helper function no matter how many operands the bufferization.dealloc
+ // has.
+ Value castedDeallocMemref = rewriter.create<memref::CastOp>(
+ op->getLoc(),
+ MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()),
+ toDeallocMemref);
+ Value castedRetainMemref = rewriter.create<memref::CastOp>(
+ op->getLoc(),
+ MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()),
+ toRetainMemref);
+
+ SmallVector<Value> replacements;
+ for (unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) {
+ auto callOp = rewriter.create<func::CallOp>(
+ op.getLoc(), deallocHelperFunc,
+ SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
+ getConstValue(i)});
+ Value shouldDealloc = rewriter.create<arith::AndIOp>(
+ op.getLoc(), callOp.getResult(0), adaptor.getConditions()[i]);
+ Value ownership = rewriter.create<arith::AndIOp>(
+ op.getLoc(), callOp.getResult(1), adaptor.getConditions()[i]);
+ replacements.push_back(ownership);
+ rewriter.create<scf::IfOp>(
+ op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) {
+ builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[i]);
+ builder.create<scf::YieldOp>(loc);
+ });
+ }
+
+ // Deallocate above allocated memrefs again to avoid memory leaks.
+ // Deallocation will not be run on code after this stage.
+ rewriter.create<memref::DeallocOp>(op.getLoc(), toDeallocMemref);
+ rewriter.create<memref::DeallocOp>(op.getLoc(), toRetainMemref);
+
+ rewriter.replaceOp(op, replacements);
+ return success();
+ }
+
+public:
+ DeallocOpConversion(MLIRContext *context, func::FuncOp deallocHelperFunc)
+ : OpConversionPattern<bufferization::DeallocOp>(context),
+ deallocHelperFunc(deallocHelperFunc) {}
+
+ LogicalResult
+ matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Lower the trivial case.
+ if (adaptor.getMemrefs().empty())
+ return rewriter.eraseOp(op), success();
+
+ if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty())
+ return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter);
+
+ return rewriteGeneralCase(op, adaptor, rewriter);
+ }
+
+ /// Build a helper function per compilation unit that can be called at
+ /// bufferization dealloc sites to determine aliasing and ownership.
+ ///
+ /// The generated function takes two memrefs of indices and one index value as
+ /// arguments and returns two boolean values:
+ /// * The first memref argument A should contain the result of the
+ /// extract_aligned_pointer_as_index operation applied to the memrefs to be
+ /// deallocated
+ /// * The second memref argument B should contain the result of the
+ /// extract_aligned_pointer_as_index operation applied to the memrefs to be
+ /// retained
+ /// * The index argument I represents the currently processed index of
+ /// memref A and is needed because aliasing with all previously deallocated
+ /// memrefs has to be checked to avoid double deallocation
+ /// * The first result indicates whether the memref at position I should be
+ /// deallocated
+ /// * The second result provides the updated ownership value corresponding
+ /// the the memref at position I
+ ///
+ /// This helper function is supposed to be called for each element in the list
+ /// of memrefs to be deallocated to determine the deallocation need and new
+ /// ownership indicator, but does not perform the deallocation itself.
+ ///
+ /// The first scf for loop in the body computes whether the memref at index I
+ /// aliases with any memref in the list of retained memrefs.
+ /// The second loop additionally checks whether one of the previously
+ /// deallocated memrefs aliases with the currently processed one.
+ ///
+ /// Generated code:
+ /// ```
+ /// func.func @dealloc_helper(%arg0: memref<?xindex>,
+ /// %arg1: memref<?xindex>,
+ /// %arg2: index) -> (i1, i1) {
+ /// %c0 = arith.constant 0 : index
+ /// %c1 = arith.constant 1 : index
+ /// %true = arith.constant true
+ /// %dim = memref.dim %arg1, %c0 : memref<?xindex>
+ /// %0 = memref.load %arg0[%arg2] : memref<?xindex>
+ /// %1 = scf.for %i = %c0 to %dim step %c1 iter_args(%arg4 = %true) -> (i1){
+ /// %4 = memref.load %arg1[%i] : memref<?xindex>
+ /// %5 = arith.cmpi ne, %4, %0 : index
+ /// %6 = arith.andi %arg4, %5 : i1
+ /// scf.yield %6 : i1
+ /// }
+ /// %2 = scf.for %i = %c0 to %arg2 step %c1 iter_args(%arg4 = %1) -> (i1) {
+ /// %4 = memref.load %arg0[%i] : memref<?xindex>
+ /// %5 = arith.cmpi ne, %4, %0 : index
+ /// %6 = arith.andi %arg4, %5 : i1
+ /// scf.yield %6 : i1
+ /// }
+ /// %3 = arith.xori %1, %true : i1
+ /// return %2, %3 : i1, i1
+ /// }
+ /// ```
+ static func::FuncOp
+ buildDeallocationHelperFunction(OpBuilder &builder, Location loc,
+ SymbolTable &symbolTable) {
+ Type idxType = builder.getIndexType();
+ Type memrefArgType = MemRefType::get({ShapedType::kDynamic}, idxType);
+ SmallVector<Type> argTypes{memrefArgType, memrefArgType, idxType};
+ builder.clearInsertionPoint();
+
+ // Generate the func operation itself.
+ auto helperFuncOp = func::FuncOp::create(
+ loc, "dealloc_helper",
+ builder.getFunctionType(argTypes,
+ {builder.getI1Type(), builder.getI1Type()}));
+ symbolTable.insert(helperFuncOp);
+ auto &block = helperFuncOp.getFunctionBody().emplaceBlock();
+ block.addArguments(argTypes, SmallVector<Location>(argTypes.size(), loc));
+
+ builder.setInsertionPointToStart(&block);
+ Value toDeallocMemref = helperFuncOp.getArguments()[0];
+ Value toRetainMemref = helperFuncOp.getArguments()[1];
+ Value idxArg = helperFuncOp.getArguments()[2];
+
+ // Insert some prerequisites.
+ Value c0 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(0));
+ Value c1 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(1));
+ Value trueValue =
+ builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(true));
+ Value toRetainSize = builder.create<memref::DimOp>(loc, toRetainMemref, c0);
+ Value toDealloc =
+ builder.create<memref::LoadOp>(loc, toDeallocMemref, idxArg);
+
+ // Build the first for loop that computes aliasing with retained memrefs.
+ Value noRetainAlias =
+ builder
+ .create<scf::ForOp>(
+ loc, c0, toRetainSize, c1, trueValue,
+ [&](OpBuilder &builder, Location loc, Value i,
+ ValueRange iterArgs) {
+ Value retainValue =
+ builder.create<memref::LoadOp>(loc, toRetainMemref, i);
+ Value doesntAlias = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ne, retainValue, toDealloc);
+ Value yieldValue = builder.create<arith::AndIOp>(
+ loc, iterArgs[0], doesntAlias);
+ builder.create<scf::YieldOp>(loc, yieldValue);
+ })
+ .getResult(0);
+
+ // Build the second for loop that adds aliasing with previously deallocated
+ // memrefs.
+ Value noAlias =
+ builder
+ .create<scf::ForOp>(
+ loc, c0, idxArg, c1, noRetainAlias,
+ [&](OpBuilder &builder, Location loc, Value i,
+ ValueRange iterArgs) {
+ Value prevDeallocValue =
+ builder.create<memref::LoadOp>(loc, toDeallocMemref, i);
+ Value doesntAlias = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ne, prevDeallocValue,
+ toDealloc);
+ Value yieldValue = builder.create<arith::AndIOp>(
+ loc, iterArgs[0], doesntAlias);
+ builder.create<scf::YieldOp>(loc, yieldValue);
+ })
+ .getResult(0);
+
+ Value ownership =
+ builder.create<arith::XOrIOp>(loc, noRetainAlias, trueValue);
+ builder.create<func::ReturnOp>(loc, SmallVector<Value>{noAlias, ownership});
+
+ return helperFuncOp;
+ }
+
+private:
+ func::FuncOp deallocHelperFunc;
+};
+} // namespace
namespace {
struct BufferizationToMemRefPass
@@ -90,12 +397,30 @@ struct BufferizationToMemRefPass
BufferizationToMemRefPass() = default;
void runOnOperation() override {
+ ModuleOp module = cast<ModuleOp>(getOperation());
+ OpBuilder builder =
+ OpBuilder::atBlockBegin(&module.getBodyRegion().front());
+ SymbolTable symbolTable(module);
+
+ // Build dealloc helper function if there are deallocs.
+ func::FuncOp helperFuncOp;
+ getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
+ if (deallocOp.getMemrefs().size() > 1 ||
+ !deallocOp.getRetained().empty()) {
+ helperFuncOp = DeallocOpConversion::buildDeallocationHelperFunction(
+ builder, getOperation()->getLoc(), symbolTable);
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+
RewritePatternSet patterns(&getContext());
- populateBufferizationToMemRefConversionPatterns(patterns);
+ patterns.add<CloneOpConversion>(patterns.getContext());
+ patterns.add<DeallocOpConversion>(patterns.getContext(), helperFuncOp);
ConversionTarget target(getContext());
- target.addLegalDialect<memref::MemRefDialect>();
- target.addLegalOp<arith::ConstantOp>();
+ target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
+ scf::SCFDialect, func::FuncDialect>();
target.addIllegalDialect<bufferization::BufferizationDialect>();
if (failed(applyPartialConversion(getOperation(), target,
@@ -105,6 +430,7 @@ struct BufferizationToMemRefPass
};
} // namespace
-std::unique_ptr<Pass> mlir::createBufferizationToMemRefPass() {
+std::unique_ptr<OperationPass<ModuleOp>>
+mlir::createBufferizationToMemRefPass() {
return std::make_unique<BufferizationToMemRefPass>();
}
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt b/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt
index 3fce07be8b138c..28d9f3b6e6ac32 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt
+++ b/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt
@@ -9,6 +9,10 @@ add_mlir_conversion_library(MLIRBufferizationToMemRef
LINK_LIBS PUBLIC
MLIRBufferizationDialect
+ MLIRSCFDialect
+ MLIRFuncDialect
+ MLIRArithDialect
+ MLIRMemRefDialect
MLIRPass
MLIRTransforms
)
diff --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
index e73c170b5b99e9..aed232d29d175c 100644
--- a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
+++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
@@ -66,3 +66,93 @@ func.func @conversion_with_invalid_layout_map(%arg0 : memref<?xf32, strided<[10]
memref.dealloc %arg0 : memref<?xf32, strided<[10], offset: ?>>
return %1 : memref<?xf32, strided<[10], offset: ?>>
}
+// -----
+
+// CHECK-LABEL: func @conversion_dealloc_empty
+func.func @conversion_dealloc_empty() {
+ // CHECK-NEXT: return
+ bufferization.dealloc
+ return
+}
+
+// -----
+
+// CHECK-NOT: func @deallocHelper
+// CHECK-LABEL: func @conversion_dealloc_simple
+// CHECK-SAME: [[ARG0:%.+]]: memref<2xf32>
+// CHECK-SAME: [[ARG1:%.+]]: i1
+func.func @conversion_dealloc_simple(%arg0: memref<2xf32>, %arg1: i1) -> i1 {
+ %0 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
+ return %0 : i1
+}
+
+// CHECk: scf.if [[ARG1]] {
+// CHECk-NEXT: memref.dealloc [[ARG0]] : memref<2xf32>
+// CHECk-NEXT: }
+// CHECk-NEXT: [[FALSE:%.+]] = arith.constant false
+// CHECk-NEXT: return [[FALSE]] : i1
+
+// -----
+
+func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1) -> (i1, i1) {
+ %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2 : memref<1xf32>)
+ return %0#0, %0#1 : i1, i1
+}
+
+// CHECK-LABEL: func @conversion_dealloc_multiple_memrefs_and_retained
+// CHECK-SAME: [[ARG0:%.+]]: memref<2xf32>,
+// CHECK-SAME: [[ARG1:%.+]]: memref<5xf32>,
+// CHECK-SAME: [[ARG2:%.+]]: memref<1xf32>,
+// CHECK-SAME: [[ARG3:%.+]]: i1,
+// CHECK-SAME: [[ARG4:%.+]]: i1
+// CHECK: [[TO_DEALLOC_MR:%.+]] = memref.alloc() : memref<2xindex>
+// CHECK: [[TO_RETAIN_MR:%.+]] = memref.alloc() : memref<1xindex>
+// CHECK-DAG: [[V0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG0]]
+// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
+// CHECK-DAG: memref.store [[V0]], [[TO_DEALLOC_MR]][[[C0]]]
+// CHECK-DAG: [[V1:%.+]] = memref.extract_aligned_pointer_as_index [[ARG1]]
+// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
+// CHECK-DAG: memref.store [[V1]], [[TO_DEALLOC_MR]][[[C1]]]
+// CHECK-DAG: [[V2:%.+]] = memref.extract_aligned_pointer_as_index [[ARG2]]
+// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
+// CHECK-DAG: memref.store [[V2]], [[TO_RETAIN_MR]][[[C0]]]
+// CHECK-DAG: [[CAST_DEALLOC:%.+]] = memref.cast [[TO_DEALLOC_MR]] : memref<2xindex> to memref<?xindex>
+// CHECK-DAG: [[CAST_RETAIN:%.+]] = memref.cast [[TO_RETAIN_MR]] : memref<1xindex> to memref<?xindex>
+// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
+// CHECK: [[RES0:%.+]]:2 = call @dealloc_helper([[CAST_DEALLOC]], [[CAST_RETAIN]], [[C0]])
+// CHECK: [[SHOULD_DEALLOC_0:%.+]] = arith.andi [[RES0]]#0, [[ARG3]]
+// CHECK: [[OWNERSHIP0:%.+]] = arith.andi [[RES0]]#1, [[ARG3]]
+// CHECK: scf.if [[SHOULD_DEALLOC_0]] {
+// CHECK: memref.dealloc %arg0
+// CHECK: }
+// CHECK: [[C1:%.+]] = arith.constant 1 : index
+// CHECK: [[RES1:%.+]]:2 = call @dealloc_helper([[CAST_DEALLOC]], [[CAST_RETAIN]], [[C1]])
+// CHECK: [[SHOULD_DEALLOC_1:%.+]] = arith.andi [[RES1:%.+]]#0, [[ARG4]]
+// CHECK: [[OWNERSHIP1:%.+]] = arith.andi [[RES1:%.+]]#1, [[ARG4]]
+// CHECK: scf.if [[SHOULD_DEALLOC_1]]
+// CHECK: memref.dealloc [[ARG1]]
+// CHECK: }
+// CHECK: memref.dealloc [[TO_DEALLOC_MR]]
+// CHECK: memref.dealloc [[TO_RETAIN_MR]]
+// CHECK: return [[OWNERSHIP0]], [[OWNERSHIP1]]
+
+// CHECK: func @dealloc_helper
+// CHECK-SAME: [[ARG0:%.+]]: memref<?xindex>, [[ARG1:%.+]]: memref<?xindex>
+// CHECK-SAME: [[ARG2:%.+]]: index
+// CHECK-SAME: -> (i1, i1)
+// CHECK: [[TO_RETAIN_SIZE:%.+]] = memref.dim [[ARG1]], %c0
+// CHECK: [[TO_DEALLOC:%.+]] = memref.load [[ARG0]][[[ARG2]]] : memref<?xindex>
+// CHECK-NEXT: [[NO_RETAIN_ALIAS:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[TO_RETAIN_SIZE]] step %c1 iter_args([[ITER_ARG:%.+]] = %true) -> (i1) {
+// CHECK-NEXT: [[RETAIN_VAL:%.+]] = memref.load [[ARG1]][[[ITER]]] : memref<?xindex>
+// CHECK-NEXT: [[DOES_ALIAS:%.+]] = arith.cmpi ne, [[RETAIN_VAL]], [[TO_DEALLOC]] : index
+// CHECK-NEXT: [[AGG_DOES_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_ALIAS]] : i1
+// CHECK-NEXT: scf.yield [[AGG_DOES_ALIAS]] : i1
+// CHECK-NEXT: }
+// CHECK-NEXT: [[SHOULD_DEALLOC:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[ARG2]] step %c1 iter_args([[ITER_ARG:%.+]] = [[NO_RETAIN_ALIAS]]) -> (i1) {
+// CHECK-NEXT: [[OTHER_DEALLOC_VAL:%.+]] = memref.load [[ARG0]][[[ITER]]] : memref<?xindex>
+// CHECK-NEXT: [[DOES_ALIAS:%.+]] = arith.cmpi ne, [[OTHER_DEALLOC_VAL]], [[TO_DEALLOC]] : index
+// CHECK-NEXT: [[AGG_DOES_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_ALIAS]] : i1
+// CHECK-NEXT: scf.yield [[AGG_DOES_ALIAS]] : i1
+// CHECK-NEXT: }
+// CHECK-NEXT: [[OWNERSHIP:%.+]] = arith.xori [[NO_RETAIN_ALIAS]], %true : i1
+// CHECK-NEXT: return [[SHOULD_DEALLOC]], [[OWNERSHIP]] : i1, i1
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index e6b24a1da43daa..c535beffa44604 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -11689,6 +11689,7 @@ cc_library(
":IR",
":MemRefDialect",
":Pass",
+ ":SCFDialect",
":Support",
":Transforms",
],
More information about the Mlir-commits
mailing list