[Mlir-commits] [mlir] [mlir][Linalg] Add transform to convert linalg.copy into memref.copy (PR #132422)
Pablo Antonio Martinez
llvmlistbot at llvm.org
Fri Mar 28 10:38:58 PDT 2025
https://github.com/pabloantoniom updated https://github.com/llvm/llvm-project/pull/132422
>From 7cf5809d26a5e6093a4e5c8e353a37164ea22c55 Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Fri, 21 Mar 2025 16:40:18 +0000
Subject: [PATCH 1/2] [mlir][Linalg] Add transform to convert linalg.copy into
memref.copy
Targeted rewrite of a linalg.copy on memrefs to a memref.copy.
This is useful when bufferizing copies to a linalg.copy, later applying
some transformations (for instance, tiling), and then rewriting the
copy into a memref.copy. If the input linalg.copy has different element
type in the source and destination, the transformation is rejected.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 34 +++++++++
.../TransformOps/LinalgTransformOps.cpp | 54 ++++++++++++++
.../transform-op-linalg-copy-to-memref.mlir | 70 +++++++++++++++++++
3 files changed, 158 insertions(+)
create mode 100644 mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 12080cee85c9d..8406d170d882e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -559,6 +559,40 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
}];
}
+//===----------------------------------------------------------------------===//
+// LinalgCopyToMemrefOp
+//===----------------------------------------------------------------------===//
+
+def LinalgCopyToMemrefOp :
+ Op<Transform_Dialect, "structured.linalg_copy_to_memref",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformEachOpTrait, TransformOpInterface]> {
+ let description = [{
+ Targeted rewrite of a linalg.copy on memrefs to a memref.copy.
+ This is useful when bufferizing copies to a linalg.copy, later applying some
+ transformations, and then rewriting the copy into a memref.copy.
+ If the input has different element type in the source and destination,
+ the transformation is rejected.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat = "$target attr-dict `:` "
+ "functional-type(operands, results) ";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target)>,
+ ];
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// LowerPackOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5d65e7e4666c3..bfebb0fbbf938 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1176,6 +1176,60 @@ LogicalResult transform::InterchangeOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// LinalgCopyToMemrefOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
+ transform::TransformRewriter &rewriter, Operation *targetOp,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+
+ // Check if the target can be converted
+ if (!isa<linalg::CopyOp>(targetOp)) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "only linalg.copy target ops are supported";
+ diag.attachNote(targetOp->getLoc()) << "target op";
+ return diag;
+ }
+
+ auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
+ if (!copyOp.hasPureBufferSemantics()) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "linalg.copy on tensors cannot be transformed into memref.copy";
+ diag.attachNote(targetOp->getLoc()) << "target op";
+ return diag;
+ }
+
+ SmallVector<Value> inputs = copyOp.getInputs();
+ SmallVector<Value> outputs = copyOp.getOutputs();
+ assert(inputs.size() == 1 && "expected linalg copy op with one input");
+ assert(outputs.size() == 1 && "expected memref copy op with one output");
+ Value input = inputs.front();
+ Value output = outputs.front();
+
+ // linalg.copy supports different element types on source/dest whereas
+ // memref.copy does not, so we must check here that the types are the same,
+ // otherwise reject the transformation.
+ if (!dyn_cast<ShapedType>(input.getType()) ||
+ cast<ShapedType>(input.getType()).getElementType() !=
+ cast<ShapedType>(output.getType()).getElementType()) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "linalg.copy with different source and "
+ "destination element types is not supported";
+ diag.attachNote(targetOp->getLoc()) << "target op";
+ return diag;
+ }
+
+ // Target can be converted, do it.
+ auto memrefCopyOp =
+ rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
+
+ results.push_back(memrefCopyOp);
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// LowerPackOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir b/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
new file mode 100644
index 0000000000000..cd376ef1eb337
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt -transform-interpreter %s --split-input-file --allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+
+// CHECK: func.func @linalg_copy_to_memref_copy(%[[INPUT:.*]]: memref<128x64xf32>, %[[OUTPUT:.*]]: memref<128x64xf32>) {
+// CHECK: memref.copy %[[INPUT]], %[[OUTPUT]] : memref<128x64xf32> to memref<128x64xf32>
+// CHECK: return
+// CHECK: }
+
+func.func @linalg_copy_to_memref_copy(%input : memref<128x64xf32>, %output : memref<128x64xf32>) {
+ linalg.copy ins(%input : memref<128x64xf32>) outs(%output : memref<128x64xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_tensors(%input : tensor<128x64xf32>, %output : tensor<128x64xf32>) -> tensor<128x64xf32> {
+ // expected-note @below {{target op}}
+ %0 = linalg.copy ins(%input : tensor<128x64xf32>) outs(%output : tensor<128x64xf32>) -> tensor<128x64xf32>
+ return %0 : tensor<128x64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{linalg.copy on tensors cannot be transformed into memref.copy}}
+ %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_different_element(%input : memref<128x64xf32>, %output : memref<128x64xf64>) {
+ // expected-note @below {{target op}}
+ linalg.copy ins(%input : memref<128x64xf32>) outs(%output : memref<128x64xf64>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{linalg.copy with different source and destination element types is not supported}}
+ %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_scalar(%input : f64, %output : memref<128x64xf64>) {
+ // expected-note @below {{target op}}
+ linalg.copy ins(%input : f64) outs(%output : memref<128x64xf64>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{linalg.copy with different source and destination element types is not supported}}
+ %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
>From 428715b1add8ddb5624d7d42b7ae4c96317cb8d7 Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Fri, 28 Mar 2025 17:36:03 +0000
Subject: [PATCH 2/2] Fix comments and add new test with strides
---
.../Linalg/TransformOps/LinalgTransformOps.td | 4 +--
.../TransformOps/LinalgTransformOps.cpp | 29 ++++++++++++------
.../transform-op-linalg-copy-to-memref.mlir | 30 +++++++++++++++++--
3 files changed, 49 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 8406d170d882e..15ea5e7bf7159 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -571,8 +571,8 @@ def LinalgCopyToMemrefOp :
Targeted rewrite of a linalg.copy on memrefs to a memref.copy.
This is useful when bufferizing copies to a linalg.copy, later applying some
transformations, and then rewriting the copy into a memref.copy.
- If the input has different element type in the source and destination,
- the transformation is rejected.
+ If the element types of the source and destination differ, or if the source
+ is a scalar, the transform produces a silenceable failure.
}];
let arguments = (ins TransformHandleTypeInterface:$target);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bfebb0fbbf938..c90ebe4487ca4 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1185,7 +1185,7 @@ DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
- // Check if the target can be converted
+ // Check if the target can be converted.
if (!isa<linalg::CopyOp>(targetOp)) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "only linalg.copy target ops are supported";
@@ -1197,7 +1197,7 @@ DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
if (!copyOp.hasPureBufferSemantics()) {
DiagnosedSilenceableFailure diag =
emitSilenceableError()
- << "linalg.copy on tensors cannot be transformed into memref.copy";
+ << "cannot transform a linalg.copy on tensors into a memref.copy";
diag.attachNote(targetOp->getLoc()) << "target op";
return diag;
}
@@ -1210,14 +1210,25 @@ DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
Value output = outputs.front();
// linalg.copy supports different element types on source/dest whereas
- // memref.copy does not, so we must check here that the types are the same,
- // otherwise reject the transformation.
- if (!dyn_cast<ShapedType>(input.getType()) ||
- cast<ShapedType>(input.getType()).getElementType() !=
- cast<ShapedType>(output.getType()).getElementType()) {
+ // memref.copy does not, so we must check that the source and dest types can
+ // be handled by memref.copy and otherwise reject the transformation.
+ if (!dyn_cast<ShapedType>(input.getType())) {
DiagnosedSilenceableFailure diag =
- emitSilenceableError() << "linalg.copy with different source and "
- "destination element types is not supported";
+ emitSilenceableError()
+ << "cannot transform a linalg.copy which input has no shape";
+ diag.attachNote(targetOp->getLoc()) << "target op";
+ return diag;
+ }
+
+ // linalg.copy destination must be a shaped type.
+ assert(dyn_cast<ShapedType>(output.getType()));
+
+ if (cast<ShapedType>(input.getType()).getElementType() !=
+ cast<ShapedType>(output.getType()).getElementType()) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "cannot transform a linalg.copy with different source and "
+ "destination element types ";
diag.attachNote(targetOp->getLoc()) << "target op";
return diag;
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir b/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
index cd376ef1eb337..7280ccbea2563 100644
--- a/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
@@ -20,6 +20,30 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK: func.func @linalg_copy_to_memref_copy_strides(%[[INPUT:.*]]: memref<128x32xf32>, %[[OUTPUT:.*]]: memref<128x64xf32>) {
+// CHECK: %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<128x64xf32>
+// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][0, 32] [128, 32] [1, 1] : memref<128x64xf32> to memref<128x32xf32, strided<[64, 1], offset: 32>>
+// CHECK: memref.copy %[[INPUT]], %[[SUBVIEW]] : memref<128x32xf32> to memref<128x32xf32, strided<[64, 1], offset: 32>>
+// CHECK: return
+// CHECK: }
+
+func.func @linalg_copy_to_memref_copy_strides(%input : memref<128x32xf32>, %output : memref<128x64xf32>) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<128x64xf32>
+ %subview = memref.subview %alloc[0, 32] [128, 32] [1, 1] : memref<128x64xf32> to memref<128x32xf32, strided<[64, 1], offset: 32>>
+ linalg.copy ins(%input : memref<128x32xf32>) outs(%subview : memref<128x32xf32, strided<[64, 1], offset: 32>>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
func.func @linalg_copy_to_memref_copy_tensors(%input : tensor<128x64xf32>, %output : tensor<128x64xf32>) -> tensor<128x64xf32> {
// expected-note @below {{target op}}
%0 = linalg.copy ins(%input : tensor<128x64xf32>) outs(%output : tensor<128x64xf32>) -> tensor<128x64xf32>
@@ -29,7 +53,7 @@ func.func @linalg_copy_to_memref_copy_tensors(%input : tensor<128x64xf32>, %outp
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-error @below {{linalg.copy on tensors cannot be transformed into memref.copy}}
+ // expected-error @below {{cannot transform a linalg.copy on tensors into a memref.copy}}
%1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
@@ -46,7 +70,7 @@ func.func @linalg_copy_to_memref_copy_different_element(%input : memref<128x64xf
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-error @below {{linalg.copy with different source and destination element types is not supported}}
+ // expected-error @below {{cannot transform a linalg.copy with different source and destination element types}}
%1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
@@ -63,7 +87,7 @@ func.func @linalg_copy_to_memref_copy_scalar(%input : f64, %output : memref<128x
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-error @below {{linalg.copy with different source and destination element types is not supported}}
+ // expected-error @below {{cannot transform a linalg.copy which input has no shape}}
%1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
More information about the Mlir-commits
mailing list