[Mlir-commits] [mlir] [mlir][linalg] Add missing check for `isaCopyOpInterface` (PR #149313)

Longsheng Mou llvmlistbot at llvm.org
Thu Jul 17 07:00:22 PDT 2025


https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/149313

This PR fixes a missing validation in `isaCopyOpInterface` by checking that the `linalg.yield` operand is identical to the first block argument, indicating a direct copy. Fixes #130002.

>From b20984344d6dab4e1e59e1346d53099a9c7ddff0 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 17 Jul 2025 20:26:14 +0800
Subject: [PATCH] [mlir][linalg] Add missing check for `isaCopyOpInterface`

This PR fixes a missing validation in `isaCopyOpInterface` by checking that the `linalg.yield` operand is identical to the first block argument, indicating a direct copy.
---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h   |  4 ++--
 mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 14 ++++++++++----
 .../Linalg/specialize-generic-ops-fail.mlir     | 17 +++++++++++++++++
 3 files changed, 29 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 0ebbeea937554..d50f4f5ca0726 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -118,8 +118,8 @@ FailureOr<ConvolutionDimensions> inferConvolutionDims(LinalgOp linalgOp);
 bool isaConvolutionOpInterface(LinalgOp linalgOp,
                                bool allowEmptyConvolvedDims = false);
 
-/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
-bool isaCopyOpInterface(LinalgOp linalgOp);
+/// Checks whether `genericOp` is semantically equivalent to a `linalg.copyOp`.
+bool isaCopyOpInterface(GenericOp genericOp);
 
 /// Checks whether `genericOp` is semantically equivalent to a
 ///  `linalg.broadcast`. Returns broadcast dimensions if true.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 94f2002fc51fa..38c4bc5295eae 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -58,8 +58,8 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
 // CopyOpInterface implementation
 //===----------------------------------------------------------------------===//
 
-bool linalg::isaCopyOpInterface(LinalgOp op) {
-  // Check all loops are parallel and linalgOp is single input and output.
+bool linalg::isaCopyOpInterface(GenericOp op) {
+  // Check all loops are parallel and genericOp is single input and output.
   if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
     return false;
 
@@ -68,8 +68,14 @@ bool linalg::isaCopyOpInterface(LinalgOp op) {
       !mapRange.back().isIdentity()) {
     return false;
   }
-  // Region.
-  return llvm::hasSingleElement(op.getBlock()->getOperations());
+  // Check yield first block argument.
+  Block *body = op.getBody();
+  if (body->getOperations().size() != 1)
+    return false;
+  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+  if (!yieldOp || yieldOp.getNumOperands() != 1)
+    return false;
+  return yieldOp->getOperand(0) == body->getArgument(0);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
index 357f2c11a7936..5d66837fca510 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
@@ -29,3 +29,20 @@ func.func @neither_permutation_nor_broadcast(%init : tensor<8xi32>) -> tensor<8x
   } -> tensor<8xi32>
   return %res : tensor<8xi32>
 }
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @not_copy
+//  CHECK-NOT:    linalg.copy
+//      CHECK:    linalg.generic
+func.func @not_copy(%input: tensor<8xi32>, %init: tensor<8xi32>) -> tensor<8xi32> {
+  %c0_i32 = arith.constant 0 : i32
+  %res = linalg.generic {
+    indexing_maps = [#map, #map], iterator_types = ["parallel"]
+  } ins(%input: tensor<8xi32>) outs(%init: tensor<8xi32>) {
+  ^bb0(%in: i32, %out: i32):
+    linalg.yield %c0_i32 : i32
+  } -> tensor<8xi32>
+  return %res : tensor<8xi32>
+}



More information about the Mlir-commits mailing list