[llvm] 07fa6d1 - [InstCombine] Avoid folding `select(umin(X, Y), X)` with min/max values in false arm (#143020)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 13 23:32:57 PDT 2025


Author: Konstantin Bogdanov
Date: 2025-06-14T14:32:54+08:00
New Revision: 07fa6d1d90c714fa269529c3e5004a063d814c4a

URL: https://github.com/llvm/llvm-project/commit/07fa6d1d90c714fa269529c3e5004a063d814c4a
DIFF: https://github.com/llvm/llvm-project/commit/07fa6d1d90c714fa269529c3e5004a063d814c4a.diff

LOG: [InstCombine] Avoid folding `select(umin(X, Y), X)` with min/max values in false arm (#143020)

Fixes https://github.com/llvm/llvm-project/issues/139050.

This patch adds a check to avoid folding min/max reduction into select, which may block loop vectorization.

The issue is that the following snippet:
```
declare i8 @llvm.umin.i8(i8, i8)

define i8 @masked_min_fold_bug(i8 %acc, i8 %val, i8 %mask) {
; CHECK-LABEL: @masked_min_fold_bug(
; CHECK:       %cond = icmp eq i8 %mask, 0
; CHECK:       %masked_val = select i1 %cond, i8 %val, i8 255
; CHECK:       call i8 @llvm.umin.i8(i8 %acc, i8 %masked_val)
;
  %cond = icmp eq i8 %mask, 0
  %masked_val = select i1 %cond, i8 %val, i8 255
  %res = call i8 @llvm.umin.i8(i8 %acc, i8 %masked_val)
  ret i8 %res
}
```

is being optimized to the following code, which can not be vectorized
later.
```
declare i8 @llvm.umin.i8(i8, i8) #0

define i8 @masked_min_fold_bug(i8 %acc, i8 %val, i8 %mask) {
  %cond = icmp eq i8 %mask, 0
  %1 = call i8 @llvm.umin.i8(i8 %acc, i8 %val)
  %res = select i1 %cond, i8 %1, i8 %acc
  ret i8 %res
}

attributes #0 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
```

Expected:
```
declare i8 @llvm.umin.i8(i8, i8) #0

define i8 @masked_min_fold_bug(i8 %acc, i8 %val, i8 %mask) {
  %cond = icmp eq i8 %mask, 0
  %masked_val = select i1 %cond, i8 %val, i8 -1
  %res = call i8 @llvm.umin.i8(i8 %acc, i8 %masked_val)
  ret i8 %res
}

attributes #0 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
```

https://godbolt.org/z/cYMheKE5r

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
    llvm/test/Transforms/InstCombine/select.ll
    llvm/test/Transforms/PhaseOrdering/X86/vector-reductions.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 29582939fa06a..4fe900e9421f8 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -1739,6 +1739,15 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI,
   if (SI->getType()->isIntOrIntVectorTy(1))
     return nullptr;
 
+  // Avoid breaking min/max reduction pattern,
+  // which is necessary for vectorization later.
+  if (isa<MinMaxIntrinsic>(&Op))
+    for (Value *IntrinOp : Op.operands())
+      if (auto *PN = dyn_cast<PHINode>(IntrinOp))
+        for (Value *PhiOp : PN->operands())
+          if (PhiOp == &Op)
+            return nullptr;
+
   // Test if a FCmpInst instruction is used exclusively by a select as
   // part of a minimum or maximum operation. If so, refrain from doing
   // any other folding. This helps out other analyses which understand

diff  --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index e16f6ad2cfc9b..ef5874ffd46ad 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -5047,3 +5047,50 @@ define <2 x ptr> @select_freeze_constant_expression_vector_gep(i1 %cond, <2 x pt
   %sel = select i1 %cond, <2 x ptr> %y, <2 x ptr> %freeze
   ret <2 x ptr> %sel
 }
+
+define void @no_fold_masked_min_loop(ptr nocapture readonly %vals, ptr nocapture readonly %masks, ptr nocapture %out, i64 %n) {
+; CHECK-LABEL: @no_fold_masked_min_loop(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label [[LOOP:%.*]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[NEXT_INDEX:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[ACC:%.*]] = phi i8 [ -1, [[ENTRY]] ], [ [[RES:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[VAL_PTR:%.*]] = getelementptr inbounds i8, ptr [[VALS:%.*]], i64 [[INDEX]]
+; CHECK-NEXT:    [[MASK_PTR:%.*]] = getelementptr inbounds i8, ptr [[MASKS:%.*]], i64 [[INDEX]]
+; CHECK-NEXT:    [[VAL:%.*]] = load i8, ptr [[VAL_PTR]], align 1
+; CHECK-NEXT:    [[MASK:%.*]] = load i8, ptr [[MASK_PTR]], align 1
+; CHECK-NEXT:    [[COND:%.*]] = icmp eq i8 [[MASK]], 0
+; CHECK-NEXT:    [[MASKED_VAL:%.*]] = select i1 [[COND]], i8 [[VAL]], i8 -1
+; CHECK-NEXT:    [[RES]] = call i8 @llvm.umin.i8(i8 [[ACC]], i8 [[MASKED_VAL]])
+; CHECK-NEXT:    [[NEXT_INDEX]] = add i64 [[INDEX]], 1
+; CHECK-NEXT:    [[DONE:%.*]] = icmp eq i64 [[NEXT_INDEX]], [[N:%.*]]
+; CHECK-NEXT:    br i1 [[DONE]], label [[EXIT:%.*]], label [[LOOP]]
+; CHECK:       exit:
+; CHECK-NEXT:    store i8 [[RES]], ptr [[OUT:%.*]], align 1
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %loop
+
+loop:
+  %index = phi i64 [0, %entry], [%next_index, %loop]
+  %acc = phi i8 [255, %entry], [%res, %loop]
+
+  %val_ptr = getelementptr inbounds i8, ptr %vals, i64 %index
+  %mask_ptr = getelementptr inbounds i8, ptr %masks, i64 %index
+
+  %val = load i8, ptr %val_ptr, align 1
+  %mask = load i8, ptr %mask_ptr, align 1
+
+  %cond = icmp eq i8 %mask, 0
+  %masked_val = select i1 %cond, i8 %val, i8 -1
+  %res = call i8 @llvm.umin.i8(i8 %acc, i8 %masked_val)
+
+  %next_index = add i64 %index, 1
+  %done = icmp eq i64 %next_index, %n
+  br i1 %done, label %exit, label %loop
+
+exit:
+  store i8 %res, ptr %out, align 1
+  ret void
+}

diff  --git a/llvm/test/Transforms/PhaseOrdering/X86/vector-reductions.ll b/llvm/test/Transforms/PhaseOrdering/X86/vector-reductions.ll
index f8450766037b2..2ec48a8637dae 100644
--- a/llvm/test/Transforms/PhaseOrdering/X86/vector-reductions.ll
+++ b/llvm/test/Transforms/PhaseOrdering/X86/vector-reductions.ll
@@ -326,26 +326,52 @@ cleanup:
   ret i1 %retval.0
 }
 
-; From https://github.com/llvm/llvm-project/issues/139050.
-; FIXME: This should be vectorized.
 define i8 @masked_min_reduction(ptr %data, ptr %mask) {
 ; CHECK-LABEL: @masked_min_reduction(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
-; CHECK:       loop:
+; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[ACC:%.*]] = phi i8 [ -1, [[ENTRY]] ], [ [[TMP21:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <32 x i8> [ splat (i8 -1), [[ENTRY]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI1:%.*]] = phi <32 x i8> [ splat (i8 -1), [[ENTRY]] ], [ [[TMP17:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI2:%.*]] = phi <32 x i8> [ splat (i8 -1), [[ENTRY]] ], [ [[TMP18:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_PHI3:%.*]] = phi <32 x i8> [ splat (i8 -1), [[ENTRY]] ], [ [[TMP19:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[DATA:%.*]] = getelementptr i8, ptr [[DATA1:%.*]], i64 [[INDEX]]
-; CHECK-NEXT:    [[VAL:%.*]] = load i8, ptr [[DATA]], align 1
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i8, ptr [[DATA]], i64 32
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i8, ptr [[DATA]], i64 64
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i8, ptr [[DATA]], i64 96
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <32 x i8>, ptr [[DATA]], align 1
+; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <32 x i8>, ptr [[TMP1]], align 1
+; CHECK-NEXT:    [[WIDE_LOAD5:%.*]] = load <32 x i8>, ptr [[TMP2]], align 1
+; CHECK-NEXT:    [[WIDE_LOAD6:%.*]] = load <32 x i8>, ptr [[TMP3]], align 1
 ; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[MASK:%.*]], i64 [[INDEX]]
-; CHECK-NEXT:    [[M:%.*]] = load i8, ptr [[TMP7]], align 1
-; CHECK-NEXT:    [[COND:%.*]] = icmp eq i8 [[M]], 0
-; CHECK-NEXT:    [[TMP0:%.*]] = tail call i8 @llvm.umin.i8(i8 [[ACC]], i8 [[VAL]])
-; CHECK-NEXT:    [[TMP21]] = select i1 [[COND]], i8 [[TMP0]], i8 [[ACC]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw nsw i64 [[INDEX]], 1
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr i8, ptr [[TMP7]], i64 32
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[TMP7]], i64 64
+; CHECK-NEXT:    [[TMP22:%.*]] = getelementptr i8, ptr [[TMP7]], i64 96
+; CHECK-NEXT:    [[WIDE_LOAD7:%.*]] = load <32 x i8>, ptr [[TMP7]], align 1
+; CHECK-NEXT:    [[WIDE_LOAD8:%.*]] = load <32 x i8>, ptr [[TMP5]], align 1
+; CHECK-NEXT:    [[WIDE_LOAD9:%.*]] = load <32 x i8>, ptr [[TMP6]], align 1
+; CHECK-NEXT:    [[WIDE_LOAD10:%.*]] = load <32 x i8>, ptr [[TMP22]], align 1
+; CHECK-NEXT:    [[TMP8:%.*]] = icmp eq <32 x i8> [[WIDE_LOAD7]], zeroinitializer
+; CHECK-NEXT:    [[TMP9:%.*]] = icmp eq <32 x i8> [[WIDE_LOAD8]], zeroinitializer
+; CHECK-NEXT:    [[TMP10:%.*]] = icmp eq <32 x i8> [[WIDE_LOAD9]], zeroinitializer
+; CHECK-NEXT:    [[TMP11:%.*]] = icmp eq <32 x i8> [[WIDE_LOAD10]], zeroinitializer
+; CHECK-NEXT:    [[TMP12:%.*]] = select <32 x i1> [[TMP8]], <32 x i8> [[WIDE_LOAD]], <32 x i8> splat (i8 -1)
+; CHECK-NEXT:    [[TMP13:%.*]] = select <32 x i1> [[TMP9]], <32 x i8> [[WIDE_LOAD4]], <32 x i8> splat (i8 -1)
+; CHECK-NEXT:    [[TMP14:%.*]] = select <32 x i1> [[TMP10]], <32 x i8> [[WIDE_LOAD5]], <32 x i8> splat (i8 -1)
+; CHECK-NEXT:    [[TMP15:%.*]] = select <32 x i1> [[TMP11]], <32 x i8> [[WIDE_LOAD6]], <32 x i8> splat (i8 -1)
+; CHECK-NEXT:    [[TMP16]] = tail call <32 x i8> @llvm.umin.v32i8(<32 x i8> [[VEC_PHI]], <32 x i8> [[TMP12]])
+; CHECK-NEXT:    [[TMP17]] = tail call <32 x i8> @llvm.umin.v32i8(<32 x i8> [[VEC_PHI1]], <32 x i8> [[TMP13]])
+; CHECK-NEXT:    [[TMP18]] = tail call <32 x i8> @llvm.umin.v32i8(<32 x i8> [[VEC_PHI2]], <32 x i8> [[TMP14]])
+; CHECK-NEXT:    [[TMP19]] = tail call <32 x i8> @llvm.umin.v32i8(<32 x i8> [[VEC_PHI3]], <32 x i8> [[TMP15]])
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 128
 ; CHECK-NEXT:    [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
-; CHECK-NEXT:    br i1 [[TMP20]], label [[EXIT:%.*]], label [[VECTOR_BODY]]
-; CHECK:       exit:
+; CHECK-NEXT:    br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK:       middle.block:
+; CHECK-NEXT:    [[RDX_MINMAX:%.*]] = tail call <32 x i8> @llvm.umin.v32i8(<32 x i8> [[TMP16]], <32 x i8> [[TMP17]])
+; CHECK-NEXT:    [[RDX_MINMAX11:%.*]] = tail call <32 x i8> @llvm.umin.v32i8(<32 x i8> [[RDX_MINMAX]], <32 x i8> [[TMP18]])
+; CHECK-NEXT:    [[RDX_MINMAX12:%.*]] = tail call <32 x i8> @llvm.umin.v32i8(<32 x i8> [[RDX_MINMAX11]], <32 x i8> [[TMP19]])
+; CHECK-NEXT:    [[TMP21:%.*]] = tail call i8 @llvm.vector.reduce.umin.v32i8(<32 x i8> [[RDX_MINMAX12]])
 ; CHECK-NEXT:    ret i8 [[TMP21]]
 ;
 entry:


        


More information about the llvm-commits mailing list