[Mlir-commits] [mlir] [mlir][linalg] Fix linalg.select crash with index type operands (PR #179056)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jan 31 09:45:03 PST 2026
https://github.com/mugiwaraluffy56 updated https://github.com/llvm/llvm-project/pull/179056
>From a8f21eb642377f75d33d2eb5e5b9c9d67cb32329 Mon Sep 17 00:00:00 2001
From: mugiwaraluffy56 <myakampuneeth at gmail.com>
Date: Sat, 31 Jan 2026 23:08:50 +0530
Subject: [PATCH] [mlir][linalg] Fix linalg.select crash with index type
operands
The buildTernaryFn function in RegionBuilderHelper only handled integer
and floating point types, causing an UNREACHABLE when linalg.select was
used with index type operands.
This patch adds support for index types by:
1. Adding isIndex() helper function
2. Adding tailIndex check in buildTernaryFn
3. Including tailIndex in the type validation condition
Fixes #179046
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 4 +++-
mlir/test/Dialect/Linalg/named-ops.mlir | 13 +++++++++++++
2 files changed, 16 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 0f0e308bba78e..e26c08f4cd6ac 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -617,11 +617,12 @@ class RegionBuilderHelper {
bool tailFloatingPoint =
isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
+ bool tailIndex = isIndex(arg0) && isIndex(arg1) && isIndex(arg2);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToEnd(&block);
switch (ternaryFn) {
case TernaryFn::select:
- if (!headBool && !(tailFloatingPoint || tailInteger))
+ if (!headBool && !(tailFloatingPoint || tailInteger || tailIndex))
llvm_unreachable("unsupported non numeric type");
return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
}
@@ -705,6 +706,7 @@ class RegionBuilderHelper {
bool isInteger(Value value) {
return llvm::isa<IntegerType>(value.getType());
}
+ bool isIndex(Value value) { return llvm::isa<IndexType>(value.getType()); }
OpBuilder &builder;
Block █
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 1e356c8fb4e72..a84519f31d368 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -2718,6 +2718,19 @@ func.func @select_tensor(%arg0: tensor<4x8x16xi1>, %arg1: tensor<4x8x16xf32>, %a
return %1 : tensor<4x8x16xf32>
}
+// -----
+
+// GH#179046: Test linalg.select with index type operands.
+// CHECK-LABEL: func @select_index
+func.func @select_index(%arg0: tensor<4x8x16xindex>, %arg1: tensor<4x8x16xindex>, %arg2: tensor<4x8x16xindex>) -> tensor<4x8x16xindex> {
+ %0 = tensor.empty() : tensor<4x8x16xindex>
+ // CHECK: linalg.select
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : tensor<4x8x16xindex>, tensor<4x8x16xindex>, tensor<4x8x16xindex>)
+ // CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xindex>)
+ %1 = linalg.select ins(%arg0, %arg1, %arg2 : tensor<4x8x16xindex>, tensor<4x8x16xindex>, tensor<4x8x16xindex>) outs(%0: tensor<4x8x16xindex>) -> tensor<4x8x16xindex>
+ return %1 : tensor<4x8x16xindex>
+}
+
//===----------------------------------------------------------------------===//
// linalg.pack + linalg.unpack
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list