[Mlir-commits] [mlir] 3f6c0fb - [mlir][linalg][bufferize] Add MemCpyFn to AllocationCallbacks struct

Matthias Springer llvmlistbot at llvm.org
Thu Nov 4 18:49:22 PDT 2021


Author: Matthias Springer
Date: 2021-11-05T10:44:12+09:00
New Revision: 3f6c0fb2ff750c9246aee41eb8ad086518752edf

URL: https://github.com/llvm/llvm-project/commit/3f6c0fb2ff750c9246aee41eb8ad086518752edf
DIFF: https://github.com/llvm/llvm-project/commit/3f6c0fb2ff750c9246aee41eb8ad086518752edf.diff

LOG: [mlir][linalg][bufferize] Add MemCpyFn to AllocationCallbacks struct

This in preparation of decoupling BufferizableOpInterface, Comprehensive Bufferize and dialects.

The goal of this CL is to make `getResultBuffer` (and other `bufferize` functions) independent of `LinalgOps`.

Differential Revision: https://reviews.llvm.org/D112907

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
index e3b59d5daa60..94cb52b4bca5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
@@ -172,16 +172,28 @@ Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
 /// `defaultAllocationFn`.
 void defaultDeallocationFn(OpBuilder &b, Location loc, Value allocatedBuffer);
 
+/// Default memory copy function that is used by the comprehensive bufferization
+/// pass. Creates a `linalg.copy` op.
+void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to);
+
 /// Callback functions that are used by the comprehensive bufferization pass to
 /// allocate/deallocate memory. These default to use the
 /// `defaultAllocationFn`/`defaultDeallocationFn`, but can be overridden by the
 /// caller. The `deallocationFn` is gauranteed to recieve the `Value` returned
 /// by the `allocationFn`.
 struct AllocationCallbacks {
-  std::function<Optional<Value>(OpBuilder &b, Location loc, Value shapedValue)>
-      allocationFn = defaultAllocationFn;
-  std::function<void(OpBuilder &b, Location loc, Value v)> deallocationFn =
-      defaultDeallocationFn;
+  using AllocationFn =
+      std::function<Optional<Value>(OpBuilder &, Location, Value)>;
+  using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
+  using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
+
+  AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
+                      MemCpyFn copyFn)
+      : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {}
+
+  AllocationFn allocationFn;
+  DeallocationFn deallocationFn;
+  MemCpyFn memCpyFn;
 };
 
 /// Bufferize one particular op.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index dad19a94080d..a6b6e0131d19 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -1274,7 +1274,7 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
     if (!skipCopy) {
       // Set insertion point now that potential alloc/dealloc are introduced.
       b.setInsertionPoint(op);
-      b.create<CopyOp>(loc, operandBuffer, resultBuffer);
+      allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer);
     }
     return resultBuffer;
   }
@@ -1669,6 +1669,11 @@ void mlir::linalg::defaultDeallocationFn(OpBuilder &b, Location loc,
   b.create<memref::DeallocOp>(loc, allocatedBuffer);
 }
 
+void mlir::linalg::defaultMemCpyFn(OpBuilder &b, Location loc, Value from,
+                                   Value to) {
+  b.create<CopyOp>(loc, from, to);
+}
+
 LogicalResult mlir::linalg::bufferizeOp(
     Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
     AllocationCallbacks allocationFns,
@@ -2258,11 +2263,13 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
     // command line option. So this is set up at the start of the pass.
     if (useAlloca) {
       AllocationCallbacks allocaAllocationFns = {
-          allocationFnUsingAlloca, [](OpBuilder &b, Location loc, Value v) {}};
+          allocationFnUsingAlloca, [](OpBuilder &b, Location loc, Value v) {},
+          defaultMemCpyFn};
       allocationFns =
           std::make_unique<AllocationCallbacks>(std::move(allocaAllocationFns));
     } else {
-      allocationFns = std::make_unique<AllocationCallbacks>();
+      allocationFns = std::make_unique<AllocationCallbacks>(
+          defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn);
     }
   }
   ModuleOp moduleOp = getOperation();
@@ -3222,7 +3229,7 @@ struct ExtractSliceOpInterface
     if (alloc) {
       // Do not copy if the copied data is never read.
       if (isValueRead(extractSliceOp.result()))
-        b.create<CopyOp>(extractSliceOp.getLoc(), subView, alloc);
+        allocationFn.memCpyFn(b, extractSliceOp.getLoc(), subView, alloc);
       subView = alloc;
     }
 
@@ -3344,7 +3351,7 @@ struct InsertSliceOpInterface
           insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
       // Insert new alias.
       aliasInfo.insertNewBufferAlias(subView, dstMemref);
-      b.create<CopyOp>(insertSliceOp.getLoc(), srcMemref, subView);
+      allocationFn.memCpyFn(b, insertSliceOp.getLoc(), srcMemref, subView);
     }
 
     map(bvm, insertSliceOp.result(), dstMemref);


        


More information about the Mlir-commits mailing list