[Mlir-commits] [mlir] 88f4292 - [mlir][bufferization] OneShotBufferizeOp: Add options to use linalg.copy

Matthias Springer llvmlistbot at llvm.org
Fri Jul 14 04:38:56 PDT 2023


Author: Matthias Springer
Date: 2023-07-14T13:34:22+02:00
New Revision: 88f4292a165cf0b65aca8632840d73e2a094b05f

URL: https://github.com/llvm/llvm-project/commit/88f4292a165cf0b65aca8632840d73e2a094b05f
DIFF: https://github.com/llvm/llvm-project/commit/88f4292a165cf0b65aca8632840d73e2a094b05f.diff

LOG: [mlir][bufferization] OneShotBufferizeOp: Add options to use linalg.copy

This new option allows users to specify a custom memcpy op.

Differential Revision: https://reviews.llvm.org/D155280

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
    mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
    mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt
    mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
index 807a63d36f3936..0a32afd0e19fe9 100644
--- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
@@ -58,10 +58,12 @@ def OneShotBufferizeOp
       DefaultValuedAttr<BoolAttr, "false">:$bufferize_function_boundaries,
       DefaultValuedAttr<BoolAttr, "true">:$create_deallocs,
       DefaultValuedAttr<BoolAttr, "false">:$test_analysis_only,
-      DefaultValuedAttr<BoolAttr, "false">:$print_conflicts);
+      DefaultValuedAttr<BoolAttr, "false">:$print_conflicts,
+      DefaultValuedAttr<StrAttr, "\"memref.copy\"">:$memcpy_op);
 
   let results = (outs TransformHandleTypeInterface:$transformed);
 
+  let hasVerifier = 1;
   let assemblyFormat = [{
     (`layout` `{` $function_boundary_type_conversion^ `}`)?
     $target attr-dict `:` functional-type($target, results)

diff  --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index 9c23ad6bfd9022..f866484f485678 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
@@ -25,6 +26,12 @@ using namespace mlir::transform;
 // OneShotBufferizeOp
 //===----------------------------------------------------------------------===//
 
+LogicalResult transform::OneShotBufferizeOp::verify() {
+  if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
+    return emitOpError() << "unsupported memcpy op";
+  return success();
+}
+
 DiagnosedSilenceableFailure
 transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
                                      TransformResults &transformResults,
@@ -39,6 +46,19 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
   if (getFunctionBoundaryTypeConversion().has_value())
     options.setFunctionBoundaryTypeConversion(
         *getFunctionBoundaryTypeConversion());
+  if (getMemcpyOp() == "memref.copy") {
+    options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
+      b.create<memref::CopyOp>(loc, from, to);
+      return success();
+    };
+  } else if (getMemcpyOp() == "linalg.copy") {
+    options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
+      b.create<linalg::CopyOp>(loc, from, to);
+      return success();
+    };
+  } else {
+    llvm_unreachable("invalid copy op");
+  }
 
   auto payloadOps = state.getPayloadOps(getTarget());
   for (Operation *target : payloadOps) {

diff  --git a/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt
index 51e5b0a099280b..10ddabd7d84015 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRBufferizationTransformOps
   MLIRIR
   MLIRBufferizationDialect
   MLIRBufferizationTransforms
+  MLIRLinalgDialect
   MLIRParser
   MLIRPDLDialect
   MLIRSideEffectInterfaces

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
index c4a40448919437..94550c8d4374a5 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
@@ -28,6 +28,35 @@ func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf3
 
 // -----
 
+// Emit linalg.copy instead of memref.copy.
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.bufferization.one_shot_bufferize %0 {memcpy_op = "linalg.copy"} : (!transform.any_op) -> !transform.any_op
+}
+
+// CHECK-LABEL: func @test_function(
+//  CHECK-SAME:     %[[A:.*]]: tensor<?xf32>
+//   CHECK-NOT:   memref.copy
+func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
+  %c0 = arith.constant 0 : index
+
+  // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
+  // CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
+  // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
+  // CHECK: linalg.copy ins(%[[A_memref]] : memref<{{.*}}>) outs(%[[alloc]]
+  // CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
+  // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
+  %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
+
+  // CHECK: memref.dealloc %[[alloc]]
+  // CHECK: return %[[res_tensor]]
+  return %0 : tensor<?xf32>
+}
+
+// -----
+
 // Test analysis of One-Shot Bufferize only.
 
 transform.sequence failures(propagate) {

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 2097da8a1e0b18..14027fcb038ce5 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -11477,6 +11477,7 @@ cc_library(
         ":BufferizationTransformOpsIncGen",
         ":BufferizationTransforms",
         ":IR",
+        ":LinalgDialect",
         ":MemRefDialect",
         ":Parser",
         ":SideEffectInterfaces",


        


More information about the Mlir-commits mailing list