[llvm] [InstCombine] Fold copysign of selects from sign comparison to sign operand (PR #85627)
Krishna Narayanan via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 18 03:33:38 PDT 2024
https://github.com/Krishna-13-cyber created https://github.com/llvm/llvm-project/pull/85627
This is currently under development and needs some assistance/review to go about solving this.
This PR intends to solve this issue #64884.
Blocker:
I pattern matched the most probable scenario according to me. This approach doesn't give the desired outcome of folding instructions and fails tests.
Any lead on where I am going wrong/could improve?
>From 96403f5361e5331a35fe8d158859e01bf208a22f Mon Sep 17 00:00:00 2001
From: Krishna-13-cyber <krishnanarayanan132002 at gmail.com>
Date: Mon, 18 Mar 2024 14:12:15 +0530
Subject: [PATCH] Add support for folding sign comparison to sign operand
---
.../InstCombine/InstCombineSelect.cpp | 45 +++++++++++++++++++
llvm/test/Transforms/InstCombine/fcmp.ll | 16 +++++++
2 files changed, 61 insertions(+)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index ee76a6294428b3..49e10646eb5d15 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2790,6 +2790,47 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
return ChangedFMF ? &SI : nullptr;
}
+// Canonicalize select with fcmp -> select
+static Instruction *foldSelectWithFCmp(SelectInst &SI, InstCombinerImpl &IC) {
+ /* From
+ %4 = fcmp olt float %1, 0.000000e+00
+ %5 = and i1 %4, %0
+ %6 = select i1 %5, float -1.000000e+00, float 1.000000e+00
+ */
+ /* To
+ %4 = select i1 %0, float %1, float 1.000000e+00
+ */
+ Value *CondVal = SI.getCondition();
+ Value *TrueVal = SI.getTrueValue();
+ Value *FalseVal = SI.getFalseValue();
+ Value *One = Constant::getAllOnesValue(FalseVal->getType());
+ Value *X, *C, *Op;
+ const APFloat *A, *E;
+ CmpInst::Predicate Pred;
+ for (bool Swap : {false, true}) {
+ if (Swap)
+ std::swap(TrueVal, FalseVal);
+ if (match(&SI, (m_Value(CondVal), m_APFloat(A), m_APFloat(E)))) {
+ if (!match(TrueVal, m_APFloatAllowUndef(A)) &&
+ !match(FalseVal, m_APFloatAllowUndef(E)))
+ return nullptr;
+ if (!match(CondVal, m_And(m_FCmp(Pred, m_Specific(X), m_PosZeroFP()),
+ m_Value(C))) &&
+ (X->hasOneUse() && C->hasOneUse()))
+ return nullptr;
+ if (!A->isNegative() && E->isNegative())
+ return nullptr;
+ if (!Swap && (Pred == FCmpInst::FCMP_OLT)) {
+ return SelectInst::Create(C, X, One);
+ }
+ if (Swap && (Pred == FCmpInst::FCMP_OGT)) {
+ return SelectInst::Create(C, X, One);
+ }
+ }
+ }
+ return nullptr;
+}
+
// Match the following IR pattern:
// %x.lowbits = and i8 %x, %lowbitmask
// %x.lowbits.are.zero = icmp eq i8 %x.lowbits, 0
@@ -3508,6 +3549,10 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (Instruction *Fabs = foldSelectWithFCmpToFabs(SI, *this))
return Fabs;
+ // Fold selecting to ffold.
+ if (Instruction *Ffold = foldSelectWithFCmp(SI, *this))
+ return Ffold;
+
// See if we are selecting two values based on a comparison of the two values.
if (ICmpInst *ICI = dyn_cast<ICmpInst>(CondVal))
if (Instruction *Result = foldSelectInstWithICmp(SI, ICI))
diff --git a/llvm/test/Transforms/InstCombine/fcmp.ll b/llvm/test/Transforms/InstCombine/fcmp.ll
index 159c84d0dd8aa9..574fb1f9417742 100644
--- a/llvm/test/Transforms/InstCombine/fcmp.ll
+++ b/llvm/test/Transforms/InstCombine/fcmp.ll
@@ -1284,3 +1284,19 @@ define <1 x i1> @bitcast_1vec_eq0(i32 %x) {
%cmp = fcmp oeq <1 x float> %f, zeroinitializer
ret <1 x i1> %cmp
}
+
+define float @copysign_conditional(i1 noundef zeroext %0, float %1, float %2) {
+; CHECK-LABEL: define float @copysign_conditional(
+; CHECK-SAME: i1 noundef zeroext [[TMP0:%.*]], float [[TMP1:%.*]], float [[TMP2:%.*]]) {
+; CHECK-NEXT: [[TMP4:%.*]] = fcmp olt float [[TMP1]], 0.000000e+00
+; CHECK-NEXT: [[TMP5:%.*]] = and i1 [[TMP4]], [[TMP0]]
+; CHECK-NEXT: [[TMP6:%.*]] = select i1 [[TMP5]], float -1.000000e+00, float 1.000000e+00
+; CHECK-NEXT: [[TMP7:%.*]] = tail call float @llvm.copysign.f32(float [[TMP2]], float [[TMP6]])
+; CHECK-NEXT: ret float [[TMP7]]
+;
+ %4 = fcmp olt float %1, 0.000000e+00
+ %5 = and i1 %4, %0
+ %6 = select i1 %5, float -1.000000e+00, float 1.000000e+00
+ %7 = tail call float @llvm.copysign.f32(float %2, float %6)
+ ret float %7
+}
\ No newline at end of file
More information about the llvm-commits
mailing list