[Mlir-commits] [mlir] [mlir][Linalg] Add transform to convert linalg.copy into memref.copy (PR #132422)
Pablo Antonio Martinez
llvmlistbot at llvm.org
Fri Mar 21 09:59:41 PDT 2025
https://github.com/pabloantoniom created https://github.com/llvm/llvm-project/pull/132422
Targeted rewrite of a linalg.copy on memrefs to a memref.copy.
This is useful when bufferizing copies to a linalg.copy, later applying some transformations (for instance, tiling), and then rewriting the copy into a memref.copy. If the input linalg.copy has different element type in the source and destination, the transformation is rejected.
>From 7cf5809d26a5e6093a4e5c8e353a37164ea22c55 Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Fri, 21 Mar 2025 16:40:18 +0000
Subject: [PATCH] [mlir][Linalg] Add transform to convert linalg.copy into
memref.copy
Targeted rewrite of a linalg.copy on memrefs to a memref.copy.
This is useful when bufferizing copies to a linalg.copy, later applying
some transformations (for instance, tiling), and then rewriting the
copy into a memref.copy. If the input linalg.copy has different element
type in the source and destination, the transformation is rejected.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 34 +++++++++
.../TransformOps/LinalgTransformOps.cpp | 54 ++++++++++++++
.../transform-op-linalg-copy-to-memref.mlir | 70 +++++++++++++++++++
3 files changed, 158 insertions(+)
create mode 100644 mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 12080cee85c9d..8406d170d882e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -559,6 +559,40 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
}];
}
+//===----------------------------------------------------------------------===//
+// LinalgCopyToMemrefOp
+//===----------------------------------------------------------------------===//
+
+def LinalgCopyToMemrefOp :
+ Op<Transform_Dialect, "structured.linalg_copy_to_memref",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformEachOpTrait, TransformOpInterface]> {
+ let description = [{
+ Targeted rewrite of a linalg.copy on memrefs to a memref.copy.
+ This is useful when bufferizing copies to a linalg.copy, later applying some
+ transformations, and then rewriting the copy into a memref.copy.
+ If the input has different element type in the source and destination,
+ the transformation is rejected.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat = "$target attr-dict `:` "
+ "functional-type(operands, results) ";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target)>,
+ ];
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// LowerPackOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5d65e7e4666c3..bfebb0fbbf938 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1176,6 +1176,60 @@ LogicalResult transform::InterchangeOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// LinalgCopyToMemrefOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
+ transform::TransformRewriter &rewriter, Operation *targetOp,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+
+ // Check if the target can be converted
+ if (!isa<linalg::CopyOp>(targetOp)) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "only linalg.copy target ops are supported";
+ diag.attachNote(targetOp->getLoc()) << "target op";
+ return diag;
+ }
+
+ auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
+ if (!copyOp.hasPureBufferSemantics()) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "linalg.copy on tensors cannot be transformed into memref.copy";
+ diag.attachNote(targetOp->getLoc()) << "target op";
+ return diag;
+ }
+
+ SmallVector<Value> inputs = copyOp.getInputs();
+ SmallVector<Value> outputs = copyOp.getOutputs();
+ assert(inputs.size() == 1 && "expected linalg copy op with one input");
+ assert(outputs.size() == 1 && "expected memref copy op with one output");
+ Value input = inputs.front();
+ Value output = outputs.front();
+
+ // linalg.copy supports different element types on source/dest whereas
+ // memref.copy does not, so we must check here that the types are the same,
+ // otherwise reject the transformation.
+ if (!dyn_cast<ShapedType>(input.getType()) ||
+ cast<ShapedType>(input.getType()).getElementType() !=
+ cast<ShapedType>(output.getType()).getElementType()) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "linalg.copy with different source and "
+ "destination element types is not supported";
+ diag.attachNote(targetOp->getLoc()) << "target op";
+ return diag;
+ }
+
+ // Target can be converted, do it.
+ auto memrefCopyOp =
+ rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
+
+ results.push_back(memrefCopyOp);
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// LowerPackOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir b/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
new file mode 100644
index 0000000000000..cd376ef1eb337
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt -transform-interpreter %s --split-input-file --allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+
+// CHECK: func.func @linalg_copy_to_memref_copy(%[[INPUT:.*]]: memref<128x64xf32>, %[[OUTPUT:.*]]: memref<128x64xf32>) {
+// CHECK: memref.copy %[[INPUT]], %[[OUTPUT]] : memref<128x64xf32> to memref<128x64xf32>
+// CHECK: return
+// CHECK: }
+
+func.func @linalg_copy_to_memref_copy(%input : memref<128x64xf32>, %output : memref<128x64xf32>) {
+ linalg.copy ins(%input : memref<128x64xf32>) outs(%output : memref<128x64xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_tensors(%input : tensor<128x64xf32>, %output : tensor<128x64xf32>) -> tensor<128x64xf32> {
+ // expected-note @below {{target op}}
+ %0 = linalg.copy ins(%input : tensor<128x64xf32>) outs(%output : tensor<128x64xf32>) -> tensor<128x64xf32>
+ return %0 : tensor<128x64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{linalg.copy on tensors cannot be transformed into memref.copy}}
+ %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_different_element(%input : memref<128x64xf32>, %output : memref<128x64xf64>) {
+ // expected-note @below {{target op}}
+ linalg.copy ins(%input : memref<128x64xf32>) outs(%output : memref<128x64xf64>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{linalg.copy with different source and destination element types is not supported}}
+ %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_scalar(%input : f64, %output : memref<128x64xf64>) {
+ // expected-note @below {{target op}}
+ linalg.copy ins(%input : f64) outs(%output : memref<128x64xf64>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{linalg.copy with different source and destination element types is not supported}}
+ %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list