[Mlir-commits] [mlir] [mlir][linalg] Fix getSourceSkipUnary to only skip cast-like ops (PR #198725)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 27 20:17:38 PDT 2026
https://github.com/Chennesxu updated https://github.com/llvm/llvm-project/pull/198725
>From b8a353050ee9945dd99b20fca582ac7c1c1f0127 Mon Sep 17 00:00:00 2001
From: Chennes Xu <xuchen359 at gmail.com>
Date: Wed, 20 May 2026 16:01:26 +0800
Subject: [PATCH] [mlir][linalg] Only skip supported casts in contraction
matching
getSourceSkipUnary skipped arbitrary unary side-effect-free ops when matching contraction bodies. This allowed semantics-changing ops such as arith.negf to be ignored, causing linalg.generic ops to be incorrectly specialized to named/category contraction ops.
Restrict the matcher to a conservative allowlist of scalar arith casts modeled by current linalg contraction cast semantics, and rename the helper to getSourceSkipCast. Only skip one cast layer so cast chains such as fptosi -> sitofp are not treated as transparent.
Update the contraction-body comments to make clear that this is structural matching only; callers remain responsible for validating cast placement for named-op round-trip semantics.
Add negative coverage for arith.negf, bitcast, and cast chains, and positive coverage for fptosi/fptoui casts.
Fixes #197178.
---
.../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 18 +++--
.../Dialect/Linalg/IR/LinalgInterfaces.cpp | 43 +++++------
.../Linalg/specialize-generic-ops-fail.mlir | 45 +++++++++++
.../Linalg/specialize-generic-ops.mlir | 75 +++++++++++++++----
4 files changed, 140 insertions(+), 41 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 3c7ebd8277dbd..6c7e26762f993 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -165,17 +165,21 @@ namespace detail {
/// Returns true if the block contains a contraction of the following form:
///
-/// %0 = <elemwise>(permutation-of(cu(block-argument-0),
-/// cu(block-argument-1)))
-/// %1 = <reduce>(permutation-of(cu(%0), cu(block-argument-2)))
-/// return-like cu(%1)
+/// %0 = <elemwise>(permutation-of(c(block-argument-0),
+/// c(block-argument-1)))
+/// %1 = <reduce>(permutation-of(c(%0), c(block-argument-2)))
+/// return-like c(%1)
///
/// where <elemwise> and <reduce> are binary operations constituting a
/// contraction (in the canonical case, <elemwise> is a multiplication and
/// <reduce> is an addition). The name and other properties of these operations
-/// are checked by `isaPair`. All operands of all operations may be supplied
-/// through a chain of side effect-free unary operations, such as casts, which
-/// is denoted as `cu` above.
+/// are checked by `isaPair`. The notation `c(...)` denotes an optional
+/// supported scalar arith cast.
+///
+/// Note: This is structural matching only. Callers must separately validate
+/// that cast placement matches forms produced by linalg.generalize for named-op
+/// round-trip semantics (e.g., linalg.generalize does not produce output-side
+/// casts, so bodies with output-side casts cannot round-trip correctly).
///
/// When the body does not contain a contraction, a more precise description of
/// the failed precondition is send to the `errs` stream, if provided.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 238bddcb3b2bd..65397d65c408c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -296,19 +296,20 @@ bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op,
// ContractionOpInterface implementation
//===----------------------------------------------------------------------===//
-/// If the value is defined by a chain of unary side effect-free, go up the
-/// use-def chain until the first value that isn't defined by such an op.
-// TODO: relax to multi-operands with constants, which are technically unary ops
-// as needed (e.g. add5).
-static Value getSourceSkipUnary(Value value) {
+/// Return true for scalar arith cast ops modeled by current linalg contraction
+/// cast semantics.
+static bool isSupportedContractionCast(Operation *op) {
+ return isa<arith::ExtFOp, arith::TruncFOp, arith::ExtSIOp, arith::ExtUIOp,
+ arith::TruncIOp, arith::SIToFPOp, arith::UIToFPOp, arith::FPToSIOp,
+ arith::FPToUIOp>(op);
+}
+
+/// If the value is defined by a supported contraction cast op, return its
+/// source. Otherwise, return the value unchanged.
+static Value getSourceSkipCast(Value value) {
Operation *op = value.getDefiningOp();
- while (op && op->getNumOperands() == 1) {
- auto iface = dyn_cast<MemoryEffectOpInterface>(op);
- if (!iface || !iface.hasNoEffect())
- break;
- value = op->getOperand(0);
- op = value.getDefiningOp();
- }
+ if (op && op->getNumOperands() == 1 && isSupportedContractionCast(op))
+ return op->getOperand(0);
return value;
}
@@ -331,7 +332,7 @@ bool mlir::linalg::detail::isContractionBody(
return false;
}
- Value yielded = getSourceSkipUnary(terminator->getOperand(0));
+ Value yielded = getSourceSkipCast(terminator->getOperand(0));
Operation *reductionOp = yielded.getDefiningOp();
if (!reductionOp || reductionOp->getNumResults() != 1 ||
reductionOp->getNumOperands() != 2) {
@@ -339,17 +340,17 @@ bool mlir::linalg::detail::isContractionBody(
return false;
}
- Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0));
- Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1));
+ Value reductionLHS = getSourceSkipCast(reductionOp->getOperand(0));
+ Value reductionRHS = getSourceSkipCast(reductionOp->getOperand(1));
if (reductionLHS != block.getArgument(2) &&
reductionRHS != block.getArgument(2)) {
errs << "expected reduction to take block argument #2 as one of the "
- "operands (modulo unary casts)";
+ "operands (modulo supported contraction casts)";
return false;
}
- Value contributed = getSourceSkipUnary(
+ Value contributed = getSourceSkipCast(
isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
Operation *elementwiseOp = contributed.getDefiningOp();
if (!elementwiseOp || elementwiseOp->getNumResults() != 1 ||
@@ -363,8 +364,8 @@ bool mlir::linalg::detail::isContractionBody(
return false;
}
- Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0));
- Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1));
+ Value elementwiseLHS = getSourceSkipCast(elementwiseOp->getOperand(0));
+ Value elementwiseRHS = getSourceSkipCast(elementwiseOp->getOperand(1));
if ((elementwiseLHS == block.getArgument(0) &&
elementwiseRHS == block.getArgument(1)) ||
(elementwiseLHS == block.getArgument(1) &&
@@ -372,8 +373,8 @@ bool mlir::linalg::detail::isContractionBody(
return true;
}
- errs << "expected elementwise op to apply to block arguments (modulo unary "
- "casts)";
+ errs << "expected elementwise op to apply to block arguments (modulo "
+ "supported contraction casts)";
return false;
}
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
index 5d66837fca510..f474b3a347594 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
@@ -46,3 +46,48 @@ func.func @not_copy(%input: tensor<8xi32>, %init: tensor<8xi32>) -> tensor<8xi32
} -> tensor<8xi32>
return %res : tensor<8xi32>
}
+
+// -----
+
+#map3 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
+// This test checks that linalg.generic with a negf between mulf and addf
+// does not get incorrectly specialized to matmul.
+// CHECK-LABEL: @contraction_with_negf
+// CHECK-NOT: linalg.matmul
+// CHECK: linalg.generic
+func.func @contraction_with_negf(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<3x3xf32>) -> tensor<3x3xf32> {
+ %0 = linalg.generic {indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<3x3xf32>, tensor<3x3xf32>) outs(%arg2 : tensor<3x3xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ %2 = arith.negf %1 : f32
+ %3 = arith.addf %out, %2 : f32
+ linalg.yield %3 : f32
+ } -> tensor<3x3xf32>
+ return %0 : tensor<3x3xf32>
+}
+
+// -----
+
+#map3 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
+// This test checks that a cast chain changing the input semantics does not get
+// ignored when matching contractions.
+// CHECK-LABEL: @contraction_with_rounding_cast_chain
+// CHECK-NOT: linalg.matmul
+// CHECK: linalg.generic
+func.func @contraction_with_rounding_cast_chain(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<3x3xf32>) -> tensor<3x3xf32> {
+ %0 = linalg.generic {indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<3x3xf32>, tensor<3x3xf32>) outs(%arg2 : tensor<3x3xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.fptosi %in : f32 to i32
+ %2 = arith.sitofp %1 : i32 to f32
+ %3 = arith.fptosi %in_0 : f32 to i32
+ %4 = arith.sitofp %3 : i32 to f32
+ %5 = arith.mulf %2, %4 : f32
+ %6 = arith.addf %out, %5 : f32
+ linalg.yield %6 : f32
+ } -> tensor<3x3xf32>
+ return %0 : tensor<3x3xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 3d6c2962731c9..45b2740510017 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -964,14 +964,11 @@ func.func @negative_op_matmul_output_cast(%A: tensor<16x8xi32>, %B: tensor<8x32x
// CATEGORY: linalg.generic
// CATEGORY-NOT: linalg.contract
-// Bitcasts are not modeled by the cast attribute, but should not block
-// specialization.
-// NOTE: Bitcasts are not preserved by the matmul named op during
-// roundtrip, so this is potentially loosing information here.
-// See #177593 for more details.
-func.func @op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
- %B: tensor<8x32xi32>,
- %Out: tensor<16x32xf32>) -> tensor<16x32xf32> {
+// Bitcasts are not modeled by the cast attribute, so specializing this would
+// not round-trip through the matmul named op without losing information.
+func.func @negative_op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
+ %B: tensor<8x32xi32>,
+ %Out: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.generic
{indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction"]}
@@ -987,13 +984,13 @@ func.func @op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
return %0 : tensor<16x32xf32>
}
-// ALL-LABEL: op_matmul_bitcast_int_to_float
+// ALL-LABEL: negative_op_matmul_bitcast_int_to_float
-// NAMED-NOT: linalg.generic
-// NAMED: linalg.matmul
+// NAMED: linalg.generic
+// NAMED-NOT: linalg.matmul
-// CATEGORY-NOT: linalg.generic
-// CATEGORY: linalg.contract
+// CATEGORY: linalg.generic
+// CATEGORY-NOT: linalg.contract
// Signed float casts only use sitofp, which defaults to signed semantics.
func.func @op_matmul_signed_cast_float(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
@@ -1050,6 +1047,58 @@ func.func @op_matmul_unsigned_cast_float(%A: tensor<16x8xi16>, %B: tensor<8x32xi
// CATEGORY-NOT: linalg.generic
// CATEGORY: linalg.contract{{.*}}{cast = #linalg.type_fn<cast_unsigned>}
+// Float-to-int casts (fptosi) should be recognized and specialized.
+func.func @op_matmul_fptosi_cast(%A: tensor<16x8xf32>, %B: tensor<8x32xf32>,
+ %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xf32>, tensor<8x32xf32>)
+ outs(%Out : tensor<16x32xi32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: i32):
+ %1 = arith.fptosi %in : f32 to i32
+ %2 = arith.fptosi %in_0 : f32 to i32
+ %3 = arith.muli %1, %2 : i32
+ %4 = arith.addi %out, %3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<16x32xi32>
+ return %0 : tensor<16x32xi32>
+}
+
+// ALL-LABEL: op_matmul_fptosi_cast
+
+// NAMED-NOT: linalg.generic
+// NAMED: linalg.matmul
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
+// Float-to-unsigned-int casts (fptoui) should be recognized with unsigned cast attr.
+func.func @op_matmul_fptoui_cast(%A: tensor<16x8xf32>, %B: tensor<8x32xf32>,
+ %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xf32>, tensor<8x32xf32>)
+ outs(%Out : tensor<16x32xi32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: i32):
+ %1 = arith.fptoui %in : f32 to i32
+ %2 = arith.fptoui %in_0 : f32 to i32
+ %3 = arith.muli %1, %2 : i32
+ %4 = arith.addi %out, %3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<16x32xi32>
+ return %0 : tensor<16x32xi32>
+}
+
+// ALL-LABEL: op_matmul_fptoui_cast
+
+// NAMED-NOT: linalg.generic
+// NAMED: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract{{.*}}{cast = #linalg.type_fn<cast_unsigned>}
+
// -----
///----------------------------------------------------------------------------------------
More information about the Mlir-commits
mailing list