[Mlir-commits] [mlir] 07c079a - [mlir][bufferization] Add lowering of bufferization.dealloc to memref.dealloc

Martin Erhart llvmlistbot at llvm.org
Wed Jul 19 07:28:57 PDT 2023

Author: Martin Erhart
Date: 2023-07-19T14:28:01Z
New Revision: 07c079a97adf7f028b63aeba9952c75d5f8da030

URL: https://github.com/llvm/llvm-project/commit/07c079a97adf7f028b63aeba9952c75d5f8da030
DIFF: https://github.com/llvm/llvm-project/commit/07c079a97adf7f028b63aeba9952c75d5f8da030.diff

LOG: [mlir][bufferization] Add lowering of bufferization.dealloc to memref.dealloc

Adds a generic lowering that suppors all cases of bufferization.dealloc
and one specialized, more efficient lowering for the simple case. Using
a helper function with for loops in the general case enables
O(|num_dealloc_memrefs|+|num_retain_memrefs|) size of the lowered code.

Depends on D155467

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D155468




diff  --git a/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h b/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h
index 73112bccdbc171..90d299181aaef8 100644
--- a/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h
+++ b/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h
@@ -9,21 +9,16 @@
+#include "mlir/Pass/Pass.h"
 #include <memory>
 namespace mlir {
-class Pass;
-class RewritePatternSet;
+class ModuleOp;
 #include "mlir/Conversion/Passes.h.inc"
-/// Collect a set of patterns to convert memory-related operations from the
-/// Bufferization dialect to the MemRef dialect.
-void populateBufferizationToMemRefConversionPatterns(
-    RewritePatternSet &patterns);
-std::unique_ptr<Pass> createBufferizationToMemRefPass();
+std::unique_ptr<OperationPass<ModuleOp>> createBufferizationToMemRefPass();
 } // namespace mlir

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 767843b73098ab..de07f3e4ccbaf0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -167,7 +167,8 @@ def ConvertAsyncToLLVMPass : Pass<"convert-async-to-llvm", "ModuleOp"> {
 // BufferizationToMemRef
-def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref"> {
+def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref",
+                                        "mlir::ModuleOp"> {
   let summary = "Convert operations from the Bufferization dialect to the "
                 "MemRef dialect";
   let description = [{
@@ -195,7 +196,10 @@ def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref"> {
   let constructor = "mlir::createBufferizationToMemRefPass()";
-  let dependentDialects = ["arith::ArithDialect", "memref::MemRefDialect"];
+  let dependentDialects = [
+    "arith::ArithDialect", "memref::MemRefDialect", "scf::SCFDialect",
+    "func::FuncDialect"
+  ];

diff  --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index a01fc9f117c2be..886bcfa2f8530d 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -15,7 +15,9 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LogicalResult.h"
@@ -77,12 +79,317 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
     return success();
-} // namespace
-void mlir::populateBufferizationToMemRefConversionPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<CloneOpConversion>(patterns.getContext());
+/// The DeallocOpConversion transforms all bufferization dealloc operations into
+/// memref dealloc operations potentially guarded by scf if operations.
+/// Additionally, memref extract_aligned_pointer_as_index and arith operations
+/// are inserted to compute the guard conditions. We distinguish multiple cases
+/// to provide an overall more efficient lowering. In the general case, a helper
+/// func is created to avoid quadratic code size explosion (relative to the
+/// number of operands of the dealloc operation). For examples of each case,
+/// refer to the documentation of the member functions of this class.
+class DeallocOpConversion
+    : public OpConversionPattern<bufferization::DeallocOp> {
+  /// Lower a simple case avoiding the helper function. Ideally, static analysis
+  /// can provide enough aliasing information to split the dealloc operations up
+  /// into this simple case as much as possible before running this pass.
+  ///
+  /// Example:
+  /// ```
+  /// %0 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
+  /// ```
+  /// is lowered to
+  /// ```
+  /// scf.if %arg1 {
+  ///   memref.dealloc %arg0 : memref<2xf32>
+  /// }
+  /// %0 = arith.constant false
+  /// ```
+  LogicalResult
+  rewriteOneMemrefNoRetainCase(bufferization::DeallocOp op, OpAdaptor adaptor,
+                               ConversionPatternRewriter &rewriter) const {
+    rewriter.create<scf::IfOp>(op.getLoc(), adaptor.getConditions()[0],
+                               [&](OpBuilder &builder, Location loc) {
+                                 builder.create<memref::DeallocOp>(
+                                     loc, adaptor.getMemrefs()[0]);
+                                 builder.create<scf::YieldOp>(loc);
+                               });
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(op,
+                                                   rewriter.getBoolAttr(false));
+    return success();
+  }
+  /// Lowering that supports all features the dealloc operation has to offer. It
+  /// computes the base pointer of each memref (as an index), stores them in a
+  /// new memref and passes it to the helper function generated in
+  /// 'buildDeallocationHelperFunction'. The two return values are used as
+  /// condition for the scf if operation containing the memref deallocate and as
+  /// replacement for the original bufferization dealloc respectively.
+  ///
+  /// Example:
+  /// ```
+  /// %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>)
+  ///                           if (%arg3, %arg4) retain (%arg2 : memref<1xf32>)
+  /// ```
+  /// lowers to (simplified):
+  /// ```
+  /// %c0 = arith.constant 0 : index
+  /// %c1 = arith.constant 1 : index
+  /// %alloc = memref.alloc() : memref<2xindex>
+  /// %alloc_0 = memref.alloc() : memref<1xindex>
+  /// %intptr = memref.extract_aligned_pointer_as_index %arg0
+  /// memref.store %intptr, %alloc[%c0] : memref<2xindex>
+  /// %intptr_1 = memref.extract_aligned_pointer_as_index %arg1
+  /// memref.store %intptr_1, %alloc[%c1] : memref<2xindex>
+  /// %intptr_2 = memref.extract_aligned_pointer_as_index %arg2
+  /// memref.store %intptr_2, %alloc_0[%c0] : memref<1xindex>
+  /// %cast = memref.cast %alloc : memref<2xindex> to memref<?xindex>
+  /// %cast_4 = memref.cast %alloc_0 : memref<1xindex> to memref<?xindex>
+  /// %0:2 = call @dealloc_helper(%cast, %cast_4, %c0)
+  /// %1 = arith.andi %0#0, %arg3 : i1
+  /// %2 = arith.andi %0#1, %arg3 : i1
+  /// scf.if %1 {
+  ///   memref.dealloc %arg0 : memref<2xf32>
+  /// }
+  /// %3:2 = call @dealloc_helper(%cast, %cast_4, %c1)
+  /// %4 = arith.andi %3#0, %arg4 : i1
+  /// %5 = arith.andi %3#1, %arg4 : i1
+  /// scf.if %4 {
+  ///   memref.dealloc %arg1 : memref<5xf32>
+  /// }
+  /// memref.dealloc %alloc : memref<2xindex>
+  /// memref.dealloc %alloc_0 : memref<1xindex>
+  /// // replace %0#0 with %2
+  /// // replace %0#1 with %5
+  /// ```
+  LogicalResult rewriteGeneralCase(bufferization::DeallocOp op,
+                                   OpAdaptor adaptor,
+                                   ConversionPatternRewriter &rewriter) const {
+    // Allocate two memrefs holding the base pointer indices of the list of
+    // memrefs to be deallocated and the ones to be retained. These can then be
+    // passed to the helper function and the for-loops can iterate over them.
+    // Without storing them to memrefs, we could not use for-loops but only a
+    // completely unrolled version of it, potentially leading to code-size
+    // blow-up.
+    Value toDeallocMemref = rewriter.create<memref::AllocOp>(
+        op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()},
+                                     rewriter.getIndexType()));
+    Value toRetainMemref = rewriter.create<memref::AllocOp>(
+        op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()},
+                                     rewriter.getIndexType()));
+    auto getConstValue = [&](uint64_t value) -> Value {
+      return rewriter.create<arith::ConstantOp>(op.getLoc(),
+                                                rewriter.getIndexAttr(value));
+    };
+    // Extract the base pointers of the memrefs as indices to check for aliasing
+    // at runtime.
+    for (auto [i, toDealloc] : llvm::enumerate(adaptor.getMemrefs())) {
+      Value memrefAsIdx =
+          rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
+                                                                  toDealloc);
+      rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx,
+                                       toDeallocMemref, getConstValue(i));
+    }
+    for (auto [i, toRetain] : llvm::enumerate(adaptor.getRetained())) {
+      Value memrefAsIdx =
+          rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(),
+                                                                  toRetain);
+      rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx, toRetainMemref,
+                                       getConstValue(i));
+    }
+    // Cast the allocated memrefs to dynamic shape because we want only one
+    // helper function no matter how many operands the bufferization.dealloc
+    // has.
+    Value castedDeallocMemref = rewriter.create<memref::CastOp>(
+        op->getLoc(),
+        MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()),
+        toDeallocMemref);
+    Value castedRetainMemref = rewriter.create<memref::CastOp>(
+        op->getLoc(),
+        MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()),
+        toRetainMemref);
+    SmallVector<Value> replacements;
+    for (unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) {
+      auto callOp = rewriter.create<func::CallOp>(
+          op.getLoc(), deallocHelperFunc,
+          SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
+                             getConstValue(i)});
+      Value shouldDealloc = rewriter.create<arith::AndIOp>(
+          op.getLoc(), callOp.getResult(0), adaptor.getConditions()[i]);
+      Value ownership = rewriter.create<arith::AndIOp>(
+          op.getLoc(), callOp.getResult(1), adaptor.getConditions()[i]);
+      replacements.push_back(ownership);
+      rewriter.create<scf::IfOp>(
+          op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) {
+            builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[i]);
+            builder.create<scf::YieldOp>(loc);
+          });
+    }
+    // Deallocate above allocated memrefs again to avoid memory leaks.
+    // Deallocation will not be run on code after this stage.
+    rewriter.create<memref::DeallocOp>(op.getLoc(), toDeallocMemref);
+    rewriter.create<memref::DeallocOp>(op.getLoc(), toRetainMemref);
+    rewriter.replaceOp(op, replacements);
+    return success();
+  }
+  DeallocOpConversion(MLIRContext *context, func::FuncOp deallocHelperFunc)
+      : OpConversionPattern<bufferization::DeallocOp>(context),
+        deallocHelperFunc(deallocHelperFunc) {}
+  LogicalResult
+  matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Lower the trivial case.
+    if (adaptor.getMemrefs().empty())
+      return rewriter.eraseOp(op), success();
+    if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty())
+      return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter);
+    return rewriteGeneralCase(op, adaptor, rewriter);
+  }
+  /// Build a helper function per compilation unit that can be called at
+  /// bufferization dealloc sites to determine aliasing and ownership.
+  ///
+  /// The generated function takes two memrefs of indices and one index value as
+  /// arguments and returns two boolean values:
+  ///   * The first memref argument A should contain the result of the
+  ///   extract_aligned_pointer_as_index operation applied to the memrefs to be
+  ///   deallocated
+  ///   * The second memref argument B should contain the result of the
+  ///   extract_aligned_pointer_as_index operation applied to the memrefs to be
+  ///   retained
+  ///   * The index argument I represents the currently processed index of
+  ///   memref A and is needed because aliasing with all previously deallocated
+  ///   memrefs has to be checked to avoid double deallocation
+  ///   * The first result indicates whether the memref at position I should be
+  ///   deallocated
+  ///   * The second result provides the updated ownership value corresponding
+  ///   the the memref at position I
+  ///
+  /// This helper function is supposed to be called for each element in the list
+  /// of memrefs to be deallocated to determine the deallocation need and new
+  /// ownership indicator, but does not perform the deallocation itself.
+  ///
+  /// The first scf for loop in the body computes whether the memref at index I
+  /// aliases with any memref in the list of retained memrefs.
+  /// The second loop additionally checks whether one of the previously
+  /// deallocated memrefs aliases with the currently processed one.
+  ///
+  /// Generated code:
+  /// ```
+  /// func.func @dealloc_helper(%arg0: memref<?xindex>,
+  ///                           %arg1: memref<?xindex>,
+  ///                           %arg2: index) -> (i1, i1) {
+  ///   %c0 = arith.constant 0 : index
+  ///   %c1 = arith.constant 1 : index
+  ///   %true = arith.constant true
+  ///   %dim = memref.dim %arg1, %c0 : memref<?xindex>
+  ///   %0 = memref.load %arg0[%arg2] : memref<?xindex>
+  ///   %1 = scf.for %i = %c0 to %dim step %c1 iter_args(%arg4 = %true) -> (i1){
+  ///     %4 = memref.load %arg1[%i] : memref<?xindex>
+  ///     %5 = arith.cmpi ne, %4, %0 : index
+  ///     %6 = arith.andi %arg4, %5 : i1
+  ///     scf.yield %6 : i1
+  ///   }
+  ///   %2 = scf.for %i = %c0 to %arg2 step %c1 iter_args(%arg4 = %1) -> (i1) {
+  ///     %4 = memref.load %arg0[%i] : memref<?xindex>
+  ///     %5 = arith.cmpi ne, %4, %0 : index
+  ///     %6 = arith.andi %arg4, %5 : i1
+  ///     scf.yield %6 : i1
+  ///   }
+  ///   %3 = arith.xori %1, %true : i1
+  ///   return %2, %3 : i1, i1
+  /// }
+  /// ```
+  static func::FuncOp
+  buildDeallocationHelperFunction(OpBuilder &builder, Location loc,
+                                  SymbolTable &symbolTable) {
+    Type idxType = builder.getIndexType();
+    Type memrefArgType = MemRefType::get({ShapedType::kDynamic}, idxType);
+    SmallVector<Type> argTypes{memrefArgType, memrefArgType, idxType};
+    builder.clearInsertionPoint();
+    // Generate the func operation itself.
+    auto helperFuncOp = func::FuncOp::create(
+        loc, "dealloc_helper",
+        builder.getFunctionType(argTypes,
+                                {builder.getI1Type(), builder.getI1Type()}));
+    symbolTable.insert(helperFuncOp);
+    auto &block = helperFuncOp.getFunctionBody().emplaceBlock();
+    block.addArguments(argTypes, SmallVector<Location>(argTypes.size(), loc));
+    builder.setInsertionPointToStart(&block);
+    Value toDeallocMemref = helperFuncOp.getArguments()[0];
+    Value toRetainMemref = helperFuncOp.getArguments()[1];
+    Value idxArg = helperFuncOp.getArguments()[2];
+    // Insert some prerequisites.
+    Value c0 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(0));
+    Value c1 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(1));
+    Value trueValue =
+        builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(true));
+    Value toRetainSize = builder.create<memref::DimOp>(loc, toRetainMemref, c0);
+    Value toDealloc =
+        builder.create<memref::LoadOp>(loc, toDeallocMemref, idxArg);
+    // Build the first for loop that computes aliasing with retained memrefs.
+    Value noRetainAlias =
+        builder
+            .create<scf::ForOp>(
+                loc, c0, toRetainSize, c1, trueValue,
+                [&](OpBuilder &builder, Location loc, Value i,
+                    ValueRange iterArgs) {
+                  Value retainValue =
+                      builder.create<memref::LoadOp>(loc, toRetainMemref, i);
+                  Value doesntAlias = builder.create<arith::CmpIOp>(
+                      loc, arith::CmpIPredicate::ne, retainValue, toDealloc);
+                  Value yieldValue = builder.create<arith::AndIOp>(
+                      loc, iterArgs[0], doesntAlias);
+                  builder.create<scf::YieldOp>(loc, yieldValue);
+                })
+            .getResult(0);
+    // Build the second for loop that adds aliasing with previously deallocated
+    // memrefs.
+    Value noAlias =
+        builder
+            .create<scf::ForOp>(
+                loc, c0, idxArg, c1, noRetainAlias,
+                [&](OpBuilder &builder, Location loc, Value i,
+                    ValueRange iterArgs) {
+                  Value prevDeallocValue =
+                      builder.create<memref::LoadOp>(loc, toDeallocMemref, i);
+                  Value doesntAlias = builder.create<arith::CmpIOp>(
+                      loc, arith::CmpIPredicate::ne, prevDeallocValue,
+                      toDealloc);
+                  Value yieldValue = builder.create<arith::AndIOp>(
+                      loc, iterArgs[0], doesntAlias);
+                  builder.create<scf::YieldOp>(loc, yieldValue);
+                })
+            .getResult(0);
+    Value ownership =
+        builder.create<arith::XOrIOp>(loc, noRetainAlias, trueValue);
+    builder.create<func::ReturnOp>(loc, SmallVector<Value>{noAlias, ownership});
+    return helperFuncOp;
+  }
+  func::FuncOp deallocHelperFunc;
+} // namespace
 namespace {
 struct BufferizationToMemRefPass
@@ -90,12 +397,30 @@ struct BufferizationToMemRefPass
   BufferizationToMemRefPass() = default;
   void runOnOperation() override {
+    ModuleOp module = cast<ModuleOp>(getOperation());
+    OpBuilder builder =
+        OpBuilder::atBlockBegin(&module.getBodyRegion().front());
+    SymbolTable symbolTable(module);
+    // Build dealloc helper function if there are deallocs.
+    func::FuncOp helperFuncOp;
+    getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
+      if (deallocOp.getMemrefs().size() > 1 ||
+          !deallocOp.getRetained().empty()) {
+        helperFuncOp = DeallocOpConversion::buildDeallocationHelperFunction(
+            builder, getOperation()->getLoc(), symbolTable);
+        return WalkResult::interrupt();
+      }
+      return WalkResult::advance();
+    });
     RewritePatternSet patterns(&getContext());
-    populateBufferizationToMemRefConversionPatterns(patterns);
+    patterns.add<CloneOpConversion>(patterns.getContext());
+    patterns.add<DeallocOpConversion>(patterns.getContext(), helperFuncOp);
     ConversionTarget target(getContext());
-    target.addLegalDialect<memref::MemRefDialect>();
-    target.addLegalOp<arith::ConstantOp>();
+    target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
+                           scf::SCFDialect, func::FuncDialect>();
     if (failed(applyPartialConversion(getOperation(), target,
@@ -105,6 +430,7 @@ struct BufferizationToMemRefPass
 } // namespace
-std::unique_ptr<Pass> mlir::createBufferizationToMemRefPass() {
+mlir::createBufferizationToMemRefPass() {
   return std::make_unique<BufferizationToMemRefPass>();

diff  --git a/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt b/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt
index 3fce07be8b138c..28d9f3b6e6ac32 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt
+++ b/mlir/lib/Conversion/BufferizationToMemRef/CMakeLists.txt
@@ -9,6 +9,10 @@ add_mlir_conversion_library(MLIRBufferizationToMemRef
+  MLIRSCFDialect
+  MLIRFuncDialect
+  MLIRArithDialect
+  MLIRMemRefDialect

diff  --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
index e73c170b5b99e9..aed232d29d175c 100644
--- a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
+++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir
@@ -66,3 +66,93 @@ func.func @conversion_with_invalid_layout_map(%arg0 : memref<?xf32, strided<[10]
   memref.dealloc %arg0 : memref<?xf32, strided<[10], offset: ?>>
   return %1 : memref<?xf32, strided<[10], offset: ?>>
+// -----
+// CHECK-LABEL: func @conversion_dealloc_empty
+func.func @conversion_dealloc_empty() {
+  // CHECK-NEXT: return
+  bufferization.dealloc
+  return
+// -----
+// CHECK-NOT: func @deallocHelper
+// CHECK-LABEL: func @conversion_dealloc_simple
+// CHECK-SAME: [[ARG0:%.+]]: memref<2xf32>
+// CHECK-SAME: [[ARG1:%.+]]: i1
+func.func @conversion_dealloc_simple(%arg0: memref<2xf32>, %arg1: i1) -> i1 {
+  %0 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
+  return %0 : i1
+//      CHECk: scf.if [[ARG1]] {
+// CHECk-NEXT:   memref.dealloc [[ARG0]] : memref<2xf32>
+// CHECk-NEXT: }
+// CHECk-NEXT: [[FALSE:%.+]] = arith.constant false
+// CHECk-NEXT: return [[FALSE]] : i1
+// -----
+func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1) -> (i1, i1) {
+  %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2 : memref<1xf32>)
+  return %0#0, %0#1 : i1, i1
+// CHECK-LABEL: func @conversion_dealloc_multiple_memrefs_and_retained
+// CHECK-SAME: [[ARG0:%.+]]: memref<2xf32>,
+// CHECK-SAME: [[ARG1:%.+]]: memref<5xf32>,
+// CHECK-SAME: [[ARG2:%.+]]: memref<1xf32>,
+// CHECK-SAME: [[ARG3:%.+]]: i1,
+// CHECK-SAME: [[ARG4:%.+]]: i1
+//      CHECK: [[TO_DEALLOC_MR:%.+]] = memref.alloc() : memref<2xindex>
+//      CHECK: [[TO_RETAIN_MR:%.+]] = memref.alloc() : memref<1xindex>
+//  CHECK-DAG: [[V0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG0]]
+//  CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
+//  CHECK-DAG: memref.store [[V0]], [[TO_DEALLOC_MR]][[[C0]]]
+//  CHECK-DAG: [[V1:%.+]] = memref.extract_aligned_pointer_as_index [[ARG1]]
+//  CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
+//  CHECK-DAG: memref.store [[V1]], [[TO_DEALLOC_MR]][[[C1]]]
+//  CHECK-DAG: [[V2:%.+]] = memref.extract_aligned_pointer_as_index [[ARG2]]
+//  CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
+//  CHECK-DAG: memref.store [[V2]], [[TO_RETAIN_MR]][[[C0]]]
+//  CHECK-DAG: [[CAST_DEALLOC:%.+]] = memref.cast [[TO_DEALLOC_MR]] : memref<2xindex> to memref<?xindex>
+//  CHECK-DAG: [[CAST_RETAIN:%.+]] = memref.cast [[TO_RETAIN_MR]] : memref<1xindex> to memref<?xindex>
+//  CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
+//      CHECK: [[RES0:%.+]]:2 = call @dealloc_helper([[CAST_DEALLOC]], [[CAST_RETAIN]], [[C0]])
+//      CHECK: [[SHOULD_DEALLOC_0:%.+]] = arith.andi [[RES0]]#0, [[ARG3]]
+//      CHECK: [[OWNERSHIP0:%.+]] = arith.andi [[RES0]]#1, [[ARG3]]
+//      CHECK: scf.if [[SHOULD_DEALLOC_0]] {
+//      CHECK:   memref.dealloc %arg0
+//      CHECK: }
+//      CHECK: [[C1:%.+]] = arith.constant 1 : index
+//      CHECK: [[RES1:%.+]]:2 = call @dealloc_helper([[CAST_DEALLOC]], [[CAST_RETAIN]], [[C1]])
+//      CHECK: [[SHOULD_DEALLOC_1:%.+]] = arith.andi [[RES1:%.+]]#0, [[ARG4]]
+//      CHECK: [[OWNERSHIP1:%.+]] = arith.andi [[RES1:%.+]]#1, [[ARG4]]
+//      CHECK: scf.if [[SHOULD_DEALLOC_1]]
+//      CHECK:   memref.dealloc [[ARG1]]
+//      CHECK: }
+//      CHECK: memref.dealloc [[TO_DEALLOC_MR]]
+//      CHECK: memref.dealloc [[TO_RETAIN_MR]]
+//      CHECK: return [[OWNERSHIP0]], [[OWNERSHIP1]]
+//      CHECK: func @dealloc_helper
+// CHECK-SAME: [[ARG0:%.+]]: memref<?xindex>, [[ARG1:%.+]]: memref<?xindex>
+// CHECK-SAME: [[ARG2:%.+]]: index
+// CHECK-SAME:  -> (i1, i1)
+//      CHECK:   [[TO_RETAIN_SIZE:%.+]] = memref.dim [[ARG1]], %c0
+//      CHECK:   [[TO_DEALLOC:%.+]] = memref.load [[ARG0]][[[ARG2]]] : memref<?xindex>
+// CHECK-NEXT:   [[NO_RETAIN_ALIAS:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[TO_RETAIN_SIZE]] step %c1 iter_args([[ITER_ARG:%.+]] = %true) -> (i1) {
+// CHECK-NEXT:     [[RETAIN_VAL:%.+]] = memref.load [[ARG1]][[[ITER]]] : memref<?xindex>
+// CHECK-NEXT:     [[DOES_ALIAS:%.+]] = arith.cmpi ne, [[RETAIN_VAL]], [[TO_DEALLOC]] : index
+// CHECK-NEXT:     [[AGG_DOES_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_ALIAS]] : i1
+// CHECK-NEXT:     scf.yield [[AGG_DOES_ALIAS]] : i1
+// CHECK-NEXT:   }
+// CHECK-NEXT:   [[SHOULD_DEALLOC:%.+]] = scf.for [[ITER:%.+]] = %c0 to [[ARG2]] step %c1 iter_args([[ITER_ARG:%.+]] = [[NO_RETAIN_ALIAS]]) -> (i1) {
+// CHECK-NEXT:     [[OTHER_DEALLOC_VAL:%.+]] = memref.load [[ARG0]][[[ITER]]] : memref<?xindex>
+// CHECK-NEXT:     [[DOES_ALIAS:%.+]] = arith.cmpi ne, [[OTHER_DEALLOC_VAL]], [[TO_DEALLOC]] : index
+// CHECK-NEXT:     [[AGG_DOES_ALIAS:%.+]] = arith.andi [[ITER_ARG]], [[DOES_ALIAS]] : i1
+// CHECK-NEXT:     scf.yield [[AGG_DOES_ALIAS]] : i1
+// CHECK-NEXT:   }
+// CHECK-NEXT:   [[OWNERSHIP:%.+]] = arith.xori [[NO_RETAIN_ALIAS]], %true : i1
+// CHECK-NEXT:   return [[SHOULD_DEALLOC]], [[OWNERSHIP]] : i1, i1

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index e6b24a1da43daa..c535beffa44604 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -11689,6 +11689,7 @@ cc_library(
+        ":SCFDialect",


More information about the Mlir-commits mailing list