[llvm] [InstCombine] Transform (fcmp + fadd + sel) into (fcmp + sel + fadd) (PR #106492)
Rajat Bajpai via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 28 22:30:05 PDT 2024
https://github.com/rajatbajpai created https://github.com/llvm/llvm-project/pull/106492
Transform `fcmp + fadd + sel` into `fcmp + sel + fadd` which enables the possibility of lowering `fcmp + sel` into `fmax/fmin`.
>From c88ba465ac07288156efabc3ff0acba2920ab610 Mon Sep 17 00:00:00 2001
From: rbajpai <rbajpai at nvidia.com>
Date: Wed, 21 Aug 2024 14:48:08 +0530
Subject: [PATCH] [InstCombine] Transform (fcmp + fadd + sel) into (fcmp + sel
+ fadd)
Transform `fcmp + fadd + sel` into `fcmp + sel + fadd` which enables
the possibility of lowering `fcmp + sel` into `fmax/fmin`.
---
.../InstCombine/InstCombineSelect.cpp | 45 ++++
.../InstCombine/fcmp-fadd-select.ll | 245 ++++++++++++++++++
2 files changed, 290 insertions(+)
create mode 100644 llvm/test/Transforms/InstCombine/fcmp-fadd-select.ll
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index fcd11126073bf1..17f1b3a1ec24ae 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3668,6 +3668,47 @@ static bool hasAffectedValue(Value *V, SmallPtrSetImpl<Value *> &Affected,
return false;
}
+static Value *foldSelectAddConstant(SelectInst &SI,
+ InstCombiner::BuilderTy &Builder) {
+ Value *Cmp;
+ Instruction *FAdd;
+ ConstantFP *C;
+
+ // select((fcmp OGT/OLT, X, 0), (fadd X, C), C) => fadd((select (fcmp OGT/OLT, X, 0), X, 0), C)
+ // This transformation enables the possibility of transforming fcmp + sel into a fmax/fmin.
+
+ // OneUse check for `Cmp` is necessary because it makes sure that other InstCombine
+ // folds don't undo this transformation and cause an infinite loop.
+ if (match(&SI, m_Select(m_OneUse(m_Value(Cmp)), m_OneUse(m_Instruction(FAdd)),
+ m_ConstantFP(C))) ||
+ match(&SI, m_Select(m_OneUse(m_Value(Cmp)), m_ConstantFP(C),
+ m_OneUse(m_Instruction(FAdd))))) {
+ Value *X;
+ CmpInst::Predicate Pred;
+ if (!match(Cmp, m_FCmp(Pred, m_Value(X), m_AnyZeroFP())))
+ return nullptr;
+
+ if (Pred != CmpInst::FCMP_OGT && Pred != CmpInst::FCMP_OLT)
+ return nullptr;
+
+ if (!match(FAdd, m_FAdd(m_Specific(X), m_Specific(C))))
+ return nullptr;
+
+ FastMathFlags FMF = FAdd->getFastMathFlags();
+ FMF |= SI.getFastMathFlags();
+
+ Value *NewSelect = Builder.CreateSelect(
+ Cmp, X, ConstantFP::getZero(C->getType()), SI.getName() + ".new", &SI);
+ cast<Instruction>(NewSelect)->setFastMathFlags(FMF);
+
+ Value *NewFAdd =
+ Builder.CreateFAddFMF(NewSelect, C, FAdd, FAdd->getName() + ".new");
+ return NewFAdd;
+ }
+
+ return nullptr;
+}
+
Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
Value *CondVal = SI.getCondition();
Value *TrueVal = SI.getTrueValue();
@@ -4067,6 +4108,10 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (Value *V = foldRoundUpIntegerWithPow2Alignment(SI, Builder))
return replaceInstUsesWith(SI, V);
+ if (Value *V = foldSelectAddConstant(SI, Builder)) {
+ return replaceInstUsesWith(SI, V);
+ }
+
// select(mask, mload(,,mask,0), 0) -> mload(,,mask,0)
// Load inst is intentionally not checked for hasOneUse()
if (match(FalseVal, m_Zero()) &&
diff --git a/llvm/test/Transforms/InstCombine/fcmp-fadd-select.ll b/llvm/test/Transforms/InstCombine/fcmp-fadd-select.ll
new file mode 100644
index 00000000000000..fced2d961b2415
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/fcmp-fadd-select.ll
@@ -0,0 +1,245 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+; Check for fcmp + sel pattern which later lowered into fmax
+define float @test_fmax_pos1(float %in) {
+; CHECK-LABEL: define float @test_fmax_pos1(
+; CHECK-SAME: float [[IN:%.*]]) {
+; CHECK-NEXT: [[CMP1:%.*]] = fcmp ogt float [[IN]], 0.000000e+00
+; CHECK-NEXT: [[SEL_NEW:%.*]] = select i1 [[CMP1]], float [[IN]], float 0.000000e+00
+; CHECK-NEXT: [[ADD_NEW:%.*]] = fadd float [[SEL_NEW]], 1.000000e+00
+; CHECK-NEXT: ret float [[ADD_NEW]]
+;
+ %cmp1 = fcmp ogt float %in, 0.000000e+00
+ %add = fadd float %in, 1.000000e+00
+ %sel = select i1 %cmp1, float %add, float 1.000000e+00
+ ret float %sel
+}
+
+define float @test_fmax_pos2(float %in) {
+; CHECK-LABEL: define float @test_fmax_pos2(
+; CHECK-SAME: float [[IN:%.*]]) {
+; CHECK-NEXT: [[CMP1:%.*]] = fcmp ogt float [[IN]], 0.000000e+00
+; CHECK-NEXT: [[SEL_NEW:%.*]] = select i1 [[CMP1]], float [[IN]], float 0.000000e+00
+; CHECK-NEXT: [[ADD_NEW:%.*]] = fadd float [[SEL_NEW]], 1.000000e+00
+; CHECK-NEXT: ret float [[ADD_NEW]]
+;
+ %cmp1 = fcmp ogt float %in, 0.000000e+00
+ %add = fadd float %in, 1.000000e+00
+ %sel = select i1 %cmp1, float 1.000000e+00, float %add
+ ret float %sel
+}
+
+define float @test_fmax_pos3(float %in) {
+; CHECK-LABEL: define float @test_fmax_pos3(
+; CHECK-SAME: float [[IN:%.*]]) {
+; CHECK-NEXT: [[CMP1:%.*]] = fcmp ogt float [[IN]], 0.000000e+00
+; CHECK-NEXT: [[SEL_NEW:%.*]] = select i1 [[CMP1]], float [[IN]], float 0.000000e+00
+; CHECK-NEXT: [[ADD_NEW:%.*]] = fadd float [[SEL_NEW]], 1.000000e+00
+; CHECK-NEXT: ret float [[ADD_NEW]]
+;
+ %cmp1 = fcmp ogt float %in, 0.000000e+00
+ %add = fadd float 1.000000e+00, %in
+ %sel = select i1 %cmp1, float %add, float 1.000000e+00
+ ret float %sel
+}
+
+define float @test_fmax_pos4(float %in) {
+; CHECK-LABEL: define float @test_fmax_pos4(
+; CHECK-SAME: float [[IN:%.*]]) {
+; CHECK-NEXT: [[CMP1:%.*]] = fcmp ogt float [[IN]], 0.000000e+00
+; CHECK-NEXT: [[SEL_NEW:%.*]] = select i1 [[CMP1]], float [[IN]], float 0.000000e+00
+; CHECK-NEXT: [[ADD_NEW:%.*]] = fadd float [[SEL_NEW]], 1.000000e+00
+; CHECK-NEXT: ret float [[ADD_NEW]]
+;
+ %cmp1 = fcmp ogt float %in, 0.000000e+00
+ %add = fadd float 1.000000e+00, %in
+ %sel = select i1 %cmp1, float 1.000000e+00, float %add
+ ret float %sel
+}
+
+define float @test_fmax_pos5(float %in) {
+; CHECK-LABEL: define float @test_fmax_pos5(
+; CHECK-SAME: float [[IN:%.*]]) {
+; CHECK-NEXT: [[CMP1:%.*]] = fcmp ogt float [[IN]], 0.000000e+00
+; CHECK-NEXT: [[SEL_NEW:%.*]] = select i1 [[CMP1]], float [[IN]], float 0.000000e+00
+; CHECK-NEXT: [[ADD_NEW:%.*]] = fadd float [[SEL_NEW]], 2.000000e+00
+; CHECK-NEXT: ret float [[ADD_NEW]]
+;
+ %cmp1 = fcmp ogt float %in, 0.000000e+00
+ %add = fadd float 2.000000e+00, %in
+ %sel = select i1 %cmp1, float 2.000000e+00, float %add
+ ret float %sel
+}
+
+
+; Check for fcmp + sel pattern which later lowered into fmin
+define float @test_fmin_pos1(float %in) {
+; CHECK-LABEL: define float @test_fmin_pos1(
+; CHECK-SAME: float [[IN:%.*]]) {
+; CHECK-NEXT: [[CMP1:%.*]] = fcmp olt float [[IN]], 0.000000e+00
+; CHECK-NEXT: [[SEL_NEW:%.*]] = select i1 [[CMP1]], float [[IN]], float 0.000000e+00
+; CHECK-NEXT: [[ADD_NEW:%.*]] = fadd float [[SEL_NEW]], 1.000000e+00
+; CHECK-NEXT: ret float [[ADD_NEW]]
+;
+ %cmp1 = fcmp olt float %in, 0.000000e+00
+ %add = fadd float %in, 1.000000e+00
+ %sel = select i1 %cmp1, float %add, float 1.000000e+00
+ ret float %sel
+}
+
+define float @test_fmin_pos2(float %in) {
+; CHECK-LABEL: define float @test_fmin_pos2(
+; CHECK-SAME: float [[IN:%.*]]) {
+; CHECK-NEXT: [[CMP1:%.*]] = fcmp olt float [[IN]], 0.000000e+00
+; CHECK-NEXT: [[SEL_NEW:%.*]] = select i1 [[CMP1]], float [[IN]], float 0.000000e+00
+; CHECK-NEXT: [[ADD_NEW:%.*]] = fadd float [[SEL_NEW]], 1.000000e+00
+; CHECK-NEXT: ret float [[ADD_NEW]]
+;
+ %cmp1 = fcmp olt float %in, 0.000000e+00
+ %add = fadd float %in, 1.000000e+00
+ %sel = select i1 %cmp1, float 1.000000e+00, float %add
+ ret float %sel
+}
+
+define float @test_fmin_pos3(float %in) {
+; CHECK-LABEL: define float @test_fmin_pos3(
+; CHECK-SAME: float [[IN:%.*]]) {
+; CHECK-NEXT: [[CMP1:%.*]] = fcmp olt float [[IN]], 0.000000e+00
+; CHECK-NEXT: [[SEL_NEW:%.*]] = select i1 [[CMP1]], float [[IN]], float 0.000000e+00
+; CHECK-NEXT: [[ADD_NEW:%.*]] = fadd float [[SEL_NEW]], 1.000000e+00
+; CHECK-NEXT: ret float [[ADD_NEW]]
+;
+ %cmp1 = fcmp olt float %in, 0.000000e+00
+ %add = fadd float 1.000000e+00, %in
+ %sel = select i1 %cmp1, float %add, float 1.000000e+00
+ ret float %sel
+}
+
+define float @test_fmin_pos4(float %in) {
+; CHECK-LABEL: define float @test_fmin_pos4(
+; CHECK-SAME: float [[IN:%.*]]) {
+; CHECK-NEXT: [[CMP1:%.*]] = fcmp olt float [[IN]], 0.000000e+00
+; CHECK-NEXT: [[SEL_NEW:%.*]] = select i1 [[CMP1]], float [[IN]], float 0.000000e+00
+; CHECK-NEXT: [[ADD_NEW:%.*]] = fadd float [[SEL_NEW]], 1.000000e+00
+; CHECK-NEXT: ret float [[ADD_NEW]]
+;
+ %cmp1 = fcmp olt float %in, 0.000000e+00
+ %add = fadd float 1.000000e+00, %in
+ %sel = select i1 %cmp1, float 1.000000e+00, float %add
+ ret float %sel
+}
+
+define float @test_fmin_pos5(float %in) {
+; CHECK-LABEL: define float @test_fmin_pos5(
+; CHECK-SAME: float [[IN:%.*]]) {
+; CHECK-NEXT: [[CMP1:%.*]] = fcmp olt float [[IN]], 0.000000e+00
+; CHECK-NEXT: [[SEL_NEW:%.*]] = select i1 [[CMP1]], float [[IN]], float 0.000000e+00
+; CHECK-NEXT: [[ADD_NEW:%.*]] = fadd float [[SEL_NEW]], 2.000000e+00
+; CHECK-NEXT: ret float [[ADD_NEW]]
+;
+ %cmp1 = fcmp olt float %in, 0.000000e+00
+ %add = fadd float 2.000000e+00, %in
+ %sel = select i1 %cmp1, float 2.000000e+00, float %add
+ ret float %sel
+}
+
+
+; Check for fmax scenarios that shouldn't be transformed.
+define float @test_fmax_neg1(float %in, float %in2) {
+; CHECK-LABEL: define float @test_fmax_neg1(
+; CHECK-SAME: float [[IN:%.*]], float [[IN2:%.*]]) {
+; CHECK-NEXT: [[CMP1:%.*]] = fcmp ogt float [[IN2]], 0.000000e+00
+; CHECK-NEXT: [[ADD:%.*]] = fadd float [[IN]], 1.000000e+00
+; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP1]], float [[ADD]], float 1.000000e+00
+; CHECK-NEXT: ret float [[SEL]]
+;
+ %cmp1 = fcmp ogt float %in2, 0.000000e+00
+ %add = fadd float %in, 1.000000e+00
+ %sel = select i1 %cmp1, float %add, float 1.000000e+00
+ ret float %sel
+}
+
+define float @test_fmax_neg2(float %in) {
+; CHECK-LABEL: define float @test_fmax_neg2(
+; CHECK-SAME: float [[IN:%.*]]) {
+; CHECK-NEXT: [[CMP1:%.*]] = fcmp ogt float [[IN]], 1.000000e+00
+; CHECK-NEXT: [[ADD:%.*]] = fadd float [[IN]], 1.000000e+00
+; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP1]], float [[ADD]], float 1.000000e+00
+; CHECK-NEXT: ret float [[SEL]]
+;
+ %cmp1 = fcmp ogt float %in, 1.000000e+00
+ %add = fadd float %in, 1.000000e+00
+ %sel = select i1 %cmp1, float %add, float 1.000000e+00
+ ret float %sel
+}
+
+define float @test_fmax_neg3(float %in) {
+; CHECK-LABEL: define float @test_fmax_neg3(
+; CHECK-SAME: float [[IN:%.*]]) {
+; CHECK-NEXT: [[CMP1:%.*]] = fcmp ogt float [[IN]], 0.000000e+00
+; CHECK-NEXT: [[ADD:%.*]] = fadd float [[IN]], 1.000000e+00
+; CHECK-NEXT: [[ADD_2:%.*]] = fadd float [[IN]], 1.000000e+00
+; CHECK-NEXT: [[SEL_1:%.*]] = select i1 [[CMP1]], float [[ADD]], float 1.000000e+00
+; CHECK-NEXT: [[SEL_2:%.*]] = select i1 [[CMP1]], float 2.000000e+00, float [[ADD_2]]
+; CHECK-NEXT: [[RES:%.*]] = fadd float [[SEL_1]], [[SEL_2]]
+; CHECK-NEXT: ret float [[RES]]
+;
+ %cmp1 = fcmp ogt float %in, 0.000000e+00
+ %add = fadd float %in, 1.000000e+00
+ %add.2 = fadd float %in, 1.000000e+00
+ %sel.1 = select i1 %cmp1, float %add, float 1.000000e+00
+ %sel.2 = select i1 %cmp1, float 2.000000e+00, float %add.2
+ %res = fadd float %sel.1, %sel.2
+ ret float %res
+}
+
+
+; Check for fmin scenarios that shouldn't be transformed.
+define float @test_fmin_neg1(float %in, float %in2) {
+; CHECK-LABEL: define float @test_fmin_neg1(
+; CHECK-SAME: float [[IN:%.*]], float [[IN2:%.*]]) {
+; CHECK-NEXT: [[CMP1:%.*]] = fcmp olt float [[IN2]], 0.000000e+00
+; CHECK-NEXT: [[ADD:%.*]] = fadd float [[IN]], 1.000000e+00
+; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP1]], float [[ADD]], float 1.000000e+00
+; CHECK-NEXT: ret float [[SEL]]
+;
+ %cmp1 = fcmp olt float %in2, 0.000000e+00
+ %add = fadd float %in, 1.000000e+00
+ %sel = select i1 %cmp1, float %add, float 1.000000e+00
+ ret float %sel
+}
+
+define float @test_fmin_neg2(float %in) {
+; CHECK-LABEL: define float @test_fmin_neg2(
+; CHECK-SAME: float [[IN:%.*]]) {
+; CHECK-NEXT: [[CMP1:%.*]] = fcmp olt float [[IN]], 1.000000e+00
+; CHECK-NEXT: [[ADD:%.*]] = fadd float [[IN]], 1.000000e+00
+; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP1]], float [[ADD]], float 1.000000e+00
+; CHECK-NEXT: ret float [[SEL]]
+;
+ %cmp1 = fcmp olt float %in, 1.000000e+00
+ %add = fadd float %in, 1.000000e+00
+ %sel = select i1 %cmp1, float %add, float 1.000000e+00
+ ret float %sel
+}
+
+define float @test_fmin_neg3(float %in) {
+; CHECK-LABEL: define float @test_fmin_neg3(
+; CHECK-SAME: float [[IN:%.*]]) {
+; CHECK-NEXT: [[CMP1:%.*]] = fcmp olt float [[IN]], 0.000000e+00
+; CHECK-NEXT: [[ADD:%.*]] = fadd float [[IN]], 1.000000e+00
+; CHECK-NEXT: [[ADD_2:%.*]] = fadd float [[IN]], 1.000000e+00
+; CHECK-NEXT: [[SEL_1:%.*]] = select i1 [[CMP1]], float [[ADD]], float 1.000000e+00
+; CHECK-NEXT: [[SEL_2:%.*]] = select i1 [[CMP1]], float 2.000000e+00, float [[ADD_2]]
+; CHECK-NEXT: [[RES:%.*]] = fadd float [[SEL_1]], [[SEL_2]]
+; CHECK-NEXT: ret float [[RES]]
+;
+ %cmp1 = fcmp olt float %in, 0.000000e+00
+ %add = fadd float %in, 1.000000e+00
+ %add.2 = fadd float %in, 1.000000e+00
+ %sel.1 = select i1 %cmp1, float %add, float 1.000000e+00
+ %sel.2 = select i1 %cmp1, float 2.000000e+00, float %add.2
+ %res = fadd float %sel.1, %sel.2
+ ret float %res
+}
More information about the llvm-commits
mailing list