[Mlir-commits] [mlir] [mlir][linalg] Add missing check for `isaCopyOpInterface` (PR #149313)
Longsheng Mou
llvmlistbot at llvm.org
Thu Jul 17 07:00:22 PDT 2025
https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/149313
This PR fixes a missing validation in `isaCopyOpInterface` by checking that the `linalg.yield` operand is identical to the first block argument, indicating a direct copy. Fixes #130002.
>From b20984344d6dab4e1e59e1346d53099a9c7ddff0 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 17 Jul 2025 20:26:14 +0800
Subject: [PATCH] [mlir][linalg] Add missing check for `isaCopyOpInterface`
This PR fixes a missing validation in `isaCopyOpInterface` by checking that the `linalg.yield` operand is identical to the first block argument, indicating a direct copy.
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 4 ++--
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 14 ++++++++++----
.../Linalg/specialize-generic-ops-fail.mlir | 17 +++++++++++++++++
3 files changed, 29 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 0ebbeea937554..d50f4f5ca0726 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -118,8 +118,8 @@ FailureOr<ConvolutionDimensions> inferConvolutionDims(LinalgOp linalgOp);
bool isaConvolutionOpInterface(LinalgOp linalgOp,
bool allowEmptyConvolvedDims = false);
-/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
-bool isaCopyOpInterface(LinalgOp linalgOp);
+/// Checks whether `genericOp` is semantically equivalent to a `linalg.copyOp`.
+bool isaCopyOpInterface(GenericOp genericOp);
/// Checks whether `genericOp` is semantically equivalent to a
/// `linalg.broadcast`. Returns broadcast dimensions if true.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 94f2002fc51fa..38c4bc5295eae 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -58,8 +58,8 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
// CopyOpInterface implementation
//===----------------------------------------------------------------------===//
-bool linalg::isaCopyOpInterface(LinalgOp op) {
- // Check all loops are parallel and linalgOp is single input and output.
+bool linalg::isaCopyOpInterface(GenericOp op) {
+ // Check all loops are parallel and genericOp is single input and output.
if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
return false;
@@ -68,8 +68,14 @@ bool linalg::isaCopyOpInterface(LinalgOp op) {
!mapRange.back().isIdentity()) {
return false;
}
- // Region.
- return llvm::hasSingleElement(op.getBlock()->getOperations());
+ // Check yield first block argument.
+ Block *body = op.getBody();
+ if (body->getOperations().size() != 1)
+ return false;
+ auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+ if (!yieldOp || yieldOp.getNumOperands() != 1)
+ return false;
+ return yieldOp->getOperand(0) == body->getArgument(0);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
index 357f2c11a7936..5d66837fca510 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
@@ -29,3 +29,20 @@ func.func @neither_permutation_nor_broadcast(%init : tensor<8xi32>) -> tensor<8x
} -> tensor<8xi32>
return %res : tensor<8xi32>
}
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @not_copy
+// CHECK-NOT: linalg.copy
+// CHECK: linalg.generic
+func.func @not_copy(%input: tensor<8xi32>, %init: tensor<8xi32>) -> tensor<8xi32> {
+ %c0_i32 = arith.constant 0 : i32
+ %res = linalg.generic {
+ indexing_maps = [#map, #map], iterator_types = ["parallel"]
+ } ins(%input: tensor<8xi32>) outs(%init: tensor<8xi32>) {
+ ^bb0(%in: i32, %out: i32):
+ linalg.yield %c0_i32 : i32
+ } -> tensor<8xi32>
+ return %res : tensor<8xi32>
+}
More information about the Mlir-commits
mailing list