[Mlir-commits] [mlir] [MLIR][Linalg] Diagnose unsupported types in Linalg named op region builders (PR #181616)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 16 02:57:53 PST 2026


https://github.com/shubhamnarlawar updated https://github.com/llvm/llvm-project/pull/181616

>From 2c94e15476b75b7198b4c26c9340f23879afe509 Mon Sep 17 00:00:00 2001
From: Shubham Narlawar <shubham.narlawar at rrlogic.co.in>
Date: Mon, 16 Feb 2026 14:31:30 +0530
Subject: [PATCH] [MLIR][Linalg] Diagnose unsupported types in Linalg named op
 region builders

Plumb emitError callbacks through Linalg named op region builders so
RegionBuilderHelper emits diagnostics instead of hitting llvm_unreachable
for unsupported operand element types (e.g. amx.tile).

Update linalg/invalid.mlir to add functions - linalg.batch_matmul(), linalg.batch_reduce_matmul()
and linalg.matmul() with amx.tile operands to ensure mlir-opt fails gracefully without crash.

Fixes #179677
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 42 +++++++++++++++++-------
 mlir/test/Dialect/Linalg/invalid.mlir    | 42 ++++++++++++++++++++++++
 2 files changed, 72 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index eba3fa6db2126..921a567dfe538 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3873,9 +3873,11 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
   }
 
   Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
-                                    block.getArgument(0));
+                                    block.getArgument(0), emitError);
   Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
-                                    block.getArgument(1));
+                                    block.getArgument(1), emitError);
+  if (!value1 || !value2)
+    return;
   Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError);
   if (!value3)
     return;
@@ -4646,11 +4648,20 @@ void BatchMatmulOp::regionBuilder(
   }
 
   auto toType = block.getArgument(2).getType();
-  Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
-  Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
-  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
-  Value addVal =
-      helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+  Value castValA =
+      helper.buildTypeFn(castVal, toType, block.getArgument(0), emitError);
+  Value castValB =
+      helper.buildTypeFn(castVal, toType, block.getArgument(1), emitError);
+  if (!castValA || !castValB)
+    return;
+  Value mulVal =
+      helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
+  if (!mulVal)
+    return;
+  Value addVal = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
+                                      mulVal, emitError);
+  if (!addVal)
+    return;
   yields.push_back(addVal);
   helper.yieldOutputs(yields);
 }
@@ -6582,13 +6593,20 @@ void BatchReduceMatmulOp::regionBuilder(
   SmallVector<Value> yields;
 
   auto toType = block.getArgument(2).getType();
-  Value castValA =
-      helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
-  Value castValB =
-      helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
-  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
+  Value castValA = helper.buildTypeFn(TypeFn::cast_signed, toType,
+                                      block.getArgument(0), emitError);
+  Value castValB = helper.buildTypeFn(TypeFn::cast_signed, toType,
+                                      block.getArgument(1), emitError);
+  if (!castValA || !castValB)
+    return;
+  Value mulVal =
+      helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
+  if (!mulVal)
+    return;
   Value addVal =
       helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
+  if (!addVal)
+    return;
   yields.push_back(addVal);
   helper.yieldOutputs(yields);
 }
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 355d801f8732c..a2d36dca0f0b1 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -2128,3 +2128,45 @@ func.func @matmul_invalid_mixed_types(%t: tensor<?xf16>, %f: vector<4xf16>)
                                 outs(%f : vector<4xf16>) -> tensor<?xf16>
   func.return %0, %f : tensor<?xf16>, vector<4xf16>
 }
+
+// -----
+
+func.func @batch_matmul_invalid_type()
+{
+  %0 = spirv.GroupNonUniformElect <Workgroup> : i1
+  %1 = amx.tile_zero : !amx.tile<16x16xbf16>
+  %2 = tensor.from_elements %0 : tensor<i1>
+  %3 = tosa.reciprocal %2 : (tensor<i1>) -> tensor<i1>
+  %4 = shape.const_shape [16, 16] : !shape.shape
+  // expected-error @below {{custom op 'linalg.batch_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}}
+  %5 = linalg.batch_matmul ins(%1, %1 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%1 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16>
+  return
+}
+
+// -----
+
+func.func @batch_reduce_matmul_invalid_type()
+{
+  %0 = spirv.GroupNonUniformElect <Workgroup> : i1
+  %1 = amx.tile_zero : !amx.tile<16x16xbf16>
+  %2 = tensor.from_elements %0 : tensor<i1>
+  %3 = tosa.reciprocal %2 : (tensor<i1>) -> tensor<i1>
+  %4 = shape.const_shape [16, 16] : !shape.shape
+  // expected-error @below {{custom op 'linalg.batch_reduce_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}}
+  %5 = linalg.batch_reduce_matmul ins(%1, %1 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%1 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16>
+  return
+}
+
+// -----
+
+func.func @matmul_invalid_type()
+{
+  %0 = spirv.GroupNonUniformElect <Workgroup> : i1
+  %1 = amx.tile_zero : !amx.tile<16x16xbf16>
+  %2 = tensor.from_elements %0 : tensor<i1>
+  %3 = tosa.reciprocal %2 : (tensor<i1>) -> tensor<i1>
+  %4 = shape.const_shape [16, 16] : !shape.shape
+  // expected-error @below {{custom op 'linalg.matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}}
+  %5 = linalg.matmul ins(%1, %1 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%1 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16>
+  return
+}



More information about the Mlir-commits mailing list