[llvm] [InstCombine] Select of Symmetric Selects (PR #99245)

Tim Gymnich via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 17 05:59:04 PDT 2024


https://github.com/tgymnich updated https://github.com/llvm/llvm-project/pull/99245

>From d5068974fcac22cc024a5c995a28e1e79a37bc4c Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tgymnich at icloud.com>
Date: Sun, 14 Jul 2024 15:47:11 +0200
Subject: [PATCH 1/4] add tests

---
 .../select-of-symmetric-selects.ll            | 126 ++++++++++++++++++
 1 file changed, 126 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/select-of-symmetric-selects.ll

diff --git a/llvm/test/Transforms/InstCombine/select-of-symmetric-selects.ll b/llvm/test/Transforms/InstCombine/select-of-symmetric-selects.ll
new file mode 100644
index 0000000000000..8c54563be6f71
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/select-of-symmetric-selects.ll
@@ -0,0 +1,126 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i32 @select_of_symmetric_selects(i32 %a, i32 %b, i1 %c1, i1 %c2) {
+; CHECK-LABEL: @select_of_symmetric_selects(
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[C1:%.*]], i32 [[A:%.*]], i32 [[B:%.*]]
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[C1]], i32 [[B]], i32 [[A]]
+; CHECK-NEXT:    [[RET:%.*]] = select i1 [[C2:%.*]], i32 [[SEL1]], i32 [[SEL2]]
+; CHECK-NEXT:    ret i32 [[RET]]
+;
+  %sel1 = select i1 %c1, i32 %a, i32 %b
+  %sel2 = select i1 %c1, i32 %b, i32 %a
+  %ret = select i1 %c2, i32 %sel1, i32 %sel2
+  ret i32 %ret
+}
+
+define i32 @select_of_symmetric_selects_negative1(i32 %a, i32 %b, i1 %c1, i1 %c2) {
+; CHECK-LABEL: @select_of_symmetric_selects_negative1(
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[C1:%.*]], i32 [[A:%.*]], i32 [[B:%.*]]
+; CHECK-NEXT:    [[RET:%.*]] = select i1 [[C2:%.*]], i32 [[SEL1]], i32 [[A]]
+; CHECK-NEXT:    ret i32 [[RET]]
+;
+  %sel1 = select i1 %c1, i32 %a, i32 %b
+  %sel2 = select i1 %c2, i32 %b, i32 %a
+  %ret = select i1 %c2, i32 %sel1, i32 %sel2
+  ret i32 %ret
+}
+
+define i32 @select_of_symmetric_selects_negative2(i32 %a, i32 %b, i32 %c, i1 %c1, i1 %c2) {
+; CHECK-LABEL: @select_of_symmetric_selects_negative2(
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[C1:%.*]], i32 [[A:%.*]], i32 [[B:%.*]]
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[C1]], i32 [[B]], i32 [[C:%.*]]
+; CHECK-NEXT:    [[RET:%.*]] = select i1 [[C2:%.*]], i32 [[SEL1]], i32 [[SEL2]]
+; CHECK-NEXT:    ret i32 [[RET]]
+;
+  %sel1 = select i1 %c1, i32 %a, i32 %b
+  %sel2 = select i1 %c1, i32 %b, i32 %c
+  %ret = select i1 %c2, i32 %sel1, i32 %sel2
+  ret i32 %ret
+}
+
+declare void @use(i32)
+
+define i32 @select_of_symmetric_selects_multi_use1(i32 %a, i32 %b, i1 %c1, i1 %c2) {
+; CHECK-LABEL: @select_of_symmetric_selects_multi_use1(
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[C1:%.*]], i32 [[A:%.*]], i32 [[B:%.*]]
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[C1]], i32 [[B]], i32 [[A]]
+; CHECK-NEXT:    call void @use(i32 [[SEL2]])
+; CHECK-NEXT:    [[RET:%.*]] = select i1 [[C2:%.*]], i32 [[SEL1]], i32 [[SEL2]]
+; CHECK-NEXT:    ret i32 [[RET]]
+;
+  %sel1 = select i1 %c1, i32 %a, i32 %b
+  %sel2 = select i1 %c1, i32 %b, i32 %a
+  call void @use(i32 %sel2)
+  %ret = select i1 %c2, i32 %sel1, i32 %sel2
+  ret i32 %ret
+}
+
+define i32 @select_of_symmetric_selects_multi_use2(i32 %a, i32 %b, i1 %c1, i1 %c2) {
+; CHECK-LABEL: @select_of_symmetric_selects_multi_use2(
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[C1:%.*]], i32 [[A:%.*]], i32 [[B:%.*]]
+; CHECK-NEXT:    call void @use(i32 [[SEL1]])
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[C1]], i32 [[B]], i32 [[A]]
+; CHECK-NEXT:    call void @use(i32 [[SEL2]])
+; CHECK-NEXT:    [[RET:%.*]] = select i1 [[C2:%.*]], i32 [[SEL1]], i32 [[SEL2]]
+; CHECK-NEXT:    ret i32 [[RET]]
+;
+  %sel1 = select i1 %c1, i32 %a, i32 %b
+  call void @use(i32 %sel1)
+  %sel2 = select i1 %c1, i32 %b, i32 %a
+  call void @use(i32 %sel2)
+  %ret = select i1 %c2, i32 %sel1, i32 %sel2
+  ret i32 %ret
+}
+
+define i32 @select_of_symmetric_selects_commuted(i32 %a, i32 %b, i1 %c1, i1 %c2) {
+; CHECK-LABEL: @select_of_symmetric_selects_commuted(
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[C1:%.*]], i32 [[A:%.*]], i32 [[B:%.*]]
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[C1]], i32 [[B]], i32 [[A]]
+; CHECK-NEXT:    [[RET:%.*]] = select i1 [[C2:%.*]], i32 [[SEL2]], i32 [[SEL1]]
+; CHECK-NEXT:    ret i32 [[RET]]
+;
+  %sel1 = select i1 %c1, i32 %a, i32 %b
+  %sel2 = select i1 %c1, i32 %b, i32 %a
+  %ret = select i1 %c2, i32 %sel2, i32 %sel1
+  ret i32 %ret
+}
+
+define <4 x i32> @select_of_symmetric_selects_vector1(<4 x i32> %a, <4 x i32> %b, i1 %c1, i1 %c2) {
+; CHECK-LABEL: @select_of_symmetric_selects_vector1(
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[C1:%.*]], <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]]
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[C1]], <4 x i32> [[B]], <4 x i32> [[A]]
+; CHECK-NEXT:    [[RET:%.*]] = select i1 [[C2:%.*]], <4 x i32> [[SEL2]], <4 x i32> [[SEL1]]
+; CHECK-NEXT:    ret <4 x i32> [[RET]]
+;
+  %sel1 = select i1 %c1, <4 x i32> %a, <4 x i32> %b
+  %sel2 = select i1 %c1, <4 x i32> %b, <4 x i32> %a
+  %ret = select i1 %c2, <4 x i32> %sel2, <4 x i32> %sel1
+  ret <4 x i32> %ret
+}
+
+define <4 x i32> @select_of_symmetric_selects_vector2(<4 x i32> %a, <4 x i32> %b, <4 x i1> %c1, <4 x i1> %c2) {
+; CHECK-LABEL: @select_of_symmetric_selects_vector2(
+; CHECK-NEXT:    [[SEL1:%.*]] = select <4 x i1> [[C1:%.*]], <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]]
+; CHECK-NEXT:    [[SEL2:%.*]] = select <4 x i1> [[C1]], <4 x i32> [[B]], <4 x i32> [[A]]
+; CHECK-NEXT:    [[RET:%.*]] = select <4 x i1> [[C2:%.*]], <4 x i32> [[SEL2]], <4 x i32> [[SEL1]]
+; CHECK-NEXT:    ret <4 x i32> [[RET]]
+;
+  %sel1 = select <4 x i1> %c1, <4 x i32> %a, <4 x i32> %b
+  %sel2 = select <4 x i1> %c1, <4 x i32> %b, <4 x i32> %a
+  %ret = select <4 x i1> %c2, <4 x i32> %sel2, <4 x i32> %sel1
+  ret <4 x i32> %ret
+}
+
+define <2 x i32> @select_of_symmetric_selects_vector3(<2 x i32> %a, <2 x i32> %b, <2 x i1> %c1, i1 %c2) {
+; CHECK-LABEL: @select_of_symmetric_selects_vector3(
+; CHECK-NEXT:    [[SEL1:%.*]] = select <2 x i1> [[C1:%.*]], <2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]]
+; CHECK-NEXT:    [[SEL2:%.*]] = select <2 x i1> [[C1]], <2 x i32> [[B]], <2 x i32> [[A]]
+; CHECK-NEXT:    [[RET:%.*]] = select i1 [[C2:%.*]], <2 x i32> [[SEL1]], <2 x i32> [[SEL2]]
+; CHECK-NEXT:    ret <2 x i32> [[RET]]
+;
+  %sel1 = select <2 x i1> %c1, <2 x i32> %a, <2 x i32> %b
+  %sel2 = select <2 x i1> %c1, <2 x i32> %b, <2 x i32> %a
+  %ret = select i1 %c2, <2 x i32> %sel1, <2 x i32> %sel2
+  ret <2 x i32> %ret
+  }

>From a53ff1c6f60781a8cc34d4635f8ae4ee7ab74adb Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tgymnich at icloud.com>
Date: Sun, 14 Jul 2024 15:33:31 +0200
Subject: [PATCH 2/4] [InstCombine] Add fold for select of symmetric selects

Alive2 proofs:
https://alive2.llvm.org/ce/z/4QAm4K
https://alive2.llvm.org/ce/z/vTVRnC
---
 .../InstCombine/InstCombineSelect.cpp         | 29 +++++++++++++++++++
 1 file changed, 29 insertions(+)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 394dfca262e13..6a7db8cc61397 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3012,6 +3012,32 @@ struct DecomposedSelect {
 };
 } // namespace
 
+/// Folds patterns like:
+///   select c2 (select c1 a b) (select c1 b a)
+/// into:
+///   select (xor c1 c2) b a
+static Instruction *
+foldSelectOfSymmetricSelect(SelectInst &OuterSelVal,
+                            InstCombiner::BuilderTy &Builder) {
+
+  DecomposedSelect OuterSel, InnerSel;
+  if (!match(&OuterSelVal,
+             m_Select(m_Value(OuterSel.Cond),
+                      m_OneUse(m_Select(m_Value(InnerSel.Cond),
+                                        m_Value(InnerSel.TrueVal),
+                                        m_Value(InnerSel.FalseVal))),
+                      m_OneUse(m_Select(m_Deferred(InnerSel.Cond),
+                                        m_Deferred(InnerSel.FalseVal),
+                                        m_Deferred(InnerSel.TrueVal))))))
+    return nullptr;
+
+  if (OuterSel.Cond->getType() != InnerSel.Cond->getType())
+    return nullptr;
+
+  Value *Xor = Builder.CreateXor(InnerSel.Cond, OuterSel.Cond);
+  return SelectInst::Create(Xor, InnerSel.FalseVal, InnerSel.TrueVal);
+}
+
 /// Look for patterns like
 ///   %outer.cond = select i1 %inner.cond, i1 %alt.cond, i1 false
 ///   %inner.sel = select i1 %inner.cond, i8 %inner.sel.t, i8 %inner.sel.f
@@ -3987,6 +4013,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
     }
   }
 
+  if (Instruction *I = foldSelectOfSymmetricSelect(SI, Builder))
+    return I;
+
   if (Instruction *I = foldNestedSelects(SI, Builder))
     return I;
 

>From 94abd42dde225f2009d3fb0eecb49cb49a208cad Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tgymnich at icloud.com>
Date: Sun, 14 Jul 2024 15:47:11 +0200
Subject: [PATCH 3/4] update tests

---
 .../select-of-symmetric-selects.ll            | 20 ++++++++-----------
 1 file changed, 8 insertions(+), 12 deletions(-)

diff --git a/llvm/test/Transforms/InstCombine/select-of-symmetric-selects.ll b/llvm/test/Transforms/InstCombine/select-of-symmetric-selects.ll
index 8c54563be6f71..0936f58ac9443 100644
--- a/llvm/test/Transforms/InstCombine/select-of-symmetric-selects.ll
+++ b/llvm/test/Transforms/InstCombine/select-of-symmetric-selects.ll
@@ -3,9 +3,8 @@
 
 define i32 @select_of_symmetric_selects(i32 %a, i32 %b, i1 %c1, i1 %c2) {
 ; CHECK-LABEL: @select_of_symmetric_selects(
-; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[C1:%.*]], i32 [[A:%.*]], i32 [[B:%.*]]
-; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[C1]], i32 [[B]], i32 [[A]]
-; CHECK-NEXT:    [[RET:%.*]] = select i1 [[C2:%.*]], i32 [[SEL1]], i32 [[SEL2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = xor i1 [[C1:%.*]], [[C2:%.*]]
+; CHECK-NEXT:    [[RET:%.*]] = select i1 [[TMP1]], i32 [[B:%.*]], i32 [[A:%.*]]
 ; CHECK-NEXT:    ret i32 [[RET]]
 ;
   %sel1 = select i1 %c1, i32 %a, i32 %b
@@ -75,9 +74,8 @@ define i32 @select_of_symmetric_selects_multi_use2(i32 %a, i32 %b, i1 %c1, i1 %c
 
 define i32 @select_of_symmetric_selects_commuted(i32 %a, i32 %b, i1 %c1, i1 %c2) {
 ; CHECK-LABEL: @select_of_symmetric_selects_commuted(
-; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[C1:%.*]], i32 [[A:%.*]], i32 [[B:%.*]]
-; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[C1]], i32 [[B]], i32 [[A]]
-; CHECK-NEXT:    [[RET:%.*]] = select i1 [[C2:%.*]], i32 [[SEL2]], i32 [[SEL1]]
+; CHECK-NEXT:    [[TMP1:%.*]] = xor i1 [[C1:%.*]], [[C2:%.*]]
+; CHECK-NEXT:    [[RET:%.*]] = select i1 [[TMP1]], i32 [[A:%.*]], i32 [[B:%.*]]
 ; CHECK-NEXT:    ret i32 [[RET]]
 ;
   %sel1 = select i1 %c1, i32 %a, i32 %b
@@ -88,9 +86,8 @@ define i32 @select_of_symmetric_selects_commuted(i32 %a, i32 %b, i1 %c1, i1 %c2)
 
 define <4 x i32> @select_of_symmetric_selects_vector1(<4 x i32> %a, <4 x i32> %b, i1 %c1, i1 %c2) {
 ; CHECK-LABEL: @select_of_symmetric_selects_vector1(
-; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[C1:%.*]], <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]]
-; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[C1]], <4 x i32> [[B]], <4 x i32> [[A]]
-; CHECK-NEXT:    [[RET:%.*]] = select i1 [[C2:%.*]], <4 x i32> [[SEL2]], <4 x i32> [[SEL1]]
+; CHECK-NEXT:    [[TMP1:%.*]] = xor i1 [[C1:%.*]], [[C2:%.*]]
+; CHECK-NEXT:    [[RET:%.*]] = select i1 [[TMP1]], <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]]
 ; CHECK-NEXT:    ret <4 x i32> [[RET]]
 ;
   %sel1 = select i1 %c1, <4 x i32> %a, <4 x i32> %b
@@ -101,9 +98,8 @@ define <4 x i32> @select_of_symmetric_selects_vector1(<4 x i32> %a, <4 x i32> %b
 
 define <4 x i32> @select_of_symmetric_selects_vector2(<4 x i32> %a, <4 x i32> %b, <4 x i1> %c1, <4 x i1> %c2) {
 ; CHECK-LABEL: @select_of_symmetric_selects_vector2(
-; CHECK-NEXT:    [[SEL1:%.*]] = select <4 x i1> [[C1:%.*]], <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]]
-; CHECK-NEXT:    [[SEL2:%.*]] = select <4 x i1> [[C1]], <4 x i32> [[B]], <4 x i32> [[A]]
-; CHECK-NEXT:    [[RET:%.*]] = select <4 x i1> [[C2:%.*]], <4 x i32> [[SEL2]], <4 x i32> [[SEL1]]
+; CHECK-NEXT:    [[TMP1:%.*]] = xor <4 x i1> [[C1:%.*]], [[C2:%.*]]
+; CHECK-NEXT:    [[RET:%.*]] = select <4 x i1> [[TMP1]], <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]]
 ; CHECK-NEXT:    ret <4 x i32> [[RET]]
 ;
   %sel1 = select <4 x i1> %c1, <4 x i32> %a, <4 x i32> %b

>From 714fad2df424dbbf6861ec4b30c16a595e6acd97 Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tgymnich at icloud.com>
Date: Wed, 17 Jul 2024 14:58:50 +0200
Subject: [PATCH 4/4] simplify DecomposedSelect

---
 .../InstCombine/InstCombineSelect.cpp         | 24 +++++++++----------
 1 file changed, 12 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 6a7db8cc61397..e387034110df9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3020,22 +3020,22 @@ static Instruction *
 foldSelectOfSymmetricSelect(SelectInst &OuterSelVal,
                             InstCombiner::BuilderTy &Builder) {
 
-  DecomposedSelect OuterSel, InnerSel;
-  if (!match(&OuterSelVal,
-             m_Select(m_Value(OuterSel.Cond),
-                      m_OneUse(m_Select(m_Value(InnerSel.Cond),
-                                        m_Value(InnerSel.TrueVal),
-                                        m_Value(InnerSel.FalseVal))),
-                      m_OneUse(m_Select(m_Deferred(InnerSel.Cond),
-                                        m_Deferred(InnerSel.FalseVal),
-                                        m_Deferred(InnerSel.TrueVal))))))
+  Value *OuterCond, *InnerCond, *InnerTrueVal, *InnerFalseVal;
+  if (!match(
+          &OuterSelVal,
+          m_Select(m_Value(OuterCond),
+                   m_OneUse(m_Select(m_Value(InnerCond), m_Value(InnerTrueVal),
+                                     m_Value(InnerFalseVal))),
+                   m_OneUse(m_Select(m_Deferred(InnerCond),
+                                     m_Deferred(InnerFalseVal),
+                                     m_Deferred(InnerTrueVal))))))
     return nullptr;
 
-  if (OuterSel.Cond->getType() != InnerSel.Cond->getType())
+  if (OuterCond->getType() != InnerCond->getType())
     return nullptr;
 
-  Value *Xor = Builder.CreateXor(InnerSel.Cond, OuterSel.Cond);
-  return SelectInst::Create(Xor, InnerSel.FalseVal, InnerSel.TrueVal);
+  Value *Xor = Builder.CreateXor(InnerCond, OuterCond);
+  return SelectInst::Create(Xor, InnerFalseVal, InnerTrueVal);
 }
 
 /// Look for patterns like



More information about the llvm-commits mailing list