[llvm] [InstCombine] Add fold for select of symmetric selects (PR #98813)

Tim Gymnich via llvm-commits llvm-commits at lists.llvm.org
Sun Jul 14 06:37:35 PDT 2024


https://github.com/tgymnich created https://github.com/llvm/llvm-project/pull/98813

fixes #98800

>From 12b5833be0c4bad2fee19b6db58b16b4c83faf96 Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tgymnich at icloud.com>
Date: Sun, 14 Jul 2024 15:33:31 +0200
Subject: [PATCH] [InstCombine] Add fold for select of symmetric selects

---
 .../InstCombine/InstCombineSelect.cpp         | 33 +++++++++++++++++++
 1 file changed, 33 insertions(+)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 394dfca262e13..1bcbc3fdf884f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3012,6 +3012,36 @@ struct DecomposedSelect {
 };
 } // namespace
 
+/// Folds patterns like:
+///   select c2 (select c1 a b) (select c1 b a)
+/// into:
+///   select (xor c1 c2) b a
+static Instruction *
+foldSelectOfSymmetricSelect(SelectInst &OuterSelVal,
+                            InstCombiner::BuilderTy &Builder) {
+
+  DecomposedSelect OuterSel, InnerSel1, InnerSel2;
+
+  if (!match(
+          &OuterSelVal,
+          m_Select(m_Value(OuterSel.Cond),
+                   m_Select(m_Value(InnerSel1.Cond), m_Value(InnerSel1.TrueVal),
+                            m_Value(InnerSel1.FalseVal)),
+                   m_Select(m_Value(InnerSel2.Cond), m_Value(InnerSel2.TrueVal),
+                            m_Value(InnerSel2.FalseVal)))))
+    return nullptr;
+
+  bool InnerSelsSymmetric = InnerSel1.Cond == InnerSel2.Cond &&
+                            InnerSel1.TrueVal == InnerSel2.FalseVal &&
+                            InnerSel1.FalseVal == InnerSel2.TrueVal;
+
+  if (!InnerSelsSymmetric)
+    return nullptr;
+
+  Value *Xor = Builder.CreateXor(InnerSel1.Cond, OuterSel.Cond);
+  return SelectInst::Create(Xor, InnerSel1.FalseVal, InnerSel1.TrueVal);
+}
+
 /// Look for patterns like
 ///   %outer.cond = select i1 %inner.cond, i1 %alt.cond, i1 false
 ///   %inner.sel = select i1 %inner.cond, i8 %inner.sel.t, i8 %inner.sel.f
@@ -3987,6 +4017,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
     }
   }
 
+  if (Instruction *I = foldSelectOfSymmetricSelect(SI, Builder))
+    return I;
+
   if (Instruction *I = foldNestedSelects(SI, Builder))
     return I;
 



More information about the llvm-commits mailing list