[Mlir-commits] [mlir] 88f4292 - [mlir][bufferization] OneShotBufferizeOp: Add options to use linalg.copy
Matthias Springer
llvmlistbot at llvm.org
Fri Jul 14 04:38:56 PDT 2023
Author: Matthias Springer
Date: 2023-07-14T13:34:22+02:00
New Revision: 88f4292a165cf0b65aca8632840d73e2a094b05f
URL: https://github.com/llvm/llvm-project/commit/88f4292a165cf0b65aca8632840d73e2a094b05f
DIFF: https://github.com/llvm/llvm-project/commit/88f4292a165cf0b65aca8632840d73e2a094b05f.diff
LOG: [mlir][bufferization] OneShotBufferizeOp: Add options to use linalg.copy
This new option allows users to specify a custom memcpy op.
Differential Revision: https://reviews.llvm.org/D155280
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt
mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
index 807a63d36f3936..0a32afd0e19fe9 100644
--- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
@@ -58,10 +58,12 @@ def OneShotBufferizeOp
DefaultValuedAttr<BoolAttr, "false">:$bufferize_function_boundaries,
DefaultValuedAttr<BoolAttr, "true">:$create_deallocs,
DefaultValuedAttr<BoolAttr, "false">:$test_analysis_only,
- DefaultValuedAttr<BoolAttr, "false">:$print_conflicts);
+ DefaultValuedAttr<BoolAttr, "false">:$print_conflicts,
+ DefaultValuedAttr<StrAttr, "\"memref.copy\"">:$memcpy_op);
let results = (outs TransformHandleTypeInterface:$transformed);
+ let hasVerifier = 1;
let assemblyFormat = [{
(`layout` `{` $function_boundary_type_conversion^ `}`)?
$target attr-dict `:` functional-type($target, results)
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index 9c23ad6bfd9022..f866484f485678 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
@@ -25,6 +26,12 @@ using namespace mlir::transform;
// OneShotBufferizeOp
//===----------------------------------------------------------------------===//
+LogicalResult transform::OneShotBufferizeOp::verify() {
+ if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
+ return emitOpError() << "unsupported memcpy op";
+ return success();
+}
+
DiagnosedSilenceableFailure
transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
TransformResults &transformResults,
@@ -39,6 +46,19 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
if (getFunctionBoundaryTypeConversion().has_value())
options.setFunctionBoundaryTypeConversion(
*getFunctionBoundaryTypeConversion());
+ if (getMemcpyOp() == "memref.copy") {
+ options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
+ b.create<memref::CopyOp>(loc, from, to);
+ return success();
+ };
+ } else if (getMemcpyOp() == "linalg.copy") {
+ options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
+ b.create<linalg::CopyOp>(loc, from, to);
+ return success();
+ };
+ } else {
+ llvm_unreachable("invalid copy op");
+ }
auto payloadOps = state.getPayloadOps(getTarget());
for (Operation *target : payloadOps) {
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt
index 51e5b0a099280b..10ddabd7d84015 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRBufferizationTransformOps
MLIRIR
MLIRBufferizationDialect
MLIRBufferizationTransforms
+ MLIRLinalgDialect
MLIRParser
MLIRPDLDialect
MLIRSideEffectInterfaces
diff --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
index c4a40448919437..94550c8d4374a5 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
@@ -28,6 +28,35 @@ func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf3
// -----
+// Emit linalg.copy instead of memref.copy.
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.bufferization.one_shot_bufferize %0 {memcpy_op = "linalg.copy"} : (!transform.any_op) -> !transform.any_op
+}
+
+// CHECK-LABEL: func @test_function(
+// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
+// CHECK-NOT: memref.copy
+func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
+ %c0 = arith.constant 0 : index
+
+ // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
+ // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
+ // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
+ // CHECK: linalg.copy ins(%[[A_memref]] : memref<{{.*}}>) outs(%[[alloc]]
+ // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
+ // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
+ %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
+
+ // CHECK: memref.dealloc %[[alloc]]
+ // CHECK: return %[[res_tensor]]
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
// Test analysis of One-Shot Bufferize only.
transform.sequence failures(propagate) {
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 2097da8a1e0b18..14027fcb038ce5 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -11477,6 +11477,7 @@ cc_library(
":BufferizationTransformOpsIncGen",
":BufferizationTransforms",
":IR",
+ ":LinalgDialect",
":MemRefDialect",
":Parser",
":SideEffectInterfaces",
More information about the Mlir-commits
mailing list