[Mlir-commits] [mlir] [mlir][linalg] Preserve cast semantics during linalg.generic to matmul (PR #174757)

Prathamesh Tagore llvmlistbot at llvm.org
Wed Jan 7 04:11:44 PST 2026


https://github.com/meshtag created https://github.com/llvm/llvm-project/pull/174757

Infer signed/unsigned cast intent from cast ops in linalg.generic bodies and propagate it via the matmul cast attribute (defaulting to signed, switching to unsigned on first unsigned cast). Fixes a functional bug in #174517.

Fixes https://github.com/llvm/llvm-project/issues/174517

>From 4446ba774588cca0c30402c8bb765dbbb437af02 Mon Sep 17 00:00:00 2001
From: Prathamesh Tagore <prathameshtagore at gmail.com>
Date: Wed, 7 Jan 2026 17:37:22 +0530
Subject: [PATCH] [mlir][linalg] Preserve cast semantics during linalg.generic
 to matmul

Infer signed/unsigned cast intent from cast ops in linalg.generic bodies and
propagate it via the matmul cast attribute (defaulting to signed, switching
to unsigned on first unsigned cast). Fixes a functional bug in #174517.
---
 .../Dialect/Linalg/Transforms/Specialize.cpp  | 39 +++++++++++++++++--
 .../Linalg/specialize-generic-ops.mlir        | 25 ++++++++++++
 2 files changed, 60 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 0c7b998ffcab9..5d4feb8828732 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -135,13 +135,41 @@ static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
 //      `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
 //  have same number of ins/out, so its easy to stamp different versions.
 template <typename NamedOpTy>
-static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
+static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
+                                         std::optional<TypeFn> castTy) {
+  SmallVector<NamedAttribute> castAttrVec;
+  if (castTy.has_value() && *castTy == TypeFn::cast_unsigned)
+    castAttrVec = {rewriter.getNamedAttr(
+        "cast", TypeFnAttr::get(rewriter.getContext(), *castTy))};
+
   LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
       op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
-      ValueRange{op.getDpsInits()[0]});
+      ValueRange{op.getDpsInits()[0]}, castAttrVec);
   return namedOp;
 }
 
+// Determines the required cast type for the linalg.matmul op (if any) which is
+// expressed in the form of the input linalg.generic op.
+static std::optional<TypeFn> getMatmulCastTy(GenericOp genericOp) {
+  std::optional<TypeFn> castTy;
+  genericOp.getBody()->walk([&](CastOpInterface castOp) {
+    // If a cast op is present and castTy is not yet set, we conservatively set
+    // it for signed cast.
+    if (!castTy)
+      castTy = TypeFn::cast_signed;
+
+    // If we find even one unsigned cast, we set the castTy for unsigned cast
+    // and stop the walk.
+    if (isa<arith::ExtUIOp, arith::UIToFPOp, arith::FPToUIOp>(castOp)) {
+      castTy = TypeFn::cast_unsigned;
+      return WalkResult::interrupt();
+    }
+    return WalkResult::advance();
+  });
+
+  return castTy;
+}
+
 // Converts linalg.generic to named linalg.*matmul* where possible.
 static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
                                                         GenericOp genericOp) {
@@ -230,11 +258,14 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
       (a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
     return failure();
 
+  // Get the cast attribute for the named matmul op (if any).
+  std::optional<TypeFn> castTy = getMatmulCastTy(genericOp);
+
   /// Codegen the different matmul variants.
   if (numOfBatchDims) {
-    return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
+    return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
   }
-  return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
+  return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy);
 }
 
 /// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy`
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index cf495a7d29b70..3cff09cef6bba 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -124,3 +124,28 @@ func.func @op_matvec(%A: tensor<?x?xf32>, %B: tensor<?xf32>, %Out: tensor<?xf32>
 }
 // CHECK-LABEL: op_matvec
 // CHECK: linalg.generic
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_unsigned_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi64>,
+                                   %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+  %0 = linalg.generic
+         {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+         ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi64>) outs(%Out : tensor<16x32xi32>) {
+   ^bb0(%in: i16, %in_0: i64, %out: i32):
+     %1 = arith.extui %in : i16 to i32
+     %2 = arith.trunci %in_0 : i64 to i32
+     %3 = arith.muli %1, %2 : i32
+     %4 = arith.addi %out, %3 : i32
+     linalg.yield %4 : i32
+   } -> tensor<16x32xi32>
+   return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_unsigned_cast
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+



More information about the Mlir-commits mailing list