[Mlir-commits] [mlir] [mlir][linalg] Add missing check for `isaCopyOpInterface` (PR #149313)

Longsheng Mou llvmlistbot at llvm.org
Thu Jul 17 07:33:29 PDT 2025


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/149313

>From c2132c4096f92efc1ae2dd85f15af36a7e312472 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 17 Jul 2025 20:26:14 +0800
Subject: [PATCH] [mlir][linalg] Add missing check for `isaCopyOpInterface`

This PR fixes a missing validation in `isaCopyOpInterface` by checking that the `linalg.yield` operand is identical to the first block argument, indicating a direct copy.
---
 mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 10 ++++++++--
 .../Linalg/specialize-generic-ops-fail.mlir     | 17 +++++++++++++++++
 2 files changed, 25 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 94f2002fc51fa..085ae4c93b829 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -68,8 +68,14 @@ bool linalg::isaCopyOpInterface(LinalgOp op) {
       !mapRange.back().isIdentity()) {
     return false;
   }
-  // Region.
-  return llvm::hasSingleElement(op.getBlock()->getOperations());
+  // Check yield first block argument.
+  Block *body = op.getBlock();
+  if (body->getOperations().size() != 1)
+    return false;
+  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+  if (!yieldOp || yieldOp.getNumOperands() != 1)
+    return false;
+  return yieldOp->getOperand(0) == body->getArgument(0);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
index 357f2c11a7936..5d66837fca510 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
@@ -29,3 +29,20 @@ func.func @neither_permutation_nor_broadcast(%init : tensor<8xi32>) -> tensor<8x
   } -> tensor<8xi32>
   return %res : tensor<8xi32>
 }
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @not_copy
+//  CHECK-NOT:    linalg.copy
+//      CHECK:    linalg.generic
+func.func @not_copy(%input: tensor<8xi32>, %init: tensor<8xi32>) -> tensor<8xi32> {
+  %c0_i32 = arith.constant 0 : i32
+  %res = linalg.generic {
+    indexing_maps = [#map, #map], iterator_types = ["parallel"]
+  } ins(%input: tensor<8xi32>) outs(%init: tensor<8xi32>) {
+  ^bb0(%in: i32, %out: i32):
+    linalg.yield %c0_i32 : i32
+  } -> tensor<8xi32>
+  return %res : tensor<8xi32>
+}



More information about the Mlir-commits mailing list