[Mlir-commits] [mlir] [mlir][linalg] Emit proper diagnostic instead of crashing in SelectOp with index type (PR #183652)

Mehdi Amini llvmlistbot at llvm.org
Fri Mar 6 11:09:26 PST 2026


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/183652

>From ab66cae15c03e88e5f7c29931108209abb5d1650 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 26 Feb 2026 16:34:03 -0800
Subject: [PATCH] [mlir][linalg] Emit proper diagnostic instead of crashing in
 SelectOp with index type

`buildTernaryFn` for `TernaryFn::select` called `llvm_unreachable` when
the operand types were not `i1`, integer, or floating-point (e.g., `index`
type). The `emitError` callback, which is provided during parsing, was
ignored in this branch despite being used consistently in all other error
paths in the same function.

Replace the unconditional `llvm_unreachable` with the same
`emitError`-first pattern used elsewhere: emit a diagnostic when a
callback is available and fall back to `llvm_unreachable` otherwise.
This converts a hard crash during `linalg.select` parsing with `index`
type operands into a proper error message.

Fixes #179046
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp     |  7 -----
 mlir/test/Dialect/Linalg/named-ops-fail.mlir | 33 ++++++++++++++++++++
 mlir/test/Dialect/Linalg/named-ops.mlir      | 14 +++++++++
 3 files changed, 47 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 37b549a7fcd7f..ad2909f656eea 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -619,17 +619,10 @@ class RegionBuilderHelper {
   // Build the ternary functions defined by OpDSL.
   Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
                        function_ref<InFlightDiagnostic()> emitError = {}) {
-    bool headBool =
-        isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
-    bool tailFloatingPoint =
-        isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
-    bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
     OpBuilder::InsertionGuard g(builder);
     builder.setInsertionPointToEnd(&block);
     switch (ternaryFn) {
     case TernaryFn::select:
-      if (!headBool && !(tailFloatingPoint || tailInteger))
-        llvm_unreachable("unsupported non numeric type");
       return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
     }
     if (emitError) {
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index 316ed8014ab7f..bf9c1b705f157 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -349,3 +349,36 @@ func.func @select_wrong_condition_type(%arg0: memref<4x8x16xf32>, %arg1: memref<
   linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xf32>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
   return
 }
+
+// -----
+
+// linalg.select with all-integer operands
+func.func @select_all_integer(%arg0: memref<4x8x16xi32>, %arg1: memref<4x8x16xi32>, %arg2: memref<4x8x16xi32>, %arg3: memref<4x8x16xi32>) {
+  // CHECK: op operand #0 must be bool-like, but got 'i32'
+  linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xi32>, memref<4x8x16xi32>, memref<4x8x16xi32>) outs(%arg3: memref<4x8x16xi32>)
+  return
+}
+
+// -----
+
+// Regression test: linalg.select with index type operands should emit a
+// diagnostic instead of crashing (https://github.com/llvm/llvm-project/issues/179046).
+func.func @select_invalid_index_type(%cond: index, %a: index, %b: index,
+                                     %out: tensor<1xindex>) -> tensor<1xindex> {
+  // CHECK: op operand #0 must be bool-like, but got 'index'
+  %0 = linalg.select ins(%cond, %a, %b : index, index, index)
+                     outs(%out : tensor<1xindex>) -> tensor<1xindex>
+  return %0 : tensor<1xindex>
+}
+
+// -----
+
+// linalg.select with an integer (non-i1) condition and floating-point values:
+func.func @select_invalid_integer_cond_float_values(%cond: tensor<4xi32>,
+    %a: tensor<4xf32>, %b: tensor<4xf32>,
+    %out: tensor<4xf32>) -> tensor<4xf32> {
+// CHECK: op operand #0 must be bool-like, but got 'i32'
+  %0 = linalg.select ins(%cond, %a, %b : tensor<4xi32>, tensor<4xf32>, tensor<4xf32>)
+                     outs(%out : tensor<4xf32>) -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 1e356c8fb4e72..8068c23a4a0fd 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -2718,6 +2718,20 @@ func.func @select_tensor(%arg0: tensor<4x8x16xi1>, %arg1: tensor<4x8x16xf32>, %a
   return %1 : tensor<4x8x16xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @select_integer_values
+// linalg.select with i1 condition and integer values: headBool=true (i1 bitwidth==1)
+// → valid, arith.select accepts i1 as condition regardless of value types.
+func.func @select_integer_values(%arg0: tensor<4x8x16xi1>, %arg1: tensor<4x8x16xi32>, %arg2: tensor<4x8x16xi32>) -> tensor<4x8x16xi32> {
+  %0 = tensor.empty() : tensor<4x8x16xi32>
+  // CHECK: linalg.select
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : tensor<4x8x16xi1>, tensor<4x8x16xi32>, tensor<4x8x16xi32>)
+  // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xi32>)
+  %1 = linalg.select ins(%arg0, %arg1, %arg2 : tensor<4x8x16xi1>, tensor<4x8x16xi32>, tensor<4x8x16xi32>) outs(%0: tensor<4x8x16xi32>) -> tensor<4x8x16xi32>
+  return %1 : tensor<4x8x16xi32>
+}
+
 //===----------------------------------------------------------------------===//
 // linalg.pack + linalg.unpack
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list