[Mlir-commits] [mlir] 3610c82 - [mlir][bufferization] Enable BufferizationToMemRef to run on functions

Martin Erhart llvmlistbot at llvm.org
Mon Aug 28 08:09:19 PDT 2023


Author: Martin Erhart
Date: 2023-08-28T15:08:48Z
New Revision: 3610c82ce398c8ed46cbeba3cfa0a0c0d100c02f

URL: https://github.com/llvm/llvm-project/commit/3610c82ce398c8ed46cbeba3cfa0a0c0d100c02f
DIFF: https://github.com/llvm/llvm-project/commit/3610c82ce398c8ed46cbeba3cfa0a0c0d100c02f.diff

LOG: [mlir][bufferization] Enable BufferizationToMemRef to run on functions

Modify the `convert-bufferization-to-memref` pass to be applicable to
`builtin.module` and `FunctionOpInterface`. In many cases the lowering using
the library function is not needed because enough static information is
available to simplify the dealloc operations enough to use the more
light-weight lowerings. In those situations, it is better to allow the pass to
run on functions for additional parallelization and for easier use with the
transform dialect (e.g., when the transform ops are in the same
`builtin.module` as the function to be transformed).

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D158969

Added: 
    mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref-func.mlir

Modified: 
    mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h b/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h
index 90d299181aaef8..c7146239f34b18 100644
--- a/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h
+++ b/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h
@@ -18,7 +18,7 @@ class ModuleOp;
 #define GEN_PASS_DECL_CONVERTBUFFERIZATIONTOMEMREF
 #include "mlir/Conversion/Passes.h.inc"
 
-std::unique_ptr<OperationPass<ModuleOp>> createBufferizationToMemRefPass();
+std::unique_ptr<Pass> createBufferizationToMemRefPass();
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_BUFFERIZATIONTOMEMREF_BUFFERIZATIONTOMEMREF_H

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cb06b917b6a9b9..39dd075bc46f17 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -187,8 +187,7 @@ def ConvertAsyncToLLVMPass : Pass<"convert-async-to-llvm", "ModuleOp"> {
 // BufferizationToMemRef
 //===----------------------------------------------------------------------===//
 
-def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref",
-                                        "mlir::ModuleOp"> {
+def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref"> {
   let summary = "Convert operations from the Bufferization dialect to the "
                 "MemRef dialect";
   let description = [{

diff  --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index 6225e010784789..5c83ae2b36bb28 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -409,6 +409,11 @@ class DeallocOpConversion
     if (adaptor.getMemrefs().size() == 1)
       return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
 
+    if (!deallocHelperFunc)
+      return op->emitError(
+          "library function required for generic lowering, but cannot be "
+          "automatically inserted when operating on functions");
+
     return rewriteGeneralCase(op, adaptor, rewriter);
   }
 
@@ -620,21 +625,29 @@ struct BufferizationToMemRefPass
   BufferizationToMemRefPass() = default;
 
   void runOnOperation() override {
-    ModuleOp module = cast<ModuleOp>(getOperation());
-    OpBuilder builder =
-        OpBuilder::atBlockBegin(&module.getBodyRegion().front());
-    SymbolTable symbolTable(module);
+    if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) {
+      emitError(getOperation()->getLoc(),
+                "root operation must be a builtin.module or a function");
+      signalPassFailure();
+      return;
+    }
 
-    // Build dealloc helper function if there are deallocs.
     func::FuncOp helperFuncOp;
-    getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
-      if (deallocOp.getMemrefs().size() > 1) {
-        helperFuncOp = DeallocOpConversion::buildDeallocationHelperFunction(
-            builder, getOperation()->getLoc(), symbolTable);
-        return WalkResult::interrupt();
-      }
-      return WalkResult::advance();
-    });
+    if (auto module = dyn_cast<ModuleOp>(getOperation())) {
+      OpBuilder builder =
+          OpBuilder::atBlockBegin(&module.getBodyRegion().front());
+      SymbolTable symbolTable(module);
+
+      // Build dealloc helper function if there are deallocs.
+      getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
+        if (deallocOp.getMemrefs().size() > 1) {
+          helperFuncOp = DeallocOpConversion::buildDeallocationHelperFunction(
+              builder, getOperation()->getLoc(), symbolTable);
+          return WalkResult::interrupt();
+        }
+        return WalkResult::advance();
+      });
+    }
 
     RewritePatternSet patterns(&getContext());
     patterns.add<CloneOpConversion>(patterns.getContext());
@@ -652,7 +665,6 @@ struct BufferizationToMemRefPass
 };
 } // namespace
 
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::createBufferizationToMemRefPass() {
+std::unique_ptr<Pass> mlir::createBufferizationToMemRefPass() {
   return std::make_unique<BufferizationToMemRefPass>();
 }

diff  --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref-func.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref-func.mlir
new file mode 100644
index 00000000000000..a6dc2c76184cdb
--- /dev/null
+++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref-func.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt -verify-diagnostics --pass-pipeline="builtin.module(func.func(convert-bufferization-to-memref))" -split-input-file %s | FileCheck %s
+
+// CHECK-NOT: func @deallocHelper
+// CHECK-LABEL: func @conversion_dealloc_simple
+// CHECK-SAME: [[ARG0:%.+]]: memref<2xf32>
+// CHECK-SAME: [[ARG1:%.+]]: i1
+func.func @conversion_dealloc_simple(%arg0: memref<2xf32>, %arg1: i1) {
+  bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1)
+  return
+}
+
+//      CHECk: scf.if [[ARG1]] {
+// CHECk-NEXT:   memref.dealloc [[ARG0]] : memref<2xf32>
+// CHECk-NEXT: }
+// CHECk-NEXT: return
+
+// -----
+
+func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+  // expected-error @below {{library function required for generic lowering, but cannot be automatically inserted when operating on functions}}
+  // expected-error @below {{failed to legalize operation 'bufferization.dealloc' that was explicitly marked illegal}}
+  %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+  return %0#0, %0#1 : i1, i1
+}


        


More information about the Mlir-commits mailing list