[Mlir-commits] [mlir] [mlir][linalg] Fix getSourceSkipUnary to only skip cast-like ops (PR #198725)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 27 20:17:38 PDT 2026


https://github.com/Chennesxu updated https://github.com/llvm/llvm-project/pull/198725

>From b8a353050ee9945dd99b20fca582ac7c1c1f0127 Mon Sep 17 00:00:00 2001
From: Chennes Xu <xuchen359 at gmail.com>
Date: Wed, 20 May 2026 16:01:26 +0800
Subject: [PATCH] [mlir][linalg] Only skip supported casts in contraction
 matching

getSourceSkipUnary skipped arbitrary unary side-effect-free ops when matching contraction bodies. This allowed semantics-changing ops such as arith.negf to be ignored, causing linalg.generic ops to be incorrectly specialized to named/category contraction ops.

Restrict the matcher to a conservative allowlist of scalar arith casts modeled by current linalg contraction cast semantics, and rename the helper to getSourceSkipCast. Only skip one cast layer so cast chains such as fptosi -> sitofp are not treated as transparent.

Update the contraction-body comments to make clear that this is structural matching only; callers remain responsible for validating cast placement for named-op round-trip semantics.

Add negative coverage for arith.negf, bitcast, and cast chains, and positive coverage for fptosi/fptoui casts.

Fixes #197178.
---
 .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 18 +++--
 .../Dialect/Linalg/IR/LinalgInterfaces.cpp    | 43 +++++------
 .../Linalg/specialize-generic-ops-fail.mlir   | 45 +++++++++++
 .../Linalg/specialize-generic-ops.mlir        | 75 +++++++++++++++----
 4 files changed, 140 insertions(+), 41 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 3c7ebd8277dbd..6c7e26762f993 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -165,17 +165,21 @@ namespace detail {
 
 /// Returns true if the block contains a contraction of the following form:
 ///
-///   %0 = <elemwise>(permutation-of(cu(block-argument-0),
-///                                  cu(block-argument-1)))
-///   %1 = <reduce>(permutation-of(cu(%0), cu(block-argument-2)))
-///   return-like cu(%1)
+///   %0 = <elemwise>(permutation-of(c(block-argument-0),
+///                                  c(block-argument-1)))
+///   %1 = <reduce>(permutation-of(c(%0), c(block-argument-2)))
+///   return-like c(%1)
 ///
 /// where <elemwise> and <reduce> are binary operations constituting a
 /// contraction (in the canonical case, <elemwise> is a multiplication and
 /// <reduce> is an addition). The name and other properties of these operations
-/// are checked by `isaPair`. All operands of all operations may be supplied
-/// through a chain of side effect-free unary operations, such as casts, which
-/// is denoted as `cu` above.
+/// are checked by `isaPair`. The notation `c(...)` denotes an optional
+/// supported scalar arith cast.
+///
+/// Note: This is structural matching only. Callers must separately validate
+/// that cast placement matches forms produced by linalg.generalize for named-op
+/// round-trip semantics (e.g., linalg.generalize does not produce output-side
+/// casts, so bodies with output-side casts cannot round-trip correctly).
 ///
 /// When the body does not contain a contraction, a more precise description of
 /// the failed precondition is send to the `errs` stream, if provided.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 238bddcb3b2bd..65397d65c408c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -296,19 +296,20 @@ bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op,
 // ContractionOpInterface implementation
 //===----------------------------------------------------------------------===//
 
-/// If the value is defined by a chain of unary side effect-free, go up the
-/// use-def chain until the first value that isn't defined by such an op.
-// TODO: relax to multi-operands with constants, which are technically unary ops
-// as needed (e.g. add5).
-static Value getSourceSkipUnary(Value value) {
+/// Return true for scalar arith cast ops modeled by current linalg contraction
+/// cast semantics.
+static bool isSupportedContractionCast(Operation *op) {
+  return isa<arith::ExtFOp, arith::TruncFOp, arith::ExtSIOp, arith::ExtUIOp,
+             arith::TruncIOp, arith::SIToFPOp, arith::UIToFPOp, arith::FPToSIOp,
+             arith::FPToUIOp>(op);
+}
+
+/// If the value is defined by a supported contraction cast op, return its
+/// source. Otherwise, return the value unchanged.
+static Value getSourceSkipCast(Value value) {
   Operation *op = value.getDefiningOp();
-  while (op && op->getNumOperands() == 1) {
-    auto iface = dyn_cast<MemoryEffectOpInterface>(op);
-    if (!iface || !iface.hasNoEffect())
-      break;
-    value = op->getOperand(0);
-    op = value.getDefiningOp();
-  }
+  if (op && op->getNumOperands() == 1 && isSupportedContractionCast(op))
+    return op->getOperand(0);
   return value;
 }
 
@@ -331,7 +332,7 @@ bool mlir::linalg::detail::isContractionBody(
     return false;
   }
 
-  Value yielded = getSourceSkipUnary(terminator->getOperand(0));
+  Value yielded = getSourceSkipCast(terminator->getOperand(0));
   Operation *reductionOp = yielded.getDefiningOp();
   if (!reductionOp || reductionOp->getNumResults() != 1 ||
       reductionOp->getNumOperands() != 2) {
@@ -339,17 +340,17 @@ bool mlir::linalg::detail::isContractionBody(
     return false;
   }
 
-  Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0));
-  Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1));
+  Value reductionLHS = getSourceSkipCast(reductionOp->getOperand(0));
+  Value reductionRHS = getSourceSkipCast(reductionOp->getOperand(1));
 
   if (reductionLHS != block.getArgument(2) &&
       reductionRHS != block.getArgument(2)) {
     errs << "expected reduction to take block argument #2 as one of the "
-            "operands (modulo unary casts)";
+            "operands (modulo supported contraction casts)";
     return false;
   }
 
-  Value contributed = getSourceSkipUnary(
+  Value contributed = getSourceSkipCast(
       isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
   Operation *elementwiseOp = contributed.getDefiningOp();
   if (!elementwiseOp || elementwiseOp->getNumResults() != 1 ||
@@ -363,8 +364,8 @@ bool mlir::linalg::detail::isContractionBody(
     return false;
   }
 
-  Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0));
-  Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1));
+  Value elementwiseLHS = getSourceSkipCast(elementwiseOp->getOperand(0));
+  Value elementwiseRHS = getSourceSkipCast(elementwiseOp->getOperand(1));
   if ((elementwiseLHS == block.getArgument(0) &&
        elementwiseRHS == block.getArgument(1)) ||
       (elementwiseLHS == block.getArgument(1) &&
@@ -372,8 +373,8 @@ bool mlir::linalg::detail::isContractionBody(
     return true;
   }
 
-  errs << "expected elementwise op to apply to block arguments (modulo unary "
-          "casts)";
+  errs << "expected elementwise op to apply to block arguments (modulo "
+          "supported contraction casts)";
   return false;
 }
 
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
index 5d66837fca510..f474b3a347594 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
@@ -46,3 +46,48 @@ func.func @not_copy(%input: tensor<8xi32>, %init: tensor<8xi32>) -> tensor<8xi32
   } -> tensor<8xi32>
   return %res : tensor<8xi32>
 }
+
+// -----
+
+#map3 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
+// This test checks that linalg.generic with a negf between mulf and addf
+// does not get incorrectly specialized to matmul.
+// CHECK-LABEL: @contraction_with_negf
+//  CHECK-NOT:    linalg.matmul
+//      CHECK:    linalg.generic
+func.func @contraction_with_negf(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<3x3xf32>) -> tensor<3x3xf32> {
+  %0 = linalg.generic {indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<3x3xf32>, tensor<3x3xf32>) outs(%arg2 : tensor<3x3xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.negf %1 : f32
+    %3 = arith.addf %out, %2 : f32
+    linalg.yield %3 : f32
+  } -> tensor<3x3xf32>
+  return %0 : tensor<3x3xf32>
+}
+
+// -----
+
+#map3 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
+// This test checks that a cast chain changing the input semantics does not get
+// ignored when matching contractions.
+// CHECK-LABEL: @contraction_with_rounding_cast_chain
+//  CHECK-NOT:    linalg.matmul
+//      CHECK:    linalg.generic
+func.func @contraction_with_rounding_cast_chain(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<3x3xf32>) -> tensor<3x3xf32> {
+  %0 = linalg.generic {indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<3x3xf32>, tensor<3x3xf32>) outs(%arg2 : tensor<3x3xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.fptosi %in : f32 to i32
+    %2 = arith.sitofp %1 : i32 to f32
+    %3 = arith.fptosi %in_0 : f32 to i32
+    %4 = arith.sitofp %3 : i32 to f32
+    %5 = arith.mulf %2, %4 : f32
+    %6 = arith.addf %out, %5 : f32
+    linalg.yield %6 : f32
+  } -> tensor<3x3xf32>
+  return %0 : tensor<3x3xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 3d6c2962731c9..45b2740510017 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -964,14 +964,11 @@ func.func @negative_op_matmul_output_cast(%A: tensor<16x8xi32>, %B: tensor<8x32x
 // CATEGORY: linalg.generic
 // CATEGORY-NOT: linalg.contract
 
-// Bitcasts are not modeled by the cast attribute, but should not block
-// specialization.
-// NOTE: Bitcasts are not preserved by the matmul named op during
-// roundtrip, so this is potentially loosing information here.
-// See #177593 for more details.
-func.func @op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
-                                          %B: tensor<8x32xi32>,
-                                          %Out: tensor<16x32xf32>) -> tensor<16x32xf32> {
+// Bitcasts are not modeled by the cast attribute, so specializing this would
+// not round-trip through the matmul named op without losing information.
+func.func @negative_op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
+                                                   %B: tensor<8x32xi32>,
+                                                   %Out: tensor<16x32xf32>) -> tensor<16x32xf32> {
   %0 = linalg.generic
     {indexing_maps = [#map, #map1, #map2],
     iterator_types = ["parallel", "parallel", "reduction"]}
@@ -987,13 +984,13 @@ func.func @op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
   return %0 : tensor<16x32xf32>
 }
 
-// ALL-LABEL: op_matmul_bitcast_int_to_float
+// ALL-LABEL: negative_op_matmul_bitcast_int_to_float
 
-// NAMED-NOT: linalg.generic
-// NAMED: linalg.matmul
+// NAMED: linalg.generic
+// NAMED-NOT: linalg.matmul
 
-// CATEGORY-NOT: linalg.generic
-// CATEGORY: linalg.contract
+// CATEGORY: linalg.generic
+// CATEGORY-NOT: linalg.contract
 
 // Signed float casts only use sitofp, which defaults to signed semantics.
 func.func @op_matmul_signed_cast_float(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
@@ -1050,6 +1047,58 @@ func.func @op_matmul_unsigned_cast_float(%A: tensor<16x8xi16>, %B: tensor<8x32xi
 // CATEGORY-NOT: linalg.generic
 // CATEGORY: linalg.contract{{.*}}{cast = #linalg.type_fn<cast_unsigned>}
 
+// Float-to-int casts (fptosi) should be recognized and specialized.
+func.func @op_matmul_fptosi_cast(%A: tensor<16x8xf32>, %B: tensor<8x32xf32>,
+                                 %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<16x8xf32>, tensor<8x32xf32>)
+    outs(%Out : tensor<16x32xi32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: i32):
+    %1 = arith.fptosi %in : f32 to i32
+    %2 = arith.fptosi %in_0 : f32 to i32
+    %3 = arith.muli %1, %2 : i32
+    %4 = arith.addi %out, %3 : i32
+    linalg.yield %4 : i32
+  } -> tensor<16x32xi32>
+  return %0 : tensor<16x32xi32>
+}
+
+// ALL-LABEL: op_matmul_fptosi_cast
+
+// NAMED-NOT: linalg.generic
+// NAMED: linalg.matmul
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract
+
+// Float-to-unsigned-int casts (fptoui) should be recognized with unsigned cast attr.
+func.func @op_matmul_fptoui_cast(%A: tensor<16x8xf32>, %B: tensor<8x32xf32>,
+                                 %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<16x8xf32>, tensor<8x32xf32>)
+    outs(%Out : tensor<16x32xi32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: i32):
+    %1 = arith.fptoui %in : f32 to i32
+    %2 = arith.fptoui %in_0 : f32 to i32
+    %3 = arith.muli %1, %2 : i32
+    %4 = arith.addi %out, %3 : i32
+    linalg.yield %4 : i32
+  } -> tensor<16x32xi32>
+  return %0 : tensor<16x32xi32>
+}
+
+// ALL-LABEL: op_matmul_fptoui_cast
+
+// NAMED-NOT: linalg.generic
+// NAMED: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+
+// CATEGORY-NOT: linalg.generic
+// CATEGORY: linalg.contract{{.*}}{cast = #linalg.type_fn<cast_unsigned>}
+
 // -----
 
 ///----------------------------------------------------------------------------------------



More information about the Mlir-commits mailing list