[Mlir-commits] [mlir] [mlir][linalg] Preserve cast semantics during generic to matmul (PR #174757)

Prathamesh Tagore llvmlistbot at llvm.org
Fri Jan 23 06:20:56 PST 2026


https://github.com/meshtag updated https://github.com/llvm/llvm-project/pull/174757

>From b7fcffda37e1d9b331ac6315db559741af84d93a Mon Sep 17 00:00:00 2001
From: Prathamesh Tagore <prathameshtagore at gmail.com>
Date: Wed, 7 Jan 2026 17:37:22 +0530
Subject: [PATCH] [mlir][linalg] Preserve cast semantics during generic to
 matmul

Infer signed/unsigned cast intent from cast ops in linalg.generic bodies and
propagate it via the matmul cast attribute. This could otherwise lead to
silent overflow/underflow errors in e2e execution.

TODO: Extend this to other named ops that support cast attribute.
---
 .../Dialect/Linalg/Transforms/Specialize.cpp  |  81 ++++++-
 .../Linalg/roundtrip-linalg-named-ops.mlir    |  20 ++
 .../Linalg/specialize-generic-ops.mlir        | 207 ++++++++++++++++++
 3 files changed, 301 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 0c7b998ffcab9..4927905f14a0b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
@@ -131,17 +132,76 @@ static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
 }
 
 // Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
-//  All the variants expressed as pseudo regular expression:
-//      `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
-//  have same number of ins/out, so its easy to stamp different versions.
+// All the variants expressed as pseudo regular expression:
+// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
+// have same number of ins/out, so its easy to stamp different versions.
+// `castTy` is an optional type function that indicates whether (and which) cast
+// attribute is needed for the named matmul op variant.
 template <typename NamedOpTy>
-static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
+static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
+                                         std::optional<TypeFn> castTy) {
+  SmallVector<NamedAttribute> castAttrVec;
+  // Only explicitly specify the cast attribute for unsigned cast; signed is
+  // the default for linalg.matmul/linalg.batch_matmul.
+  if (castTy.has_value() && *castTy == TypeFn::cast_unsigned)
+    castAttrVec = {rewriter.getNamedAttr(
+        "cast", TypeFnAttr::get(rewriter.getContext(), *castTy))};
+
   LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
       op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
-      ValueRange{op.getDpsInits()[0]});
+      ValueRange{op.getDpsInits()[0]}, castAttrVec);
   return namedOp;
 }
 
+// Returns the cast type to use for a matmul-like named op. If the generic
+// contains casts that cannot be represented (e.g. output casts or mixed
+// signedness), return std::nullopt.
+static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
+  bool foundCastForMatmulOutput = false;
+  SmallVector<TypeFn> castTyFns;
+  genericOp.getBody()->walk([&](CastOpInterface castOp) {
+    // Collect forward slice of the cast op to check if it is for the matmul
+    // output.
+    SetVector<Operation *> forwardSlice;
+    getForwardSlice(castOp, &forwardSlice);
+
+    // If there is no multiplication op in the forward slice, then this cast
+    // op is for the matmul output. Cast ops on matmul output cannot be
+    // expressed by the matmul op variant.
+    if (!llvm::any_of(forwardSlice, [](Operation *op) {
+          // We check explicitly for these multiplication ops in
+          // `specializeLinalgContractions()` to infer matmul-like ops.
+          return isa<arith::MulIOp, arith::MulFOp, complex::MulOp>(op);
+        })) {
+      foundCastForMatmulOutput = true;
+      return WalkResult::interrupt();
+    }
+
+    // Determine the cast type.
+    if (isa<arith::ExtUIOp, arith::UIToFPOp, arith::FPToUIOp>(castOp))
+      castTyFns.push_back(TypeFn::cast_unsigned);
+    else if (isa<arith::ExtSIOp, arith::SIToFPOp, arith::FPToSIOp>(castOp))
+      castTyFns.push_back(TypeFn::cast_signed);
+
+    return WalkResult::advance();
+  });
+
+  if (foundCastForMatmulOutput)
+    return std::nullopt;
+
+  if (!castTyFns.empty()) {
+    // If there were multiple different cast types found, then we can't express
+    // them using matmul-like ops. They only allow a single cast type for all
+    // inputs.
+    if (!llvm::all_equal(castTyFns))
+      return std::nullopt;
+    return castTyFns.front();
+  }
+
+  // Default to signed cast for matmul-like ops.
+  return TypeFn::cast_signed;
+}
+
 // Converts linalg.generic to named linalg.*matmul* where possible.
 static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
                                                         GenericOp genericOp) {
@@ -230,11 +290,18 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
       (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
     return failure();
 
+  // Determine the cast type for the named matmul op, or bail out if casts
+  // cannot be represented by the named op.
+  std::optional<TypeFn> castTy = getCastTypeForMatmulLikeOp(genericOp);
+  if (!castTy)
+    return rewriter.notifyMatchFailure(
+        genericOp, "contains invalid cast ops for the named matmul op");
+
   /// Codegen the different matmul variants.
   if (numOfBatchDims) {
-    return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
+    return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
   }
-  return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
+  return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy);
 }
 
 /// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy`
diff --git a/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
index 1fb520c5982e6..f15ae646e5765 100644
--- a/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
@@ -26,6 +26,11 @@ func.func @binary_add(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?
 
 // -----
 
+
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.matmul
+///----------------------------------------------------------------------------------------
+
 func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
@@ -38,6 +43,21 @@ func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32
 
 // -----
 
+// Check matmul with unsigned cast is correctly raised back to named op.
+func.func @matmul_unsigned_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi64>,
+                                %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+                     ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi64>)
+                     outs(%Out : tensor<16x32xi32>) -> tensor<16x32xi32>
+  return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: @matmul_unsigned_cast
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+
+// -----
+
 func.func @mixed_named_ops(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
                                    %C: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
   %AB = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index cf495a7d29b70..6acf1ca0d4e30 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -38,6 +38,10 @@ func.func @binary_op_div(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.matmul
+///----------------------------------------------------------------------------------------
+
 #map = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
 #map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
@@ -58,8 +62,187 @@ func.func @op_matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?x
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 
+// Cast-auditing tests: ensure we only specialize when the cast semantics can
+// be expressed by linalg.matmul, and use the cast attribute when needed.
+
+// Check matmul with unsigned cast is correctly raised back to named op.
+func.func @op_matmul_unsigned_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi32>,
+                                   %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi32>) outs(%Out : tensor<16x32xi32>) {
+  ^bb0(%in: i16, %in_0: i32, %out: i32):
+    %1 = arith.extui %in : i16 to i32
+    %3 = arith.muli %1, %in_0 : i32
+    %4 = arith.addi %out, %3 : i32
+    linalg.yield %4 : i32
+  } -> tensor<16x32xi32>
+  return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_unsigned_cast
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+
+// Ensures truncation rounding is tolerated with unsigned cases.
+// Note: We only consider casts as conflicting if they have different
+// signedness behaviours, and then we do not specialize if they do
+// conflict. Since this is not such a case, we do not block specialization.
+// Also the roundtrip lowering back to linalg.generic for such an op is
+// expected to produce the same thing again, so we are not loosing
+// information here.
+func.func @op_matmul_unsigned_cast_and_truncate(%A: tensor<16x8xi16>, %B: tensor<8x32xi64>,
+                                                %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi64>) outs(%Out : tensor<16x32xi32>) {
+  ^bb0(%in: i16, %in_0: i64, %out: i32):
+    %1 = arith.extui %in : i16 to i32
+    %2 = arith.trunci %in_0 : i64 to i32
+    %3 = arith.muli %1, %2 : i32
+    %4 = arith.addi %out, %3 : i32
+    linalg.yield %4 : i32
+  } -> tensor<16x32xi32>
+  return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_unsigned_cast_and_truncate
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+
+// Signed casts are the default, no cast attribute is required.
+func.func @op_matmul_signed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+                                 %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi32>) {
+   ^bb0(%in: i16, %in_0: i16, %out: i32):
+     %1 = arith.extsi %in : i16 to i32
+     %2 = arith.extsi %in_0 : i16 to i32
+     %3 = arith.muli %1, %2 : i32
+     %4 = arith.addi %out, %3 : i32
+     linalg.yield %4 : i32
+   } -> tensor<16x32xi32>
+   return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_signed_cast
+// CHECK-NOT: linalg.generic
+// CHECK-NOT: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+// CHECK: linalg.matmul
+
+// Mixed signed/unsigned inputs cannot be encoded with a single cast attribute.
+func.func @negative_op_matmul_mixed_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+                                %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xi32>) {
+   ^bb0(%in: i16, %in_0: i16, %out: i32):
+     %1 = arith.extui %in : i16 to i32
+     %2 = arith.extsi %in_0 : i16 to i32
+     %3 = arith.muli %1, %2 : i32
+     %4 = arith.addi %out, %3 : i32
+     linalg.yield %4 : i32
+   } -> tensor<16x32xi32>
+   return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: negative_op_matmul_mixed_cast
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
+
+// Output-side casts are not representable by the named matmul ops.
+func.func @negative_op_matmul_output_cast(%A: tensor<16x8xi32>, %B: tensor<8x32xi32>,
+                                 %Out: tensor<16x32xi64>) -> tensor<16x32xi64> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi32>, tensor<8x32xi32>) outs(%Out : tensor<16x32xi64>) {
+   ^bb0(%in: i32, %in_0: i32, %out: i64):
+     %3 = arith.trunci %out : i64 to i32
+     %4 = arith.muli %in, %in_0 : i32
+     %5 = arith.addi %3, %4 : i32
+     %6 = arith.extsi %5 : i32 to i64
+     linalg.yield %6 : i64
+   } -> tensor<16x32xi64>
+   return %0 : tensor<16x32xi64>
+}
+
+// CHECK-LABEL: negative_op_matmul_output_cast
+// CHECK: linalg.generic
+// CHECK-NOT: linalg.matmul
+
+// Bitcasts are not modeled by the cast attribute, but should not block
+// specialization.
+// NOTE: Bitcasts are not preserved by the matmul named op during
+// roundtrip, so this is potentially loosing information here.
+// See #177593 for more details.
+func.func @op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
+                                          %B: tensor<8x32xi32>,
+                                          %Out: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi32>, tensor<8x32xi32>) outs(%Out : tensor<16x32xf32>) {
+  ^bb0(%in: i32, %in_0: i32, %out: f32):
+    %1 = arith.bitcast %in : i32 to f32
+    %2 = arith.bitcast %in_0 : i32 to f32
+    %3 = arith.mulf %1, %2 : f32
+    %4 = arith.addf %out, %3 : f32
+    linalg.yield %4 : f32
+  } -> tensor<16x32xf32>
+  return %0 : tensor<16x32xf32>
+}
+
+// CHECK-LABEL: op_matmul_bitcast_int_to_float
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul
+
+// Signed float casts only use sitofp, which defaults to signed semantics.
+func.func @op_matmul_signed_cast_float(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+                                       %Out: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xf32>) {
+  ^bb0(%in: i16, %in_0: i16, %out: f32):
+    %1 = arith.sitofp %in : i16 to f32
+    %2 = arith.sitofp %in_0 : i16 to f32
+    %3 = arith.mulf %1, %2 : f32
+    %4 = arith.addf %out, %3 : f32
+    linalg.yield %4 : f32
+  } -> tensor<16x32xf32>
+  return %0 : tensor<16x32xf32>
+}
+
+// CHECK-LABEL: op_matmul_signed_cast_float
+// CHECK-NOT: linalg.generic
+// CHECK-NOT: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+// CHECK: linalg.matmul
+
+// Unsigned float casts are expressed via uitofp and use the unsigned cast attr.
+func.func @op_matmul_unsigned_cast_float(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,
+                                         %Out: tensor<16x32xf32>) -> tensor<16x32xf32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi16>) outs(%Out : tensor<16x32xf32>) {
+  ^bb0(%in: i16, %in_0: i16, %out: f32):
+    %1 = arith.uitofp %in : i16 to f32
+    %2 = arith.uitofp %in_0 : i16 to f32
+    %3 = arith.mulf %1, %2 : f32
+    %4 = arith.addf %out, %3 : f32
+    linalg.yield %4 : f32
+  } -> tensor<16x32xf32>
+  return %0 : tensor<16x32xf32>
+}
+
+// CHECK-LABEL: op_matmul_unsigned_cast_float
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+
 // -----
 
+///----------------------------------------------------------------------------------------
+/// Tests for linalg.batch_matmul
+///----------------------------------------------------------------------------------------
+
 #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
 #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
@@ -80,6 +263,30 @@ func.func @op_batch_matmul(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>, %Out:
 // CHECK-NOT: linalg.generic
 // CHECK: linalg.batch_matmul ins(%[[A]], %[[B]] : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%[[Out]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
 
+// Ensure that the unsigned cast path for cast detection is exercised for
+// batch_matmul as well.
+func.func @op_batch_matmul_unsigned_cast(%A: tensor<2x16x8xi16>,
+                                         %B: tensor<2x8x16xi64>,
+                                         %Out: tensor<2x16x16xi32>) -> tensor<2x16x16xi32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2],
+          iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<2x16x8xi16>, tensor<2x8x16xi64>)
+         outs(%Out : tensor<2x16x16xi32>) {
+   ^bb0(%in: i16, %in_0: i64, %out: i32):
+     %1 = arith.extui %in : i16 to i32
+     %2 = arith.trunci %in_0 : i64 to i32
+     %3 = arith.muli %1, %2 : i32
+     %4 = arith.addi %out, %3 : i32
+     linalg.yield %4 : i32
+   } -> tensor<2x16x16xi32>
+   return %0 : tensor<2x16x16xi32>
+}
+
+// CHECK-LABEL: op_batch_matmul_unsigned_cast
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.batch_matmul {cast = #linalg.type_fn<cast_unsigned>}
+
 // -----
 
 // This is a multi-reduction linalg.generic and cannot be lifted to matrix multiply



More information about the Mlir-commits mailing list