[Mlir-commits] [mlir] a338f80 - [mlir][Linalg] Add transform to convert linalg.copy into memref.copy (#132422)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 1 05:39:36 PDT 2025


Author: Pablo Antonio Martinez
Date: 2025-04-01T13:39:33+01:00
New Revision: a338f80ddcb97edd275c8bf949b1fab0c7d1049e

URL: https://github.com/llvm/llvm-project/commit/a338f80ddcb97edd275c8bf949b1fab0c7d1049e
DIFF: https://github.com/llvm/llvm-project/commit/a338f80ddcb97edd275c8bf949b1fab0c7d1049e.diff

LOG: [mlir][Linalg] Add transform to convert linalg.copy into memref.copy (#132422)

Targeted rewrite of a linalg.copy on memrefs to a memref.copy.

This is useful when bufferizing copies to a linalg.copy, applying some
transformations, and then rewriting the copy into a memref.copy.
If the element types of the source and destination differ, or if the
source is a scalar, the transform produces a silenceable failure.

Added: 
    mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 12080cee85c9d..15ea5e7bf7159 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -559,6 +559,40 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// LinalgCopyToMemrefOp
+//===----------------------------------------------------------------------===//
+
+def LinalgCopyToMemrefOp :
+    Op<Transform_Dialect, "structured.linalg_copy_to_memref",
+      [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+       TransformEachOpTrait, TransformOpInterface]> {
+  let description = [{
+    Targeted rewrite of a linalg.copy on memrefs to a memref.copy.
+    This is useful when bufferizing copies to a linalg.copy, later applying some
+    transformations, and then rewriting the copy into a memref.copy.
+    If the element types of the source and destination 
diff er, or if the source
+    is a scalar, the transform produces a silenceable failure.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat = "$target attr-dict `:` "
+                       "functional-type(operands, results) ";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>,
+  ];
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // LowerPackOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5d65e7e4666c3..c90ebe4487ca4 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1176,6 +1176,71 @@ LogicalResult transform::InterchangeOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// LinalgCopyToMemrefOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *targetOp,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+
+  // Check if the target can be converted.
+  if (!isa<linalg::CopyOp>(targetOp)) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "only linalg.copy target ops are supported";
+    diag.attachNote(targetOp->getLoc()) << "target op";
+    return diag;
+  }
+
+  auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
+  if (!copyOp.hasPureBufferSemantics()) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError()
+        << "cannot transform a linalg.copy on tensors into a memref.copy";
+    diag.attachNote(targetOp->getLoc()) << "target op";
+    return diag;
+  }
+
+  SmallVector<Value> inputs = copyOp.getInputs();
+  SmallVector<Value> outputs = copyOp.getOutputs();
+  assert(inputs.size() == 1 && "expected linalg copy op with one input");
+  assert(outputs.size() == 1 && "expected memref copy op with one output");
+  Value input = inputs.front();
+  Value output = outputs.front();
+
+  // linalg.copy supports 
diff erent element types on source/dest whereas
+  // memref.copy does not, so we must check that the source and dest types can
+  // be handled by memref.copy and otherwise reject the transformation.
+  if (!dyn_cast<ShapedType>(input.getType())) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError()
+        << "cannot transform a linalg.copy which input has no shape";
+    diag.attachNote(targetOp->getLoc()) << "target op";
+    return diag;
+  }
+
+  // linalg.copy destination must be a shaped type.
+  assert(dyn_cast<ShapedType>(output.getType()));
+
+  if (cast<ShapedType>(input.getType()).getElementType() !=
+      cast<ShapedType>(output.getType()).getElementType()) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError()
+        << "cannot transform a linalg.copy with 
diff erent source and "
+           "destination element types ";
+    diag.attachNote(targetOp->getLoc()) << "target op";
+    return diag;
+  }
+
+  // Target can be converted, do it.
+  auto memrefCopyOp =
+      rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
+
+  results.push_back(memrefCopyOp);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // LowerPackOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir b/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
new file mode 100644
index 0000000000000..7280ccbea2563
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
@@ -0,0 +1,94 @@
+// RUN: mlir-opt -transform-interpreter %s --split-input-file --allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+
+// CHECK:  func.func @linalg_copy_to_memref_copy(%[[INPUT:.*]]: memref<128x64xf32>, %[[OUTPUT:.*]]: memref<128x64xf32>) {
+// CHECK:    memref.copy %[[INPUT]], %[[OUTPUT]] : memref<128x64xf32> to memref<128x64xf32>
+// CHECK:    return
+// CHECK:  }
+
+func.func @linalg_copy_to_memref_copy(%input : memref<128x64xf32>, %output : memref<128x64xf32>) {
+  linalg.copy ins(%input : memref<128x64xf32>) outs(%output : memref<128x64xf32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK:  func.func @linalg_copy_to_memref_copy_strides(%[[INPUT:.*]]: memref<128x32xf32>, %[[OUTPUT:.*]]: memref<128x64xf32>) {
+// CHECK:    %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<128x64xf32>
+// CHECK:    %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][0, 32] [128, 32] [1, 1] : memref<128x64xf32> to memref<128x32xf32, strided<[64, 1], offset: 32>>
+// CHECK:    memref.copy %[[INPUT]], %[[SUBVIEW]] : memref<128x32xf32> to memref<128x32xf32, strided<[64, 1], offset: 32>>
+// CHECK:    return
+// CHECK:  }
+
+func.func @linalg_copy_to_memref_copy_strides(%input : memref<128x32xf32>, %output : memref<128x64xf32>) {
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<128x64xf32>
+  %subview = memref.subview %alloc[0, 32] [128, 32] [1, 1] : memref<128x64xf32> to memref<128x32xf32, strided<[64, 1], offset: 32>>
+  linalg.copy ins(%input : memref<128x32xf32>) outs(%subview : memref<128x32xf32, strided<[64, 1], offset: 32>>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_tensors(%input : tensor<128x64xf32>, %output : tensor<128x64xf32>) -> tensor<128x64xf32> {
+  // expected-note @below {{target op}}
+  %0 = linalg.copy ins(%input : tensor<128x64xf32>) outs(%output : tensor<128x64xf32>) -> tensor<128x64xf32>
+  return %0 : tensor<128x64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{cannot transform a linalg.copy on tensors into a memref.copy}}
+    %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_
diff erent_element(%input : memref<128x64xf32>, %output : memref<128x64xf64>) {
+  // expected-note @below {{target op}}
+  linalg.copy ins(%input : memref<128x64xf32>) outs(%output : memref<128x64xf64>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{cannot transform a linalg.copy with 
diff erent source and destination element types}}
+    %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_scalar(%input : f64, %output : memref<128x64xf64>) {
+  // expected-note @below {{target op}}
+  linalg.copy ins(%input : f64) outs(%output : memref<128x64xf64>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{cannot transform a linalg.copy which input has no shape}}
+    %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}


        


More information about the Mlir-commits mailing list