[Mlir-commits] [mlir] 17aaa65 - [mlir][bufferization] Add DeallocOp canonicalizer to remove duplicate values

Martin Erhart llvmlistbot at llvm.org
Fri Jul 28 09:28:39 PDT 2023


Author: Martin Erhart
Date: 2023-07-28T16:27:32Z
New Revision: 17aaa651dbfbc79adfe73d5010252dc90dc5752f

URL: https://github.com/llvm/llvm-project/commit/17aaa651dbfbc79adfe73d5010252dc90dc5752f
DIFF: https://github.com/llvm/llvm-project/commit/17aaa651dbfbc79adfe73d5010252dc90dc5752f.diff

LOG: [mlir][bufferization] Add DeallocOp canonicalizer to remove duplicate values

Duplicate values in the retained list can just be removed, however, for duplicates in the list of memrefs to deallocate, we also need to check the conditions and if thhey don't match, we need to compute the OR in order to not miss a case leading to a memory leak.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D156157

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
    mlir/test/Dialect/Bufferization/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td
index 23eb2abf66d1a7..cd77645bedeeca 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td
@@ -28,7 +28,8 @@ def Bufferization_Dialect : Dialect {
     deallocation](/docs/BufferDeallocationInternals/).
   }];
   let dependentDialects = [
-    "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
+    "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect",
+    "arith::ArithDialect"
   ];
 
   let extraClassDeclaration = [{

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index a91ec5bad6aaf9..55182f83832b47 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -514,6 +514,7 @@ def Bufferization_DeallocOp : Bufferization_Op<"dealloc", [
   }];
 
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 #endif // BUFFERIZATION_OPS

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index ab9af5cef987d4..6f475735d633c5 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -766,6 +766,86 @@ LogicalResult DeallocOp::verify() {
   return success();
 }
 
+namespace {
+
+/// Remove duplicate values in the list of retained memrefs as well as the list
+/// of memrefs to be deallocated. For the latter, we need to make sure the
+/// corresponding condition values match as well, or otherwise have to combine
+/// them (by computing the disjunction of them).
+/// Example:
+/// ```mlir
+/// %0:2 = bufferization.dealloc (%arg0, %arg0 : ...)
+///                           if (%arg1, %arg2)
+///                       retain (%arg3, %arg3 : ...)
+/// ```
+/// is canonicalized to
+/// ```mlir
+/// %0 = arith.ori %arg1, %arg2 : i1
+/// %1 = bufferization.dealloc (%arg0 : memref<2xi32>)
+///                         if (%0)
+///                     retain (%arg3 : memref<2xi32>)
+/// ```
+struct DeallocRemoveDuplicates : public OpRewritePattern<DeallocOp> {
+  using OpRewritePattern<DeallocOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DeallocOp deallocOp,
+                                PatternRewriter &rewriter) const override {
+    // Unique memrefs to be deallocated.
+    DenseMap<Value, unsigned> memrefToCondition;
+    SmallVector<Value> newMemrefs, newConditions, newRetained;
+    SmallVector<unsigned> resultIndices;
+    for (auto [memref, cond] :
+         llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
+      if (memrefToCondition.count(memref)) {
+        // If the dealloc conditions don't match, we need to make sure that the
+        // dealloc happens on the union of cases.
+        Value &newCond = newConditions[memrefToCondition[memref]];
+        if (newCond != cond)
+          newCond =
+              rewriter.create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond);
+      } else {
+        memrefToCondition.insert({memref, newConditions.size()});
+        newMemrefs.push_back(memref);
+        newConditions.push_back(cond);
+      }
+      resultIndices.push_back(memrefToCondition[memref]);
+    }
+
+    // Unique retained values
+    DenseSet<Value> seen;
+    for (auto retained : deallocOp.getRetained()) {
+      if (!seen.contains(retained)) {
+        seen.insert(retained);
+        newRetained.push_back(retained);
+      }
+    }
+
+    // Return failure if we don't change anything such that we don't run into an
+    // infinite loop of pattern applications.
+    if (newConditions.size() == deallocOp.getConditions().size() &&
+        newRetained.size() == deallocOp.getRetained().size())
+      return failure();
+
+    // We need to create a new op because the number of results is always the
+    // same as the number of condition operands.
+    auto newDealloc = rewriter.create<DeallocOp>(deallocOp.getLoc(), newMemrefs,
+                                                 newConditions, newRetained);
+    for (auto [i, newIdx] : llvm::enumerate(resultIndices))
+      rewriter.replaceAllUsesWith(deallocOp.getResult(i),
+                                  newDealloc.getResult(newIdx));
+
+    rewriter.eraseOp(deallocOp);
+    return success();
+  }
+};
+
+} // anonymous namespace
+
+void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                            MLIRContext *context) {
+  results.add<DeallocRemoveDuplicates>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index e77414067e3a84..5b6e850643daec 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
 
   LINK_LIBS PUBLIC
   MLIRAffineDialect
+  MLIRArithDialect
   MLIRDestinationStyleOpInterface
   MLIRDialect
   MLIRFuncDialect

diff  --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index 07959e29cd788f..d8088b627d23a8 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -279,3 +279,18 @@ func.func @dealloc_canonicalize_clone_removal(%arg0: memref<?xindex>) -> memref<
 //   CHECK-NOT:   bufferization.clone
 //   CHECK-NOT:   memref.dealloc
 //       CHECK:   return {{.*}}
+
+// -----
+
+func.func @dealloc_canonicalize_duplicates(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>, %arg4: memref<2xi32>) -> (i1, i1, i1, i1, i1) {
+  %0:3 = bufferization.dealloc (%arg4, %arg0, %arg0 : memref<2xi32>, memref<2xi32>, memref<2xi32>) if (%arg1, %arg1, %arg1) retain (%arg3, %arg4, %arg3 : memref<2xi32>, memref<2xi32>, memref<2xi32>)
+  %1:2 = bufferization.dealloc (%arg0, %arg0 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg2)
+  return %0#0, %0#1, %0#2, %1#0, %1#1 : i1, i1, i1, i1, i1
+}
+
+// CHECK-LABEL: func @dealloc_canonicalize_duplicates
+//  CHECK-SAME:  ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xi32>, [[ARG4:%.+]]: memref<2xi32>)
+//  CHECK-NEXT:   [[V0:%.+]]:2 = bufferization.dealloc ([[ARG4]], [[ARG0]] : memref<2xi32>, memref<2xi32>) if ([[ARG1]], [[ARG1]]) retain ([[ARG3]], [[ARG4]] : memref<2xi32>, memref<2xi32>)
+//  CHECK-NEXT:   [[NEW_COND:%.+]] = arith.ori [[ARG1]], [[ARG2]] : i1
+//  CHECK-NEXT:   [[V1:%.+]] = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[NEW_COND]])
+//  CHECK-NEXT:   return [[V0]]#0, [[V0]]#1, [[V0]]#1, [[V1]], [[V1]] :


        


More information about the Mlir-commits mailing list