[Mlir-commits] [mlir] [mlir][linalg] Reject matmul specialization when generic uses bitcast (PR #182705)

Prathamesh Tagore llvmlistbot at llvm.org
Sat Feb 21 14:10:43 PST 2026


https://github.com/meshtag created https://github.com/llvm/llvm-project/pull/182705

linalg-specialize-generic-ops currently allows matmul-like specialization even when the generic body contains arith.bitcast. The matmul cast attribute cannot represent bit-level reinterpretation semantics, so this can lose information across specialization/generalization.

Fixes https://github.com/llvm/llvm-project/issues/177593

>From 808d59fb9a78e0580da32fd5c9f7a4c6be91be92 Mon Sep 17 00:00:00 2001
From: Prathamesh Tagore <prathameshtagore at gmail.com>
Date: Sat, 21 Feb 2026 22:49:14 +0100
Subject: [PATCH] [mlir][linalg] Reject matmul specialization when generic uses
 bitcast

linalg-specialize-generic-ops currently allows matmul-like specialization
even when the generic body contains arith.bitcast. The matmul cast attribute
cannot represent bit-level reinterpretation semantics, so this can lose
information across specialization/generalization.
---
 mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp    | 11 ++++++++++-
 mlir/test/Dialect/Linalg/specialize-generic-ops.mlir | 12 +++++-------
 2 files changed, 15 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index a71f84dee3bb0..bdb33c833c829 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -157,9 +157,18 @@ static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op,
 // contains casts that cannot be represented (e.g. output casts or mixed
 // signedness), return std::nullopt.
 static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
+  // In addition to output casts, matmul-like named ops cannot represent bit
+  // level casts.
+  bool foundBitCastOp = false;
   bool foundCastForMatmulOutput = false;
   SmallVector<TypeFn> castTyFns;
   genericOp.getBody()->walk([&](CastOpInterface castOp) {
+    // Early return if we encounter a bitcast op.
+    if (isa<arith::BitcastOp>(castOp)) {
+      foundBitCastOp = true;
+      return WalkResult::interrupt();
+    }
+
     // Collect forward slice of the cast op to check if it is for the matmul
     // output.
     SetVector<Operation *> forwardSlice;
@@ -186,7 +195,7 @@ static std::optional<TypeFn> getCastTypeForMatmulLikeOp(GenericOp genericOp) {
     return WalkResult::advance();
   });
 
-  if (foundCastForMatmulOutput)
+  if (foundBitCastOp || foundCastForMatmulOutput)
     return std::nullopt;
 
   if (!castTyFns.empty()) {
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
index 6acf1ca0d4e30..da4b307f12fa7 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops.mlir
@@ -171,11 +171,9 @@ func.func @negative_op_matmul_output_cast(%A: tensor<16x8xi32>, %B: tensor<8x32x
 // CHECK: linalg.generic
 // CHECK-NOT: linalg.matmul
 
-// Bitcasts are not modeled by the cast attribute, but should not block
-// specialization.
-// NOTE: Bitcasts are not preserved by the matmul named op during
-// roundtrip, so this is potentially loosing information here.
-// See #177593 for more details.
+// Bitcasts are not modeled by the cast attribute, and would lose information
+// when roundtripped through the matmul named op (sitofp will be emitted in
+// this case), so we do not allow them for specialization.
 func.func @op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
                                           %B: tensor<8x32xi32>,
                                           %Out: tensor<16x32xf32>) -> tensor<16x32xf32> {
@@ -193,8 +191,8 @@ func.func @op_matmul_bitcast_int_to_float(%A: tensor<16x8xi32>,
 }
 
 // CHECK-LABEL: op_matmul_bitcast_int_to_float
-// CHECK-NOT: linalg.generic
-// CHECK: linalg.matmul
+// CHECK:     linalg.generic
+// CHECK-NOT: linalg.matmul
 
 // Signed float casts only use sitofp, which defaults to signed semantics.
 func.func @op_matmul_signed_cast_float(%A: tensor<16x8xi16>, %B: tensor<8x32xi16>,



More information about the Mlir-commits mailing list