[llvm] [InstCombine] Simplify nested selects with implied condition (PR #83739)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 3 21:34:14 PST 2024


https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/83739

>From 2abbc125b2d8d9fedf9961eeef853792dd7aadd1 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Mon, 4 Mar 2024 03:05:51 +0800
Subject: [PATCH 1/3] [InstCombine] Add pre-commit tests. NFC.

---
 .../Transforms/InstCombine/nested-select.ll   | 92 +++++++++++++++++++
 1 file changed, 92 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/nested-select.ll b/llvm/test/Transforms/InstCombine/nested-select.ll
index 42a0f81e7b85a2..a7cdee1483b46a 100644
--- a/llvm/test/Transforms/InstCombine/nested-select.ll
+++ b/llvm/test/Transforms/InstCombine/nested-select.ll
@@ -498,3 +498,95 @@ define i1 @orcond.111.inv.all.conds(i1 %inner.cond, i1 %alt.cond, i1 %inner.sel.
   %outer.sel = select i1 %not.outer.cond, i1 true, i1 %inner.sel
   ret i1 %outer.sel
 }
+
+define i8 @test_implied_true(i8 %x) {
+; CHECK-LABEL: @test_implied_true(
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i8 [[X:%.*]], 10
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i8 [[X]], 0
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i8 0, i8 5
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[SEL1]], i8 20
+; CHECK-NEXT:    ret i8 [[SEL2]]
+;
+  %cmp1 = icmp slt i8 %x, 10
+  %cmp2 = icmp slt i8 %x, 0
+  %sel1 = select i1 %cmp1, i8 0, i8 5
+  %sel2 = select i1 %cmp2, i8 %sel1, i8 20
+  ret i8 %sel2
+}
+
+define <2 x i8> @test_implied_true_vec(<2 x i8> %x) {
+; CHECK-LABEL: @test_implied_true_vec(
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt <2 x i8> [[X:%.*]], <i8 10, i8 10>
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt <2 x i8> [[X]], zeroinitializer
+; CHECK-NEXT:    [[SEL1:%.*]] = select <2 x i1> [[CMP1]], <2 x i8> zeroinitializer, <2 x i8> <i8 5, i8 5>
+; CHECK-NEXT:    [[SEL2:%.*]] = select <2 x i1> [[CMP2]], <2 x i8> [[SEL1]], <2 x i8> <i8 20, i8 20>
+; CHECK-NEXT:    ret <2 x i8> [[SEL2]]
+;
+  %cmp1 = icmp slt <2 x i8> %x, <i8 10, i8 10>
+  %cmp2 = icmp slt <2 x i8> %x, zeroinitializer
+  %sel1 = select <2 x i1> %cmp1, <2 x i8> zeroinitializer, <2 x i8> <i8 5, i8 5>
+  %sel2 = select <2 x i1> %cmp2, <2 x i8> %sel1, <2 x i8> <i8 20, i8 20>
+  ret <2 x i8> %sel2
+}
+
+define i8 @test_implied_true_falseval(i8 %x) {
+; CHECK-LABEL: @test_implied_true_falseval(
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i8 [[X:%.*]], 10
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp sgt i8 [[X]], 0
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i8 0, i8 5
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 20, i8 [[SEL1]]
+; CHECK-NEXT:    ret i8 [[SEL2]]
+;
+  %cmp1 = icmp slt i8 %x, 10
+  %cmp2 = icmp sgt i8 %x, 0
+  %sel1 = select i1 %cmp1, i8 0, i8 5
+  %sel2 = select i1 %cmp2, i8 20, i8 %sel1
+  ret i8 %sel2
+}
+
+define i8 @test_implied_false(i8 %x) {
+; CHECK-LABEL: @test_implied_false(
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp sgt i8 [[X:%.*]], 10
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i8 [[X]], 0
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i8 0, i8 5
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[SEL1]], i8 20
+; CHECK-NEXT:    ret i8 [[SEL2]]
+;
+  %cmp1 = icmp sgt i8 %x, 10
+  %cmp2 = icmp slt i8 %x, 0
+  %sel1 = select i1 %cmp1, i8 0, i8 5
+  %sel2 = select i1 %cmp2, i8 %sel1, i8 20
+  ret i8 %sel2
+}
+
+; Negative tests
+
+define i8 @test_imply_fail(i8 %x) {
+; CHECK-LABEL: @test_imply_fail(
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i8 [[X:%.*]], -10
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i8 [[X]], 0
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i8 0, i8 5
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[SEL1]], i8 20
+; CHECK-NEXT:    ret i8 [[SEL2]]
+;
+  %cmp1 = icmp slt i8 %x, -10
+  %cmp2 = icmp slt i8 %x, 0
+  %sel1 = select i1 %cmp1, i8 0, i8 5
+  %sel2 = select i1 %cmp2, i8 %sel1, i8 20
+  ret i8 %sel2
+}
+
+define <2 x i8> @test_imply_type_mismatch(<2 x i8> %x, i8 %y) {
+; CHECK-LABEL: @test_imply_type_mismatch(
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt <2 x i8> [[X:%.*]], <i8 10, i8 10>
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i8 [[Y:%.*]], 0
+; CHECK-NEXT:    [[SEL1:%.*]] = select <2 x i1> [[CMP1]], <2 x i8> zeroinitializer, <2 x i8> <i8 5, i8 5>
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], <2 x i8> [[SEL1]], <2 x i8> <i8 20, i8 20>
+; CHECK-NEXT:    ret <2 x i8> [[SEL2]]
+;
+  %cmp1 = icmp slt <2 x i8> %x, <i8 10, i8 10>
+  %cmp2 = icmp slt i8 %y, 0
+  %sel1 = select <2 x i1> %cmp1, <2 x i8> zeroinitializer, <2 x i8> <i8 5, i8 5>
+  %sel2 = select i1 %cmp2, <2 x i8> %sel1, <2 x i8> <i8 20, i8 20>
+  ret <2 x i8> %sel2
+}

>From 5bfe453202935d883c302045f2ba987a3b318f8b Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Mon, 4 Mar 2024 03:25:48 +0800
Subject: [PATCH 2/3] [InstCombine] Simplify nested selects with implied
 condition

---
 .../InstCombine/InstCombineSelect.cpp         | 17 +++++++++++++
 ...etween-negative-and-positive-thresholds.ll | 12 ++++------
 .../Transforms/InstCombine/nested-select.ll   | 24 +++++++------------
 3 files changed, 29 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 71fa9b9ba41ebb..dc4347fdd713c6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3867,5 +3867,22 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
     }
   }
 
+  // Fold nested selects if the inner condition can be implied by the outer
+  // condition.
+  Value *InnerCondVal;
+  const DataLayout &DL = getDataLayout();
+  if (match(TrueVal,
+            m_Select(m_Value(InnerCondVal), m_Value(LHS), m_Value(RHS))) &&
+      CondVal->getType() == InnerCondVal->getType())
+    if (auto Implied =
+            isImpliedCondition(CondVal, InnerCondVal, DL, /*LHSIsTrue=*/true))
+      return replaceOperand(SI, 1, *Implied ? LHS : RHS);
+  if (match(FalseVal,
+            m_Select(m_Value(InnerCondVal), m_Value(LHS), m_Value(RHS))) &&
+      CondVal->getType() == InnerCondVal->getType())
+    if (auto Implied =
+            isImpliedCondition(CondVal, InnerCondVal, DL, /*LHSIsTrue=*/false))
+      return replaceOperand(SI, 2, *Implied ? LHS : RHS);
+
   return nullptr;
 }
diff --git a/llvm/test/Transforms/InstCombine/canonicalize-clamp-like-pattern-between-negative-and-positive-thresholds.ll b/llvm/test/Transforms/InstCombine/canonicalize-clamp-like-pattern-between-negative-and-positive-thresholds.ll
index d03e22bc4c9fbf..b5ef1f466958d7 100644
--- a/llvm/test/Transforms/InstCombine/canonicalize-clamp-like-pattern-between-negative-and-positive-thresholds.ll
+++ b/llvm/test/Transforms/InstCombine/canonicalize-clamp-like-pattern-between-negative-and-positive-thresholds.ll
@@ -189,10 +189,8 @@ define i32 @n9_ult_slt_neg17(i32 %x, i32 %replacement_low, i32 %replacement_high
 ; Regression test for PR53252.
 define i32 @n10_ugt_slt(i32 %x, i32 %replacement_low, i32 %replacement_high) {
 ; CHECK-LABEL: @n10_ugt_slt(
-; CHECK-NEXT:    [[T0:%.*]] = icmp slt i32 [[X:%.*]], 0
-; CHECK-NEXT:    [[T1:%.*]] = select i1 [[T0]], i32 [[REPLACEMENT_LOW:%.*]], i32 [[REPLACEMENT_HIGH:%.*]]
-; CHECK-NEXT:    [[T2:%.*]] = icmp ugt i32 [[X]], 128
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[T2]], i32 [[X]], i32 [[T1]]
+; CHECK-NEXT:    [[T2:%.*]] = icmp ugt i32 [[X:%.*]], 128
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[T2]], i32 [[X]], i32 [[REPLACEMENT_HIGH:%.*]]
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %t0 = icmp slt i32 %x, 0
@@ -204,10 +202,8 @@ define i32 @n10_ugt_slt(i32 %x, i32 %replacement_low, i32 %replacement_high) {
 
 define i32 @n11_uge_slt(i32 %x, i32 %replacement_low, i32 %replacement_high) {
 ; CHECK-LABEL: @n11_uge_slt(
-; CHECK-NEXT:    [[T0:%.*]] = icmp slt i32 [[X:%.*]], 0
-; CHECK-NEXT:    [[T1:%.*]] = select i1 [[T0]], i32 [[REPLACEMENT_LOW:%.*]], i32 [[REPLACEMENT_HIGH:%.*]]
-; CHECK-NEXT:    [[T2:%.*]] = icmp ult i32 [[X]], 129
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[T2]], i32 [[T1]], i32 [[X]]
+; CHECK-NEXT:    [[T2:%.*]] = icmp ult i32 [[X:%.*]], 129
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[T2]], i32 [[REPLACEMENT_HIGH:%.*]], i32 [[X]]
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %t0 = icmp slt i32 %x, 0
diff --git a/llvm/test/Transforms/InstCombine/nested-select.ll b/llvm/test/Transforms/InstCombine/nested-select.ll
index a7cdee1483b46a..d4bbf0ae48590a 100644
--- a/llvm/test/Transforms/InstCombine/nested-select.ll
+++ b/llvm/test/Transforms/InstCombine/nested-select.ll
@@ -501,10 +501,8 @@ define i1 @orcond.111.inv.all.conds(i1 %inner.cond, i1 %alt.cond, i1 %inner.sel.
 
 define i8 @test_implied_true(i8 %x) {
 ; CHECK-LABEL: @test_implied_true(
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i8 [[X:%.*]], 10
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i8 [[X]], 0
-; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i8 0, i8 5
-; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[SEL1]], i8 20
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i8 [[X:%.*]], 0
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 0, i8 20
 ; CHECK-NEXT:    ret i8 [[SEL2]]
 ;
   %cmp1 = icmp slt i8 %x, 10
@@ -516,10 +514,8 @@ define i8 @test_implied_true(i8 %x) {
 
 define <2 x i8> @test_implied_true_vec(<2 x i8> %x) {
 ; CHECK-LABEL: @test_implied_true_vec(
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt <2 x i8> [[X:%.*]], <i8 10, i8 10>
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt <2 x i8> [[X]], zeroinitializer
-; CHECK-NEXT:    [[SEL1:%.*]] = select <2 x i1> [[CMP1]], <2 x i8> zeroinitializer, <2 x i8> <i8 5, i8 5>
-; CHECK-NEXT:    [[SEL2:%.*]] = select <2 x i1> [[CMP2]], <2 x i8> [[SEL1]], <2 x i8> <i8 20, i8 20>
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt <2 x i8> [[X:%.*]], zeroinitializer
+; CHECK-NEXT:    [[SEL2:%.*]] = select <2 x i1> [[CMP2]], <2 x i8> zeroinitializer, <2 x i8> <i8 20, i8 20>
 ; CHECK-NEXT:    ret <2 x i8> [[SEL2]]
 ;
   %cmp1 = icmp slt <2 x i8> %x, <i8 10, i8 10>
@@ -531,10 +527,8 @@ define <2 x i8> @test_implied_true_vec(<2 x i8> %x) {
 
 define i8 @test_implied_true_falseval(i8 %x) {
 ; CHECK-LABEL: @test_implied_true_falseval(
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i8 [[X:%.*]], 10
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp sgt i8 [[X]], 0
-; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i8 0, i8 5
-; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 20, i8 [[SEL1]]
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp sgt i8 [[X:%.*]], 0
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 20, i8 0
 ; CHECK-NEXT:    ret i8 [[SEL2]]
 ;
   %cmp1 = icmp slt i8 %x, 10
@@ -546,10 +540,8 @@ define i8 @test_implied_true_falseval(i8 %x) {
 
 define i8 @test_implied_false(i8 %x) {
 ; CHECK-LABEL: @test_implied_false(
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp sgt i8 [[X:%.*]], 10
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i8 [[X]], 0
-; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i8 0, i8 5
-; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[SEL1]], i8 20
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i8 [[X:%.*]], 0
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 5, i8 20
 ; CHECK-NEXT:    ret i8 [[SEL2]]
 ;
   %cmp1 = icmp sgt i8 %x, 10

>From 3a3fa92de1aecc79926c9b183da004879125380b Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Mon, 4 Mar 2024 13:19:31 +0800
Subject: [PATCH 3/3] fixup! [InstCombine] Simplify nested selects with implied
 condition

---
 .../InstCombine/InstCombineSelect.cpp         | 29 ++++++++++---------
 1 file changed, 16 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index dc4347fdd713c6..85d0cecf7cea02 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3869,20 +3869,23 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
 
   // Fold nested selects if the inner condition can be implied by the outer
   // condition.
-  Value *InnerCondVal;
-  const DataLayout &DL = getDataLayout();
-  if (match(TrueVal,
-            m_Select(m_Value(InnerCondVal), m_Value(LHS), m_Value(RHS))) &&
-      CondVal->getType() == InnerCondVal->getType())
-    if (auto Implied =
-            isImpliedCondition(CondVal, InnerCondVal, DL, /*LHSIsTrue=*/true))
-      return replaceOperand(SI, 1, *Implied ? LHS : RHS);
-  if (match(FalseVal,
-            m_Select(m_Value(InnerCondVal), m_Value(LHS), m_Value(RHS))) &&
-      CondVal->getType() == InnerCondVal->getType())
+  auto SimplifyNestedSelect = [&](Value *Arm, bool CondIsTrue) -> Value * {
+    Value *InnerCondVal, *LHS, *RHS;
+    if (!match(Arm,
+               m_Select(m_Value(InnerCondVal), m_Value(LHS), m_Value(RHS))))
+      return nullptr;
+    if (CondVal->getType() != InnerCondVal->getType())
+      return nullptr;
+    const DataLayout &DL = getDataLayout();
     if (auto Implied =
-            isImpliedCondition(CondVal, InnerCondVal, DL, /*LHSIsTrue=*/false))
-      return replaceOperand(SI, 2, *Implied ? LHS : RHS);
+            isImpliedCondition(CondVal, InnerCondVal, DL, CondIsTrue))
+      return *Implied ? LHS : RHS;
+    return nullptr;
+  };
+  if (Value *V = SimplifyNestedSelect(TrueVal, /*CondIsTrue=*/true))
+    return replaceOperand(SI, 1, V);
+  if (Value *V = SimplifyNestedSelect(FalseVal, /*CondIsTrue=*/false))
+    return replaceOperand(SI, 2, V);
 
   return nullptr;
 }



More information about the llvm-commits mailing list