[Mlir-commits] [mlir] [mlir][Linalg] Add transform to convert linalg.copy into memref.copy (PR #132422)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 21 10:00:14 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Pablo Antonio Martinez (pabloantoniom)

<details>
<summary>Changes</summary>

Targeted rewrite of a linalg.copy on memrefs to a memref.copy.

This is useful when bufferizing copies to a linalg.copy, later applying some transformations (for instance, tiling), and then rewriting the copy into a memref.copy. If the input linalg.copy has different element type in the source and destination, the transformation is rejected.

---
Full diff: https://github.com/llvm/llvm-project/pull/132422.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+34) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+54) 
- (added) mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir (+70) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 12080cee85c9d..8406d170d882e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -559,6 +559,40 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// LinalgCopyToMemrefOp
+//===----------------------------------------------------------------------===//
+
+def LinalgCopyToMemrefOp :
+    Op<Transform_Dialect, "structured.linalg_copy_to_memref",
+      [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+       TransformEachOpTrait, TransformOpInterface]> {
+  let description = [{
+    Targeted rewrite of a linalg.copy on memrefs to a memref.copy.
+    This is useful when bufferizing copies to a linalg.copy, later applying some
+    transformations, and then rewriting the copy into a memref.copy.
+    If the input has different element type in the source and destination,
+    the transformation is rejected.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat = "$target attr-dict `:` "
+                       "functional-type(operands, results) ";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>,
+  ];
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // LowerPackOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5d65e7e4666c3..bfebb0fbbf938 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1176,6 +1176,60 @@ LogicalResult transform::InterchangeOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// LinalgCopyToMemrefOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *targetOp,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+
+  // Check if the target can be converted
+  if (!isa<linalg::CopyOp>(targetOp)) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "only linalg.copy target ops are supported";
+    diag.attachNote(targetOp->getLoc()) << "target op";
+    return diag;
+  }
+
+  auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
+  if (!copyOp.hasPureBufferSemantics()) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError()
+        << "linalg.copy on tensors cannot be transformed into memref.copy";
+    diag.attachNote(targetOp->getLoc()) << "target op";
+    return diag;
+  }
+
+  SmallVector<Value> inputs = copyOp.getInputs();
+  SmallVector<Value> outputs = copyOp.getOutputs();
+  assert(inputs.size() == 1 && "expected linalg copy op with one input");
+  assert(outputs.size() == 1 && "expected memref copy op with one output");
+  Value input = inputs.front();
+  Value output = outputs.front();
+
+  // linalg.copy supports different element types on source/dest whereas
+  // memref.copy does not, so we must check here that the types are the same,
+  // otherwise reject the transformation.
+  if (!dyn_cast<ShapedType>(input.getType()) ||
+      cast<ShapedType>(input.getType()).getElementType() !=
+          cast<ShapedType>(output.getType()).getElementType()) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "linalg.copy with different source and "
+                                  "destination element types is not supported";
+    diag.attachNote(targetOp->getLoc()) << "target op";
+    return diag;
+  }
+
+  // Target can be converted, do it.
+  auto memrefCopyOp =
+      rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
+
+  results.push_back(memrefCopyOp);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // LowerPackOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir b/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
new file mode 100644
index 0000000000000..cd376ef1eb337
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt -transform-interpreter %s --split-input-file --allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+
+// CHECK:  func.func @linalg_copy_to_memref_copy(%[[INPUT:.*]]: memref<128x64xf32>, %[[OUTPUT:.*]]: memref<128x64xf32>) {
+// CHECK:    memref.copy %[[INPUT]], %[[OUTPUT]] : memref<128x64xf32> to memref<128x64xf32>
+// CHECK:    return
+// CHECK:  }
+
+func.func @linalg_copy_to_memref_copy(%input : memref<128x64xf32>, %output : memref<128x64xf32>) {
+  linalg.copy ins(%input : memref<128x64xf32>) outs(%output : memref<128x64xf32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_tensors(%input : tensor<128x64xf32>, %output : tensor<128x64xf32>) -> tensor<128x64xf32> {
+  // expected-note @below {{target op}}
+  %0 = linalg.copy ins(%input : tensor<128x64xf32>) outs(%output : tensor<128x64xf32>) -> tensor<128x64xf32>
+  return %0 : tensor<128x64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+	  // expected-error @below {{linalg.copy on tensors cannot be transformed into memref.copy}}
+    %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_different_element(%input : memref<128x64xf32>, %output : memref<128x64xf64>) {
+  // expected-note @below {{target op}}
+  linalg.copy ins(%input : memref<128x64xf32>) outs(%output : memref<128x64xf64>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+	  // expected-error @below {{linalg.copy with different source and destination element types is not supported}}
+    %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_scalar(%input : f64, %output : memref<128x64xf64>) {
+  // expected-note @below {{target op}}
+  linalg.copy ins(%input : f64) outs(%output : memref<128x64xf64>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+	  // expected-error @below {{linalg.copy with different source and destination element types is not supported}}
+    %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/132422


More information about the Mlir-commits mailing list