[Mlir-commits] [mlir] [mlir][Linalg] Fix crash in buildBinaryFn on non-numeric types (PR #180594)

Guilherme oliveira de campos llvmlistbot at llvm.org
Mon Feb 16 10:57:35 PST 2026


https://github.com/guiolidc updated https://github.com/llvm/llvm-project/pull/180594

>From ba3110ac160b7b445c70be21ec7c24f1d89ee095 Mon Sep 17 00:00:00 2001
From: Guilherme Oliveira de Campos <oliveira.gui at hotmail.com.br>
Date: Mon, 9 Feb 2026 16:08:31 -0300
Subject: [PATCH] [mlir][Linalg] Fix crash in buildBinaryFn on non-numeric
 types

This prevents a crash when Linalg tries to build a binary region for unsupported types (like AMX tiles) by emitting a proper error diagnostic.
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 34 ++++++++++++++----------
 mlir/test/Dialect/Linalg/invalid.mlir    | 20 ++++++++++++++
 2 files changed, 40 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index eba3fa6db2126..9e4dcbcf0dcd4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -506,14 +506,11 @@ class RegionBuilderHelper {
     bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
                    arg1.getType().getIntOrFloatBitWidth() == 1;
     if (!allComplex && !allFloatingPoint && !allInteger) {
-      if (emitError) {
-        emitError()
-            << "Cannot build binary Linalg operation: expects allComplex, "
-               "allFloatingPoint, or allInteger, got "
-            << arg0.getType() << " and " << arg1.getType();
-        return nullptr;
-      }
-      llvm_unreachable("unsupported non numeric type");
+      auto diag = emitError ? emitError() : mlir::emitError(arg0.getLoc());
+      diag << "Cannot build binary Linalg operation: expects allComplex, "
+           << "allFloatingPoint, or allInteger, got " << arg0.getType()
+           << " and " << arg1.getType();
+      return {};
     }
     OpBuilder::InsertionGuard g(builder);
     builder.setInsertionPointToEnd(&block);
@@ -4648,9 +4645,13 @@ void BatchMatmulOp::regionBuilder(
   auto toType = block.getArgument(2).getType();
   Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
   Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
-  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
+  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
+  if (!mulVal)
+    return;
   Value addVal =
-      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal, emitError);
+  if (!addVal)
+    return;
   yields.push_back(addVal);
   helper.yieldOutputs(yields);
 }
@@ -4933,7 +4934,7 @@ void ElementwiseOp::regionBuilder(
 
   } else if (arityGroup == ElementwiseArityGroup::Binary) {
     result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
-                                  block.getArgument(1));
+                                  block.getArgument(1), emitError);
 
   } else if (arityGroup == ElementwiseArityGroup::Ternary) {
     result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
@@ -4942,7 +4943,8 @@ void ElementwiseOp::regionBuilder(
   } else {
     assert(false && "found unhandled category in elemwise");
   }
-
+  if (!result)
+    return;
   yields.push_back(result);
   helper.yieldOutputs(yields);
 }
@@ -6586,9 +6588,13 @@ void BatchReduceMatmulOp::regionBuilder(
       helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
   Value castValB =
       helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
-  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
+  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
+  if (!mulVal)
+    return;
   Value addVal =
-      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal, emitError);
+  if (!addVal)
+    return;
   yields.push_back(addVal);
   helper.yieldOutputs(yields);
 }
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 355d801f8732c..32d274f1951bf 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -2128,3 +2128,23 @@ func.func @matmul_invalid_mixed_types(%t: tensor<?xf16>, %f: vector<4xf16>)
                                 outs(%f : vector<4xf16>) -> tensor<?xf16>
   func.return %0, %f : tensor<?xf16>, vector<4xf16>
 }
+
+// -----
+
+func.func @batch_reduce_matmul_invalid_types() {
+  %0 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %1 = "test.op"() {attr = #test.custom_float<"float" : 2.>} : () -> tensor<1xvector<4xf32>>
+  // expected-warning @unknown {{could not cast operand of type 'i32' to 'vector<4xf32>'}}
+  // expected-error @below {{custom op 'linalg.batch_reduce_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got 'vector<4xf32>' and 'i32'}}
+  %2 = linalg.batch_reduce_matmul ins(%0, %0 : tensor<1xi32>, tensor<1xi32>) outs(%1 : tensor<1xvector<4xf32>>) -> tensor<1xvector<4xf32>>
+  return
+}
+
+// -----
+
+func.func @elemwise_invalid_types(%arg0: tensor<4xi32>, %arg1: tensor<4xf32>) -> tensor<4xi32> {
+  %0 = "test.op"() {attr = #test.custom_float<"float" : 2.>} : () -> tensor<1xvector<4xi32>>
+  // expected-error @below {{custom op 'linalg.add' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got 'i32' and 'f32'}}
+  %1 = linalg.add ins(%arg0, %arg1 : tensor<4xi32>, tensor<4xf32>) outs(%0 : tensor<1xvector<4xi32>>) -> tensor<1xvector<4xi32>>
+  return %1
+}
\ No newline at end of file



More information about the Mlir-commits mailing list