[Mlir-commits] [mlir] [mlir][linalg] Preserve cast semantics during linalg.generic to matmul (PR #174757)
Prathamesh Tagore
llvmlistbot at llvm.org
Wed Jan 7 04:11:44 PST 2026
https://github.com/meshtag created https://github.com/llvm/llvm-project/pull/174757
Infer signed/unsigned cast intent from cast ops in linalg.generic bodies and propagate it via the matmul cast attribute (defaulting to signed, switching to unsigned on first unsigned cast). Fixes a functional bug in #174517.
Fixes https://github.com/llvm/llvm-project/issues/174517
>From 4446ba774588cca0c30402c8bb765dbbb437af02 Mon Sep 17 00:00:00 2001
From: Prathamesh Tagore <prathameshtagore at gmail.com>
Date: Wed, 7 Jan 2026 17:37:22 +0530
Subject: [PATCH] [mlir][linalg] Preserve cast semantics during linalg.generic
to matmul
Infer signed/unsigned cast intent from cast ops in linalg.generic bodies and
propagate it via the matmul cast attribute (defaulting to signed, switching
to unsigned on first unsigned cast). Fixes a functional bug in #174517.
---
.../Dialect/Linalg/Transforms/Specialize.cpp | 39 +++++++++++++++++--
.../Linalg/specialize-generic-ops.mlir | 25 ++++++++++++
2 files changed, 60 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 0c7b998ffcab9..5d4feb8828732 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -135,13 +135,41 @@ static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
// have same number of ins/out, so its easy to stamp different versions.
template <typename NamedOpTy>
-static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
+static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
+ std::optional<TypeFn> castTy) {
+ SmallVector<NamedAttribute> castAttrVec;
+ if (castTy.has_value() && *castTy == TypeFn::cast_unsigned)
+ castAttrVec = {rewriter.getNamedAttr(
+ "cast", TypeFnAttr::get(rewriter.getContext(), *castTy))};
+
LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
- ValueRange{op.getDpsInits()[0]});
+ ValueRange{op.getDpsInits()[0]}, castAttrVec);
return namedOp;
}
+// Determines the required cast type for the linalg.matmul op (if any) which is
+// expressed in the form of the input linalg.generic op.
+static std::optional<TypeFn> getMatmulCastTy(GenericOp genericOp) {
+ std::optional<TypeFn> castTy;
+ genericOp.getBody()->walk([&](CastOpInterface castOp) {
+ // If a cast op is present and castTy is not yet set, we conservatively set
+ // it for signed cast.
+ if (!castTy)
+ castTy = TypeFn::cast_signed;
+
+ // If we find even one unsigned cast, we set the castTy for unsigned cast
+ // and stop the walk.
+ if (isa<arith::ExtUIOp, arith::UIToFPOp, arith::FPToUIOp>(castOp)) {
+ castTy = TypeFn::cast_unsigned;
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+
+ return castTy;
+}
+
// Converts linalg.generic to named linalg.*matmul* where possible.
static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
GenericOp genericOp) {
@@ -230,11 +258,14 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
(a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
return failure();
+ // Get the cast attribute for the named matmul op (if any).
+ std::optional<TypeFn> castTy = getMatmulCastTy(genericOp);
+
/// Codegen the different matmul variants.
if (numOfBatchDims) {
- return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
+ return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp, castTy);
}
- return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
+ return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp, castTy);
}
/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy`
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index cf495a7d29b70..3cff09cef6bba 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -124,3 +124,28 @@ func.func @op_matvec(%A: tensor<?x?xf32>, %B: tensor<?xf32>, %Out: tensor<?xf32>
}
// CHECK-LABEL: op_matvec
// CHECK: linalg.generic
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @op_matmul_unsigned_cast(%A: tensor<16x8xi16>, %B: tensor<8x32xi64>,
+ %Out: tensor<16x32xi32>) -> tensor<16x32xi32> {
+ %0 = linalg.generic
+ {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+ ins(%A, %B : tensor<16x8xi16>, tensor<8x32xi64>) outs(%Out : tensor<16x32xi32>) {
+ ^bb0(%in: i16, %in_0: i64, %out: i32):
+ %1 = arith.extui %in : i16 to i32
+ %2 = arith.trunci %in_0 : i64 to i32
+ %3 = arith.muli %1, %2 : i32
+ %4 = arith.addi %out, %3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<16x32xi32>
+ return %0 : tensor<16x32xi32>
+}
+
+// CHECK-LABEL: op_matmul_unsigned_cast
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.matmul {cast = #linalg.type_fn<cast_unsigned>}
+
More information about the Mlir-commits
mailing list