[Mlir-commits] [mlir] [mlir][linalg] Fix linalg.select crash with index type operands (PR #179056)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jan 31 09:54:17 PST 2026
https://github.com/mugiwaraluffy56 updated https://github.com/llvm/llvm-project/pull/179056
>From 9f43f458510c9630a0bd48e12afd2cae667df79e Mon Sep 17 00:00:00 2001
From: mugiwaraluffy56 <myakampuneeth at gmail.com>
Date: Sat, 31 Jan 2026 23:08:50 +0530
Subject: [PATCH] [mlir][linalg] Fix linalg.select crash with index type
operands
The buildTernaryFn function in RegionBuilderHelper only handled integer
and floating point types, causing an UNREACHABLE when linalg.select was
used with index type operands.
This patch adds support for index types by:
1. Adding isIndex() helper function
2. Adding tailIndex check in buildTernaryFn
3. Including tailIndex in the type validation condition
Fixes #179046
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 4 +++-
mlir/test/Dialect/Linalg/named-ops.mlir | 13 +++++++++++++
2 files changed, 16 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 0f0e308bba78e..e26c08f4cd6ac 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -617,11 +617,12 @@ class RegionBuilderHelper {
bool tailFloatingPoint =
isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
+ bool tailIndex = isIndex(arg0) && isIndex(arg1) && isIndex(arg2);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToEnd(&block);
switch (ternaryFn) {
case TernaryFn::select:
- if (!headBool && !(tailFloatingPoint || tailInteger))
+ if (!headBool && !(tailFloatingPoint || tailInteger || tailIndex))
llvm_unreachable("unsupported non numeric type");
return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
}
@@ -705,6 +706,7 @@ class RegionBuilderHelper {
bool isInteger(Value value) {
return llvm::isa<IntegerType>(value.getType());
}
+ bool isIndex(Value value) { return llvm::isa<IndexType>(value.getType()); }
OpBuilder &builder;
Block █
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 1e356c8fb4e72..ee4128570d3c0 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -2718,6 +2718,19 @@ func.func @select_tensor(%arg0: tensor<4x8x16xi1>, %arg1: tensor<4x8x16xf32>, %a
return %1 : tensor<4x8x16xf32>
}
+// -----
+
+// GH#179046: Test linalg.select with index type values (condition must be i1).
+// CHECK-LABEL: func @select_index
+func.func @select_index(%arg0: tensor<4x8x16xi1>, %arg1: tensor<4x8x16xindex>, %arg2: tensor<4x8x16xindex>) -> tensor<4x8x16xindex> {
+ %0 = tensor.empty() : tensor<4x8x16xindex>
+ // CHECK: linalg.select
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : tensor<4x8x16xi1>, tensor<4x8x16xindex>, tensor<4x8x16xindex>)
+ // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xindex>)
+ %1 = linalg.select ins(%arg0, %arg1, %arg2 : tensor<4x8x16xi1>, tensor<4x8x16xindex>, tensor<4x8x16xindex>) outs(%0: tensor<4x8x16xindex>) -> tensor<4x8x16xindex>
+ return %1 : tensor<4x8x16xindex>
+}
+
//===----------------------------------------------------------------------===//
// linalg.pack + linalg.unpack
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list