[Mlir-commits] [mlir] a338f80 - [mlir][Linalg] Add transform to convert linalg.copy into memref.copy (#132422)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 1 05:39:36 PDT 2025
Author: Pablo Antonio Martinez
Date: 2025-04-01T13:39:33+01:00
New Revision: a338f80ddcb97edd275c8bf949b1fab0c7d1049e
URL: https://github.com/llvm/llvm-project/commit/a338f80ddcb97edd275c8bf949b1fab0c7d1049e
DIFF: https://github.com/llvm/llvm-project/commit/a338f80ddcb97edd275c8bf949b1fab0c7d1049e.diff
LOG: [mlir][Linalg] Add transform to convert linalg.copy into memref.copy (#132422)
Targeted rewrite of a linalg.copy on memrefs to a memref.copy.
This is useful when bufferizing copies to a linalg.copy, applying some
transformations, and then rewriting the copy into a memref.copy.
If the element types of the source and destination differ, or if the
source is a scalar, the transform produces a silenceable failure.
Added:
mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 12080cee85c9d..15ea5e7bf7159 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -559,6 +559,40 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
}];
}
+//===----------------------------------------------------------------------===//
+// LinalgCopyToMemrefOp
+//===----------------------------------------------------------------------===//
+
+def LinalgCopyToMemrefOp :
+ Op<Transform_Dialect, "structured.linalg_copy_to_memref",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformEachOpTrait, TransformOpInterface]> {
+ let description = [{
+ Targeted rewrite of a linalg.copy on memrefs to a memref.copy.
+ This is useful when bufferizing copies to a linalg.copy, later applying some
+ transformations, and then rewriting the copy into a memref.copy.
+ If the element types of the source and destination
diff er, or if the source
+ is a scalar, the transform produces a silenceable failure.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat = "$target attr-dict `:` "
+ "functional-type(operands, results) ";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target)>,
+ ];
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// LowerPackOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5d65e7e4666c3..c90ebe4487ca4 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1176,6 +1176,71 @@ LogicalResult transform::InterchangeOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// LinalgCopyToMemrefOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
+ transform::TransformRewriter &rewriter, Operation *targetOp,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+
+ // Check if the target can be converted.
+ if (!isa<linalg::CopyOp>(targetOp)) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "only linalg.copy target ops are supported";
+ diag.attachNote(targetOp->getLoc()) << "target op";
+ return diag;
+ }
+
+ auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
+ if (!copyOp.hasPureBufferSemantics()) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "cannot transform a linalg.copy on tensors into a memref.copy";
+ diag.attachNote(targetOp->getLoc()) << "target op";
+ return diag;
+ }
+
+ SmallVector<Value> inputs = copyOp.getInputs();
+ SmallVector<Value> outputs = copyOp.getOutputs();
+ assert(inputs.size() == 1 && "expected linalg copy op with one input");
+ assert(outputs.size() == 1 && "expected memref copy op with one output");
+ Value input = inputs.front();
+ Value output = outputs.front();
+
+ // linalg.copy supports
diff erent element types on source/dest whereas
+ // memref.copy does not, so we must check that the source and dest types can
+ // be handled by memref.copy and otherwise reject the transformation.
+ if (!dyn_cast<ShapedType>(input.getType())) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "cannot transform a linalg.copy which input has no shape";
+ diag.attachNote(targetOp->getLoc()) << "target op";
+ return diag;
+ }
+
+ // linalg.copy destination must be a shaped type.
+ assert(dyn_cast<ShapedType>(output.getType()));
+
+ if (cast<ShapedType>(input.getType()).getElementType() !=
+ cast<ShapedType>(output.getType()).getElementType()) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "cannot transform a linalg.copy with
diff erent source and "
+ "destination element types ";
+ diag.attachNote(targetOp->getLoc()) << "target op";
+ return diag;
+ }
+
+ // Target can be converted, do it.
+ auto memrefCopyOp =
+ rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
+
+ results.push_back(memrefCopyOp);
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// LowerPackOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir b/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
new file mode 100644
index 0000000000000..7280ccbea2563
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
@@ -0,0 +1,94 @@
+// RUN: mlir-opt -transform-interpreter %s --split-input-file --allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+
+// CHECK: func.func @linalg_copy_to_memref_copy(%[[INPUT:.*]]: memref<128x64xf32>, %[[OUTPUT:.*]]: memref<128x64xf32>) {
+// CHECK: memref.copy %[[INPUT]], %[[OUTPUT]] : memref<128x64xf32> to memref<128x64xf32>
+// CHECK: return
+// CHECK: }
+
+func.func @linalg_copy_to_memref_copy(%input : memref<128x64xf32>, %output : memref<128x64xf32>) {
+ linalg.copy ins(%input : memref<128x64xf32>) outs(%output : memref<128x64xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK: func.func @linalg_copy_to_memref_copy_strides(%[[INPUT:.*]]: memref<128x32xf32>, %[[OUTPUT:.*]]: memref<128x64xf32>) {
+// CHECK: %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<128x64xf32>
+// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][0, 32] [128, 32] [1, 1] : memref<128x64xf32> to memref<128x32xf32, strided<[64, 1], offset: 32>>
+// CHECK: memref.copy %[[INPUT]], %[[SUBVIEW]] : memref<128x32xf32> to memref<128x32xf32, strided<[64, 1], offset: 32>>
+// CHECK: return
+// CHECK: }
+
+func.func @linalg_copy_to_memref_copy_strides(%input : memref<128x32xf32>, %output : memref<128x64xf32>) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<128x64xf32>
+ %subview = memref.subview %alloc[0, 32] [128, 32] [1, 1] : memref<128x64xf32> to memref<128x32xf32, strided<[64, 1], offset: 32>>
+ linalg.copy ins(%input : memref<128x32xf32>) outs(%subview : memref<128x32xf32, strided<[64, 1], offset: 32>>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_tensors(%input : tensor<128x64xf32>, %output : tensor<128x64xf32>) -> tensor<128x64xf32> {
+ // expected-note @below {{target op}}
+ %0 = linalg.copy ins(%input : tensor<128x64xf32>) outs(%output : tensor<128x64xf32>) -> tensor<128x64xf32>
+ return %0 : tensor<128x64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{cannot transform a linalg.copy on tensors into a memref.copy}}
+ %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_
diff erent_element(%input : memref<128x64xf32>, %output : memref<128x64xf64>) {
+ // expected-note @below {{target op}}
+ linalg.copy ins(%input : memref<128x64xf32>) outs(%output : memref<128x64xf64>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{cannot transform a linalg.copy with
diff erent source and destination element types}}
+ %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_scalar(%input : f64, %output : memref<128x64xf64>) {
+ // expected-note @below {{target op}}
+ linalg.copy ins(%input : f64) outs(%output : memref<128x64xf64>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{cannot transform a linalg.copy which input has no shape}}
+ %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list