[Mlir-commits] [mlir] [mlir][Linalg] Fix crash in buildBinaryFn on non-numeric types (PR #180594)
Guilherme oliveira de campos
llvmlistbot at llvm.org
Mon Feb 16 10:57:35 PST 2026
https://github.com/guiolidc updated https://github.com/llvm/llvm-project/pull/180594
>From ba3110ac160b7b445c70be21ec7c24f1d89ee095 Mon Sep 17 00:00:00 2001
From: Guilherme Oliveira de Campos <oliveira.gui at hotmail.com.br>
Date: Mon, 9 Feb 2026 16:08:31 -0300
Subject: [PATCH] [mlir][Linalg] Fix crash in buildBinaryFn on non-numeric
types
This prevents a crash when Linalg tries to build a binary region for unsupported types (like AMX tiles) by emitting a proper error diagnostic.
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 34 ++++++++++++++----------
mlir/test/Dialect/Linalg/invalid.mlir | 20 ++++++++++++++
2 files changed, 40 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index eba3fa6db2126..9e4dcbcf0dcd4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -506,14 +506,11 @@ class RegionBuilderHelper {
bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
arg1.getType().getIntOrFloatBitWidth() == 1;
if (!allComplex && !allFloatingPoint && !allInteger) {
- if (emitError) {
- emitError()
- << "Cannot build binary Linalg operation: expects allComplex, "
- "allFloatingPoint, or allInteger, got "
- << arg0.getType() << " and " << arg1.getType();
- return nullptr;
- }
- llvm_unreachable("unsupported non numeric type");
+ auto diag = emitError ? emitError() : mlir::emitError(arg0.getLoc());
+ diag << "Cannot build binary Linalg operation: expects allComplex, "
+ << "allFloatingPoint, or allInteger, got " << arg0.getType()
+ << " and " << arg1.getType();
+ return {};
}
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToEnd(&block);
@@ -4648,9 +4645,13 @@ void BatchMatmulOp::regionBuilder(
auto toType = block.getArgument(2).getType();
Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
- Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
+ Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
+ if (!mulVal)
+ return;
Value addVal =
- helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal, emitError);
+ if (!addVal)
+ return;
yields.push_back(addVal);
helper.yieldOutputs(yields);
}
@@ -4933,7 +4934,7 @@ void ElementwiseOp::regionBuilder(
} else if (arityGroup == ElementwiseArityGroup::Binary) {
result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
- block.getArgument(1));
+ block.getArgument(1), emitError);
} else if (arityGroup == ElementwiseArityGroup::Ternary) {
result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
@@ -4942,7 +4943,8 @@ void ElementwiseOp::regionBuilder(
} else {
assert(false && "found unhandled category in elemwise");
}
-
+ if (!result)
+ return;
yields.push_back(result);
helper.yieldOutputs(yields);
}
@@ -6586,9 +6588,13 @@ void BatchReduceMatmulOp::regionBuilder(
helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
Value castValB =
helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
- Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
+ Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
+ if (!mulVal)
+ return;
Value addVal =
- helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+ helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal, emitError);
+ if (!addVal)
+ return;
yields.push_back(addVal);
helper.yieldOutputs(yields);
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 355d801f8732c..32d274f1951bf 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -2128,3 +2128,23 @@ func.func @matmul_invalid_mixed_types(%t: tensor<?xf16>, %f: vector<4xf16>)
outs(%f : vector<4xf16>) -> tensor<?xf16>
func.return %0, %f : tensor<?xf16>, vector<4xf16>
}
+
+// -----
+
+func.func @batch_reduce_matmul_invalid_types() {
+ %0 = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %1 = "test.op"() {attr = #test.custom_float<"float" : 2.>} : () -> tensor<1xvector<4xf32>>
+ // expected-warning @unknown {{could not cast operand of type 'i32' to 'vector<4xf32>'}}
+ // expected-error @below {{custom op 'linalg.batch_reduce_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got 'vector<4xf32>' and 'i32'}}
+ %2 = linalg.batch_reduce_matmul ins(%0, %0 : tensor<1xi32>, tensor<1xi32>) outs(%1 : tensor<1xvector<4xf32>>) -> tensor<1xvector<4xf32>>
+ return
+}
+
+// -----
+
+func.func @elemwise_invalid_types(%arg0: tensor<4xi32>, %arg1: tensor<4xf32>) -> tensor<4xi32> {
+ %0 = "test.op"() {attr = #test.custom_float<"float" : 2.>} : () -> tensor<1xvector<4xi32>>
+ // expected-error @below {{custom op 'linalg.add' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got 'i32' and 'f32'}}
+ %1 = linalg.add ins(%arg0, %arg1 : tensor<4xi32>, tensor<4xf32>) outs(%0 : tensor<1xvector<4xi32>>) -> tensor<1xvector<4xi32>>
+ return %1
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list