[Mlir-commits] [mlir] [MLIR][Linalg] Diagnose unsupported types in Linalg named op region builders (PR #181616)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 16 02:57:53 PST 2026
https://github.com/shubhamnarlawar updated https://github.com/llvm/llvm-project/pull/181616
>From 2c94e15476b75b7198b4c26c9340f23879afe509 Mon Sep 17 00:00:00 2001
From: Shubham Narlawar <shubham.narlawar at rrlogic.co.in>
Date: Mon, 16 Feb 2026 14:31:30 +0530
Subject: [PATCH] [MLIR][Linalg] Diagnose unsupported types in Linalg named op
region builders
Plumb emitError callbacks through Linalg named op region builders so
RegionBuilderHelper emits diagnostics instead of hitting llvm_unreachable
for unsupported operand element types (e.g. amx.tile).
Update linalg/invalid.mlir to add functions - linalg.batch_matmul(), linalg.batch_reduce_matmul()
and linalg.matmul() with amx.tile operands to ensure mlir-opt fails gracefully without crash.
Fixes #179677
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 42 +++++++++++++++++-------
mlir/test/Dialect/Linalg/invalid.mlir | 42 ++++++++++++++++++++++++
2 files changed, 72 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index eba3fa6db2126..921a567dfe538 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3873,9 +3873,11 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
}
Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
- block.getArgument(0));
+ block.getArgument(0), emitError);
Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
- block.getArgument(1));
+ block.getArgument(1), emitError);
+ if (!value1 || !value2)
+ return;
Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError);
if (!value3)
return;
@@ -4646,11 +4648,20 @@ void BatchMatmulOp::regionBuilder(
}
auto toType = block.getArgument(2).getType();
- Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
- Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
- Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
- Value addVal =
- helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+ Value castValA =
+ helper.buildTypeFn(castVal, toType, block.getArgument(0), emitError);
+ Value castValB =
+ helper.buildTypeFn(castVal, toType, block.getArgument(1), emitError);
+ if (!castValA || !castValB)
+ return;
+ Value mulVal =
+ helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
+ if (!mulVal)
+ return;
+ Value addVal = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
+ mulVal, emitError);
+ if (!addVal)
+ return;
yields.push_back(addVal);
helper.yieldOutputs(yields);
}
@@ -6582,13 +6593,20 @@ void BatchReduceMatmulOp::regionBuilder(
SmallVector<Value> yields;
auto toType = block.getArgument(2).getType();
- Value castValA =
- helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
- Value castValB =
- helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
- Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
+ Value castValA = helper.buildTypeFn(TypeFn::cast_signed, toType,
+ block.getArgument(0), emitError);
+ Value castValB = helper.buildTypeFn(TypeFn::cast_signed, toType,
+ block.getArgument(1), emitError);
+ if (!castValA || !castValB)
+ return;
+ Value mulVal =
+ helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
+ if (!mulVal)
+ return;
Value addVal =
helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+ if (!addVal)
+ return;
yields.push_back(addVal);
helper.yieldOutputs(yields);
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 355d801f8732c..a2d36dca0f0b1 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -2128,3 +2128,45 @@ func.func @matmul_invalid_mixed_types(%t: tensor<?xf16>, %f: vector<4xf16>)
outs(%f : vector<4xf16>) -> tensor<?xf16>
func.return %0, %f : tensor<?xf16>, vector<4xf16>
}
+
+// -----
+
+func.func @batch_matmul_invalid_type()
+{
+ %0 = spirv.GroupNonUniformElect <Workgroup> : i1
+ %1 = amx.tile_zero : !amx.tile<16x16xbf16>
+ %2 = tensor.from_elements %0 : tensor<i1>
+ %3 = tosa.reciprocal %2 : (tensor<i1>) -> tensor<i1>
+ %4 = shape.const_shape [16, 16] : !shape.shape
+ // expected-error @below {{custom op 'linalg.batch_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}}
+ %5 = linalg.batch_matmul ins(%1, %1 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%1 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16>
+ return
+}
+
+// -----
+
+func.func @batch_reduce_matmul_invalid_type()
+{
+ %0 = spirv.GroupNonUniformElect <Workgroup> : i1
+ %1 = amx.tile_zero : !amx.tile<16x16xbf16>
+ %2 = tensor.from_elements %0 : tensor<i1>
+ %3 = tosa.reciprocal %2 : (tensor<i1>) -> tensor<i1>
+ %4 = shape.const_shape [16, 16] : !shape.shape
+ // expected-error @below {{custom op 'linalg.batch_reduce_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}}
+ %5 = linalg.batch_reduce_matmul ins(%1, %1 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%1 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16>
+ return
+}
+
+// -----
+
+func.func @matmul_invalid_type()
+{
+ %0 = spirv.GroupNonUniformElect <Workgroup> : i1
+ %1 = amx.tile_zero : !amx.tile<16x16xbf16>
+ %2 = tensor.from_elements %0 : tensor<i1>
+ %3 = tosa.reciprocal %2 : (tensor<i1>) -> tensor<i1>
+ %4 = shape.const_shape [16, 16] : !shape.shape
+ // expected-error @below {{custom op 'linalg.matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}}
+ %5 = linalg.matmul ins(%1, %1 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%1 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16>
+ return
+}
More information about the Mlir-commits
mailing list