[Mlir-commits] [mlir] [MLIR][Linalg] Diagnose unsupported types in Linalg named op region builders (PR #181616)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 16 06:05:53 PST 2026
https://github.com/shubhamnarlawar updated https://github.com/llvm/llvm-project/pull/181616
>From 29958bd0ff3f78be411caa6ff34000ef2f20aa83 Mon Sep 17 00:00:00 2001
From: Shubham Narlawar <shubham.narlawar at rrlogic.co.in>
Date: Mon, 16 Feb 2026 14:31:30 +0530
Subject: [PATCH] [MLIR][Linalg] Diagnose unsupported types in Linalg named op
region builders
Plumb emitError callbacks through Linalg named op region builders so
RegionBuilderHelper emits diagnostics instead of hitting llvm_unreachable
for unsupported operand element types (e.g. amx.tile).
Update linalg/invalid.mlir to add functions - linalg.batch_matmul(), linalg.batch_reduce_matmul()
and linalg.matmul() with amx.tile operands to ensure mlir-opt fails gracefully without crash.
Fixes #179677
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 38 ++++++++++++++++--------
mlir/test/Dialect/Linalg/invalid.mlir | 30 +++++++++++++++++++
2 files changed, 55 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index eba3fa6db2126..36b7da3d311ef 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3873,11 +3873,11 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
}
Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
- block.getArgument(0));
+ block.getArgument(0), emitError);
Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
- block.getArgument(1));
+ block.getArgument(1), emitError);
Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError);
- if (!value3)
+ if (!value1 || !value2 || !value3)
return;
Value value4 = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
value3, emitError);
@@ -4646,11 +4646,18 @@ void BatchMatmulOp::regionBuilder(
}
auto toType = block.getArgument(2).getType();
- Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
- Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
- Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
- Value addVal =
- helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+ Value castValA =
+ helper.buildTypeFn(castVal, toType, block.getArgument(0), emitError);
+ Value castValB =
+ helper.buildTypeFn(castVal, toType, block.getArgument(1), emitError);
+ Value mulVal =
+ helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
+ if (!castValA || !castValB || !mulVal)
+ return;
+ Value addVal = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
+ mulVal, emitError);
+ if (!addVal)
+ return;
yields.push_back(addVal);
helper.yieldOutputs(yields);
}
@@ -6582,13 +6589,18 @@ void BatchReduceMatmulOp::regionBuilder(
SmallVector<Value> yields;
auto toType = block.getArgument(2).getType();
- Value castValA =
- helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
- Value castValB =
- helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
- Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
+ Value castValA = helper.buildTypeFn(TypeFn::cast_signed, toType,
+ block.getArgument(0), emitError);
+ Value castValB = helper.buildTypeFn(TypeFn::cast_signed, toType,
+ block.getArgument(1), emitError);
+ Value mulVal =
+ helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
+ if (!castValA || !castValB || !mulVal)
+ return;
Value addVal =
helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+ if (!addVal)
+ return;
yields.push_back(addVal);
helper.yieldOutputs(yields);
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 355d801f8732c..c42bc19055146 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -2128,3 +2128,33 @@ func.func @matmul_invalid_mixed_types(%t: tensor<?xf16>, %f: vector<4xf16>)
outs(%f : vector<4xf16>) -> tensor<?xf16>
func.return %0, %f : tensor<?xf16>, vector<4xf16>
}
+
+// -----
+
+func.func @batch_matmul_invalid_type()
+{
+ %0 = amx.tile_zero : !amx.tile<16x16xbf16>
+ // expected-error @below {{custom op 'linalg.batch_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}}
+ %1 = linalg.batch_matmul ins(%0, %0 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%0 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16>
+ return
+}
+
+// -----
+
+func.func @batch_reduce_matmul_invalid_type()
+{
+ %0 = amx.tile_zero : !amx.tile<16x16xbf16>
+ // expected-error @below {{custom op 'linalg.batch_reduce_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}}
+ %1 = linalg.batch_reduce_matmul ins(%0, %0 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%0 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16>
+ return
+}
+
+// -----
+
+func.func @matmul_invalid_type()
+{
+ %0 = amx.tile_zero : !amx.tile<16x16xbf16>
+ // expected-error @below {{custom op 'linalg.matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}}
+ %1 = linalg.matmul ins(%0, %0 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%0 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16>
+ return
+}
More information about the Mlir-commits
mailing list