[llvm] [InstCombine] Avoid folding `select(umin(X, Y), X)` with non-constant mask (PR #143020)

Konstantin Bogdanov via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 5 12:04:13 PDT 2025


https://github.com/thevar1able updated https://github.com/llvm/llvm-project/pull/143020

>From 41a9445853d8b68a9936a4d94bbc48de1d06aca4 Mon Sep 17 00:00:00 2001
From: Konstantin Bogdanov <konstantin at clickhouse.com>
Date: Thu, 5 Jun 2025 20:47:32 +0200
Subject: [PATCH 1/2] Prototype fix

---
 .../InstCombine/InstCombineCalls.cpp          | 19 +++++++++++++++++++
 llvm/test/Transforms/InstCombine/select.ll    | 15 +++++++++++++++
 2 files changed, 34 insertions(+)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index cfb4af391b540..930819d24393a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1654,6 +1654,25 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
   if (Value *FreedOp = getFreedOperand(&CI, &TLI))
     return visitFree(CI, FreedOp);
 
+  if (Function *F = CI.getCalledFunction()) {
+    if (F->getIntrinsicID() == Intrinsic::umin || F->getIntrinsicID() == Intrinsic::umax) {
+      for (Value *Arg : CI.args()) {
+        auto *SI = dyn_cast<SelectInst>(Arg);
+        if (!SI)
+          continue;
+
+        auto *TrueC = dyn_cast<Constant>(SI->getTrueValue());
+        auto *FalseC = dyn_cast<Constant>(SI->getFalseValue());
+
+        // Block only if the select is masking, e.g. select(cond, val, -1)
+        if ((TrueC && TrueC->isAllOnesValue()) || (FalseC && FalseC->isAllOnesValue())) {
+          LLVM_DEBUG(dbgs() << "InstCombine: skipping umin/umax folding for masked select\n");
+          return nullptr;
+        }
+      }
+    }
+  }
+
   // If the caller function (i.e. us, the function that contains this CallInst)
   // is nounwind, mark the call as nounwind, even if the callee isn't.
   if (CI.getFunction()->doesNotThrow() && !CI.doesNotThrow()) {
diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index e16f6ad2cfc9b..09cb84cde07ca 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -5047,3 +5047,18 @@ define <2 x ptr> @select_freeze_constant_expression_vector_gep(i1 %cond, <2 x pt
   %sel = select i1 %cond, <2 x ptr> %y, <2 x ptr> %freeze
   ret <2 x ptr> %sel
 }
+
+declare i8 @llvm.umin.i8(i8, i8)
+
+define i8 @no_fold_masked_min(i8 %acc, i8 %val, i8 %mask) {
+; CHECK-LABEL: @no_fold_masked_min(
+; CHECK-NEXT:  [[COND:%.*]] = icmp eq i8 [[MASK:%.*]], 0
+; CHECK-NEXT:  [[MASKED_VAL:%.*]] = select i1 [[COND:%.*]], i8 [[VAL:%.*]], i8 -1
+; CHECK-NEXT:  [[RES:%.*]] = call i8 @llvm.umin.i8(i8 [[ACC:%.*]], i8 [[MASKED_VAL:%.*]])
+; CHECK-NEXT:  ret i8 [[RES]]
+;
+  %cond = icmp eq i8 %mask, 0
+  %masked_val = select i1 %cond, i8 %val, i8 -1
+  %res = call i8 @llvm.umin.i8(i8 %acc, i8 %masked_val)
+  ret i8 %res
+}

>From a19889c3952ed70d93aa40e36e38404e531f1ce1 Mon Sep 17 00:00:00 2001
From: Konstantin Bogdanov <konstantin at clickhouse.com>
Date: Thu, 5 Jun 2025 21:03:53 +0200
Subject: [PATCH 2/2] Fix formatting

---
 llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp | 10 +++++++---
 1 file changed, 7 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 930819d24393a..0e5c95c7445dd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1655,7 +1655,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     return visitFree(CI, FreedOp);
 
   if (Function *F = CI.getCalledFunction()) {
-    if (F->getIntrinsicID() == Intrinsic::umin || F->getIntrinsicID() == Intrinsic::umax) {
+    if (F->getIntrinsicID() == Intrinsic::umin ||
+        F->getIntrinsicID() == Intrinsic::umax) {
       for (Value *Arg : CI.args()) {
         auto *SI = dyn_cast<SelectInst>(Arg);
         if (!SI)
@@ -1665,8 +1666,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
         auto *FalseC = dyn_cast<Constant>(SI->getFalseValue());
 
         // Block only if the select is masking, e.g. select(cond, val, -1)
-        if ((TrueC && TrueC->isAllOnesValue()) || (FalseC && FalseC->isAllOnesValue())) {
-          LLVM_DEBUG(dbgs() << "InstCombine: skipping umin/umax folding for masked select\n");
+        if ((TrueC && TrueC->isAllOnesValue()) ||
+            (FalseC && FalseC->isAllOnesValue())) {
+          LLVM_DEBUG(
+              dbgs()
+              << "InstCombine: skipping umin/umax folding for masked select\n");
           return nullptr;
         }
       }



More information about the llvm-commits mailing list