[llvm] [InstCombine] InstCombine should fold frexp of select to select of frexp (PR #121227)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 30 08:51:46 PST 2025


https://github.com/vortex73 updated https://github.com/llvm/llvm-project/pull/121227

>From 7e22e93fcda75d4f03624a68f00685d68bb6d32b Mon Sep 17 00:00:00 2001
From: Narayan Sreekumar <nsreekumar6 at gmail.com>
Date: Fri, 3 Jan 2025 01:33:41 +0530
Subject: [PATCH 1/4] [InstCombine] Pre-Commit Tests

---
 .../Transforms/InstCombine/select_frexp.ll    | 129 ++++++++++++++++++
 1 file changed, 129 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/select_frexp.ll

diff --git a/llvm/test/Transforms/InstCombine/select_frexp.ll b/llvm/test/Transforms/InstCombine/select_frexp.ll
new file mode 100644
index 000000000000000..b3f05f4db42dd19
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/select_frexp.ll
@@ -0,0 +1,129 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes=instcombine -S < %s | FileCheck %s
+
+declare { float, i32 } @llvm.frexp.f32.i32(float)
+declare void @use(float)
+
+; Basic test case - constant in true position
+define float @test_select_frexp_basic(float %x, i1 %cond) {
+; CHECK-LABEL: define float @test_select_frexp_basic(
+; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) {
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[COND]], float 1.000000e+00, float [[X]]
+; CHECK-NEXT:    [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]])
+; CHECK-NEXT:    [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0
+; CHECK-NEXT:    ret float [[FREXP_0]]
+;
+  %sel = select i1 %cond, float 1.000000e+00, float %x
+  %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel)
+  %frexp.0 = extractvalue { float, i32 } %frexp, 0
+  ret float %frexp.0
+}
+
+; Test with constant in false position
+define float @test_select_frexp_const_false(float %x, i1 %cond) {
+; CHECK-LABEL: define float @test_select_frexp_const_false(
+; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) {
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[COND]], float [[X]], float 1.000000e+00
+; CHECK-NEXT:    [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]])
+; CHECK-NEXT:    [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0
+; CHECK-NEXT:    ret float [[FREXP_0]]
+;
+  %sel = select i1 %cond, float %x, float 1.000000e+00
+  %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel)
+  %frexp.0 = extractvalue { float, i32 } %frexp, 0
+  ret float %frexp.0
+}
+
+; Multi-use test
+define float @test_select_frexp_multi_use(float %x, i1 %cond) {
+; CHECK-LABEL: define float @test_select_frexp_multi_use(
+; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) {
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[COND]], float 1.000000e+00, float [[X]]
+; CHECK-NEXT:    call void @use(float [[SEL]])
+; CHECK-NEXT:    [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]])
+; CHECK-NEXT:    [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0
+; CHECK-NEXT:    ret float [[FREXP_0]]
+;
+  %sel = select i1 %cond, float 1.000000e+00, float %x
+  call void @use(float %sel)
+  %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel)
+  %frexp.0 = extractvalue { float, i32 } %frexp, 0
+  ret float %frexp.0
+}
+
+; Vector test - splat constant
+define <2 x float> @test_select_frexp_vec_splat(<2 x float> %x, <2 x i1> %cond) {
+; CHECK-LABEL: define <2 x float> @test_select_frexp_vec_splat(
+; CHECK-SAME: <2 x float> [[X:%.*]], <2 x i1> [[COND:%.*]]) {
+; CHECK-NEXT:    [[SEL:%.*]] = select <2 x i1> [[COND]], <2 x float> splat (float 1.000000e+00), <2 x float> [[X]]
+; CHECK-NEXT:    [[FREXP:%.*]] = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> [[SEL]])
+; CHECK-NEXT:    [[FREXP_0:%.*]] = extractvalue { <2 x float>, <2 x i32> } [[FREXP]], 0
+; CHECK-NEXT:    ret <2 x float> [[FREXP_0]]
+;
+  %sel = select <2 x i1> %cond, <2 x float> <float 1.000000e+00, float 1.000000e+00>, <2 x float> %x
+  %frexp = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> %sel)
+  %frexp.0 = extractvalue { <2 x float>, <2 x i32> } %frexp, 0
+  ret <2 x float> %frexp.0
+}
+
+; Vector test with poison
+define <2 x float> @test_select_frexp_vec_poison(<2 x float> %x, <2 x i1> %cond) {
+; CHECK-LABEL: define <2 x float> @test_select_frexp_vec_poison(
+; CHECK-SAME: <2 x float> [[X:%.*]], <2 x i1> [[COND:%.*]]) {
+; CHECK-NEXT:    [[SEL:%.*]] = select <2 x i1> [[COND]], <2 x float> <float 1.000000e+00, float poison>, <2 x float> [[X]]
+; CHECK-NEXT:    [[FREXP:%.*]] = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> [[SEL]])
+; CHECK-NEXT:    [[FREXP_0:%.*]] = extractvalue { <2 x float>, <2 x i32> } [[FREXP]], 0
+; CHECK-NEXT:    ret <2 x float> [[FREXP_0]]
+;
+  %sel = select <2 x i1> %cond, <2 x float> <float 1.000000e+00, float poison>, <2 x float> %x
+  %frexp = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> %sel)
+  %frexp.0 = extractvalue { <2 x float>, <2 x i32> } %frexp, 0
+  ret <2 x float> %frexp.0
+}
+
+; Vector test - non-splat (should not fold)
+define <2 x float> @test_select_frexp_vec_nonsplat(<2 x float> %x, <2 x i1> %cond) {
+; CHECK-LABEL: define <2 x float> @test_select_frexp_vec_nonsplat(
+; CHECK-SAME: <2 x float> [[X:%.*]], <2 x i1> [[COND:%.*]]) {
+; CHECK-NEXT:    [[SEL:%.*]] = select <2 x i1> [[COND]], <2 x float> <float 1.000000e+00, float 2.000000e+00>, <2 x float> [[X]]
+; CHECK-NEXT:    [[FREXP:%.*]] = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> [[SEL]])
+; CHECK-NEXT:    [[FREXP_0:%.*]] = extractvalue { <2 x float>, <2 x i32> } [[FREXP]], 0
+; CHECK-NEXT:    ret <2 x float> [[FREXP_0]]
+;
+  %sel = select <2 x i1> %cond, <2 x float> <float 1.000000e+00, float 2.000000e+00>, <2 x float> %x
+  %frexp = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> %sel)
+  %frexp.0 = extractvalue { <2 x float>, <2 x i32> } %frexp, 0
+  ret <2 x float> %frexp.0
+}
+
+; Negative test - both operands non-constant
+define float @test_select_frexp_no_const(float %x, float %y, i1 %cond) {
+; CHECK-LABEL: define float @test_select_frexp_no_const(
+; CHECK-SAME: float [[X:%.*]], float [[Y:%.*]], i1 [[COND:%.*]]) {
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[COND]], float [[X]], float [[Y]]
+; CHECK-NEXT:    [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]])
+; CHECK-NEXT:    [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0
+; CHECK-NEXT:    ret float [[FREXP_0]]
+;
+  %sel = select i1 %cond, float %x, float %y
+  %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel)
+  %frexp.0 = extractvalue { float, i32 } %frexp, 0
+  ret float %frexp.0
+}
+
+; Negative test - extracting exp instead of mantissa
+define i32 @test_select_frexp_extract_exp(float %x, i1 %cond) {
+; CHECK-LABEL: define i32 @test_select_frexp_extract_exp(
+; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) {
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[COND]], float 1.000000e+00, float [[X]]
+; CHECK-NEXT:    [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]])
+; CHECK-NEXT:    [[FREXP_1:%.*]] = extractvalue { float, i32 } [[FREXP]], 1
+; CHECK-NEXT:    ret i32 [[FREXP_1]]
+;
+  %sel = select i1 %cond, float 1.000000e+00, float %x
+  %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel)
+  %frexp.1 = extractvalue { float, i32 } %frexp, 1
+  ret i32 %frexp.1
+}
+
+declare { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float>)

>From b738a1dcfc577b27002b53ec7008f0e26753cdb6 Mon Sep 17 00:00:00 2001
From: Narayan Sreekumar <nsreekumar6 at gmail.com>
Date: Fri, 3 Jan 2025 18:03:45 +0530
Subject: [PATCH 2/4] [InstCombine] InstCombine should fold frexp of select to
 select of frexp

---
 .../InstCombine/InstructionCombining.cpp      | 67 ++++++++++++++++++-
 .../Transforms/InstCombine/select_frexp.ll    | 17 ++---
 2 files changed, 75 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index a64c188575e6c37..54919ea2f7386ce 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -4069,6 +4069,52 @@ InstCombinerImpl::foldExtractOfOverflowIntrinsic(ExtractValueInst &EV) {
   return nullptr;
 }
 
+static Value *foldFrexpOfSelect(ExtractValueInst &EV, CallInst *FrexpCall,
+                                SelectInst *SelectInst,
+                                InstCombiner::BuilderTy &Builder) {
+  // Helper to fold frexp of select to select of frexp.
+  Value *Cond = SelectInst->getCondition();
+  Value *TrueVal = SelectInst->getTrueValue();
+  Value *FalseVal = SelectInst->getFalseValue();
+  ConstantFP *ConstOp = nullptr;
+  Value *VarOp = nullptr;
+  bool ConstIsTrue = false;
+
+  if (auto *TrueConst = dyn_cast<ConstantFP>(TrueVal)) {
+    ConstOp = TrueConst;
+    VarOp = FalseVal;
+    ConstIsTrue = true;
+  } else if (auto *FalseConst = dyn_cast<ConstantFP>(FalseVal)) {
+    ConstOp = FalseConst;
+    VarOp = TrueVal;
+    ConstIsTrue = false;
+  }
+
+  if (!ConstOp || !VarOp)
+    return nullptr;
+
+  CallInst *NewFrexp =
+      Builder.CreateCall(FrexpCall->getCalledFunction(), {VarOp}, "frexp");
+
+  Value *NewEV = Builder.CreateExtractValue(NewFrexp, 0, "mantissa");
+
+  APFloat ConstVal = ConstOp->getValueAPF();
+  int Exp = 0;
+  APFloat Mantissa = ConstVal;
+
+  if (ConstVal.isFiniteNonZero()) {
+    Mantissa = frexp(ConstVal, Exp, APFloat::rmNearestTiesToEven);
+  }
+
+  Constant *ConstantMantissa = ConstantFP::get(ConstOp->getType(), Mantissa);
+
+  Value *NewSel = Builder.CreateSelect(
+      Cond, ConstIsTrue ? ConstantMantissa : NewEV,
+      ConstIsTrue ? NewEV : ConstantMantissa, "select.frexp");
+
+  return NewSel;
+}
+
 Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) {
   Value *Agg = EV.getAggregateOperand();
 
@@ -4078,7 +4124,26 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) {
   if (Value *V = simplifyExtractValueInst(Agg, EV.getIndices(),
                                           SQ.getWithInstruction(&EV)))
     return replaceInstUsesWith(EV, V);
-
+  if (EV.getNumIndices() == 1 && EV.getIndices()[0] == 0) {
+    if (auto *FrexpCall = dyn_cast<CallInst>(Agg)) {
+      if (Function *F = FrexpCall->getCalledFunction()) {
+        if (F->getIntrinsicID() == Intrinsic::frexp) {
+          if (auto *SelInst =
+                  dyn_cast<SelectInst>(FrexpCall->getArgOperand(0))) {
+            if (isa<ConstantFP>(SelInst->getTrueValue()) ||
+                isa<ConstantFP>(SelInst->getFalseValue())) {
+              Builder.SetInsertPoint(&EV);
+
+              if (Value *Result =
+                      foldFrexpOfSelect(EV, FrexpCall, SelInst, Builder)) {
+                return replaceInstUsesWith(EV, Result);
+              }
+            }
+          }
+        }
+      }
+    }
+  }
   if (InsertValueInst *IV = dyn_cast<InsertValueInst>(Agg)) {
     // We're extracting from an insertvalue instruction, compare the indices
     const unsigned *exti, *exte, *insi, *inse;
diff --git a/llvm/test/Transforms/InstCombine/select_frexp.ll b/llvm/test/Transforms/InstCombine/select_frexp.ll
index b3f05f4db42dd19..652d4de27b7591f 100644
--- a/llvm/test/Transforms/InstCombine/select_frexp.ll
+++ b/llvm/test/Transforms/InstCombine/select_frexp.ll
@@ -8,10 +8,10 @@ declare void @use(float)
 define float @test_select_frexp_basic(float %x, i1 %cond) {
 ; CHECK-LABEL: define float @test_select_frexp_basic(
 ; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) {
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[COND]], float 1.000000e+00, float [[X]]
-; CHECK-NEXT:    [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]])
+; CHECK-NEXT:    [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[X]])
 ; CHECK-NEXT:    [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0
-; CHECK-NEXT:    ret float [[FREXP_0]]
+; CHECK-NEXT:    [[SELECT_FREXP:%.*]] = select i1 [[COND]], float 5.000000e-01, float [[FREXP_0]]
+; CHECK-NEXT:    ret float [[SELECT_FREXP]]
 ;
   %sel = select i1 %cond, float 1.000000e+00, float %x
   %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel)
@@ -23,10 +23,10 @@ define float @test_select_frexp_basic(float %x, i1 %cond) {
 define float @test_select_frexp_const_false(float %x, i1 %cond) {
 ; CHECK-LABEL: define float @test_select_frexp_const_false(
 ; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) {
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[COND]], float [[X]], float 1.000000e+00
-; CHECK-NEXT:    [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]])
+; CHECK-NEXT:    [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[X]])
 ; CHECK-NEXT:    [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0
-; CHECK-NEXT:    ret float [[FREXP_0]]
+; CHECK-NEXT:    [[SELECT_FREXP:%.*]] = select i1 [[COND]], float [[FREXP_0]], float 5.000000e-01
+; CHECK-NEXT:    ret float [[SELECT_FREXP]]
 ;
   %sel = select i1 %cond, float %x, float 1.000000e+00
   %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel)
@@ -40,9 +40,10 @@ define float @test_select_frexp_multi_use(float %x, i1 %cond) {
 ; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) {
 ; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[COND]], float 1.000000e+00, float [[X]]
 ; CHECK-NEXT:    call void @use(float [[SEL]])
-; CHECK-NEXT:    [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]])
+; CHECK-NEXT:    [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[X]])
 ; CHECK-NEXT:    [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0
-; CHECK-NEXT:    ret float [[FREXP_0]]
+; CHECK-NEXT:    [[SELECT_FREXP:%.*]] = select i1 [[COND]], float 5.000000e-01, float [[FREXP_0]]
+; CHECK-NEXT:    ret float [[SELECT_FREXP]]
 ;
   %sel = select i1 %cond, float 1.000000e+00, float %x
   call void @use(float %sel)

>From 1e656c8957a5e3a608c694060c38d94327105821 Mon Sep 17 00:00:00 2001
From: Narayan Sreekumar <nsreekumar6 at gmail.com>
Date: Mon, 6 Jan 2025 23:08:37 +0530
Subject: [PATCH 3/4] [InstCombine] Refactor and Preserve fast math flags

---
 .../InstCombine/InstructionCombining.cpp      | 58 +++++++++----------
 .../Transforms/InstCombine/select_frexp.ll    | 37 +++++++++++-
 2 files changed, 62 insertions(+), 33 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 54919ea2f7386ce..bc07c7c047efbec 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -33,6 +33,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "InstCombineInternal.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
@@ -4069,52 +4070,57 @@ InstCombinerImpl::foldExtractOfOverflowIntrinsic(ExtractValueInst &EV) {
   return nullptr;
 }
 
-static Value *foldFrexpOfSelect(ExtractValueInst &EV, CallInst *FrexpCall,
+static Value *foldFrexpOfSelect(ExtractValueInst &EV, IntrinsicInst *FrexpCall,
                                 SelectInst *SelectInst,
                                 InstCombiner::BuilderTy &Builder) {
   // Helper to fold frexp of select to select of frexp.
   Value *Cond = SelectInst->getCondition();
   Value *TrueVal = SelectInst->getTrueValue();
   Value *FalseVal = SelectInst->getFalseValue();
-  ConstantFP *ConstOp = nullptr;
+
+  const APFloat *ConstVal = nullptr;
   Value *VarOp = nullptr;
   bool ConstIsTrue = false;
 
-  if (auto *TrueConst = dyn_cast<ConstantFP>(TrueVal)) {
-    ConstOp = TrueConst;
+  if (match(TrueVal, m_APFloat(ConstVal))) {
     VarOp = FalseVal;
     ConstIsTrue = true;
-  } else if (auto *FalseConst = dyn_cast<ConstantFP>(FalseVal)) {
-    ConstOp = FalseConst;
+  } else if (match(FalseVal, m_APFloat(ConstVal))) {
     VarOp = TrueVal;
     ConstIsTrue = false;
+  } else {
+    return nullptr;
   }
 
-  if (!ConstOp || !VarOp)
-    return nullptr;
+  Builder.SetInsertPoint(&EV);
 
   CallInst *NewFrexp =
       Builder.CreateCall(FrexpCall->getCalledFunction(), {VarOp}, "frexp");
+  NewFrexp->copyIRFlags(FrexpCall);
 
   Value *NewEV = Builder.CreateExtractValue(NewFrexp, 0, "mantissa");
 
-  APFloat ConstVal = ConstOp->getValueAPF();
-  int Exp = 0;
-  APFloat Mantissa = ConstVal;
+  int Exp;
+  APFloat Mantissa = frexp(*ConstVal, Exp, APFloat::rmNearestTiesToEven);
 
-  if (ConstVal.isFiniteNonZero()) {
-    Mantissa = frexp(ConstVal, Exp, APFloat::rmNearestTiesToEven);
+  Constant *ConstantMantissa;
+  if (auto *VecTy = dyn_cast<VectorType>(TrueVal->getType())) {
+    SmallVector<Constant *, 4> Elems(
+        VecTy->getElementCount().getFixedValue(),
+        ConstantFP::get(VecTy->getElementType(), Mantissa));
+    ConstantMantissa = ConstantVector::get(Elems);
+  } else {
+    ConstantMantissa = ConstantFP::get(TrueVal->getType(), Mantissa);
   }
 
-  Constant *ConstantMantissa = ConstantFP::get(ConstOp->getType(), Mantissa);
-
   Value *NewSel = Builder.CreateSelect(
       Cond, ConstIsTrue ? ConstantMantissa : NewEV,
       ConstIsTrue ? NewEV : ConstantMantissa, "select.frexp");
+  if (auto *NewSelInst = dyn_cast<Instruction>(NewSel))
+    NewSelInst->copyFastMathFlags(SelectInst);
 
   return NewSel;
 }
-
 Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) {
   Value *Agg = EV.getAggregateOperand();
 
@@ -4125,20 +4131,12 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) {
                                           SQ.getWithInstruction(&EV)))
     return replaceInstUsesWith(EV, V);
   if (EV.getNumIndices() == 1 && EV.getIndices()[0] == 0) {
-    if (auto *FrexpCall = dyn_cast<CallInst>(Agg)) {
-      if (Function *F = FrexpCall->getCalledFunction()) {
-        if (F->getIntrinsicID() == Intrinsic::frexp) {
-          if (auto *SelInst =
-                  dyn_cast<SelectInst>(FrexpCall->getArgOperand(0))) {
-            if (isa<ConstantFP>(SelInst->getTrueValue()) ||
-                isa<ConstantFP>(SelInst->getFalseValue())) {
-              Builder.SetInsertPoint(&EV);
-
-              if (Value *Result =
-                      foldFrexpOfSelect(EV, FrexpCall, SelInst, Builder)) {
-                return replaceInstUsesWith(EV, Result);
-              }
-            }
+    if (auto *FrexpCall = dyn_cast<IntrinsicInst>(Agg)) {
+      if (FrexpCall->getIntrinsicID() == Intrinsic::frexp) {
+        if (auto *SelInst = dyn_cast<SelectInst>(FrexpCall->getArgOperand(0))) {
+          if (Value *Result =
+                  foldFrexpOfSelect(EV, FrexpCall, SelInst, Builder)) {
+            return replaceInstUsesWith(EV, Result);
           }
         }
       }
diff --git a/llvm/test/Transforms/InstCombine/select_frexp.ll b/llvm/test/Transforms/InstCombine/select_frexp.ll
index 652d4de27b7591f..d729e7c7005142f 100644
--- a/llvm/test/Transforms/InstCombine/select_frexp.ll
+++ b/llvm/test/Transforms/InstCombine/select_frexp.ll
@@ -56,10 +56,10 @@ define float @test_select_frexp_multi_use(float %x, i1 %cond) {
 define <2 x float> @test_select_frexp_vec_splat(<2 x float> %x, <2 x i1> %cond) {
 ; CHECK-LABEL: define <2 x float> @test_select_frexp_vec_splat(
 ; CHECK-SAME: <2 x float> [[X:%.*]], <2 x i1> [[COND:%.*]]) {
-; CHECK-NEXT:    [[SEL:%.*]] = select <2 x i1> [[COND]], <2 x float> splat (float 1.000000e+00), <2 x float> [[X]]
-; CHECK-NEXT:    [[FREXP:%.*]] = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> [[SEL]])
+; CHECK-NEXT:    [[FREXP:%.*]] = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> [[X]])
 ; CHECK-NEXT:    [[FREXP_0:%.*]] = extractvalue { <2 x float>, <2 x i32> } [[FREXP]], 0
-; CHECK-NEXT:    ret <2 x float> [[FREXP_0]]
+; CHECK-NEXT:    [[SELECT_FREXP:%.*]] = select <2 x i1> [[COND]], <2 x float> splat (float 5.000000e-01), <2 x float> [[FREXP_0]]
+; CHECK-NEXT:    ret <2 x float> [[SELECT_FREXP]]
 ;
   %sel = select <2 x i1> %cond, <2 x float> <float 1.000000e+00, float 1.000000e+00>, <2 x float> %x
   %frexp = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> %sel)
@@ -127,4 +127,35 @@ define i32 @test_select_frexp_extract_exp(float %x, i1 %cond) {
   ret i32 %frexp.1
 }
 
+; Test with fast math flags
+define float @test_select_frexp_fast_math_select(float %x, i1 %cond) {
+; CHECK-LABEL: define float @test_select_frexp_fast_math_select(
+; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) {
+; CHECK-NEXT:    [[FREXP1:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[X]])
+; CHECK-NEXT:    [[MANTISSA:%.*]] = extractvalue { float, i32 } [[FREXP1]], 0
+; CHECK-NEXT:    [[SELECT_FREXP:%.*]] = select nnan ninf nsz i1 [[COND]], float 5.000000e-01, float [[MANTISSA]]
+; CHECK-NEXT:    ret float [[SELECT_FREXP]]
+;
+  %sel = select nnan ninf nsz i1 %cond, float 1.000000e+00, float %x
+  %frexp = call { float, i32 } @llvm.frexp.f32.i32(float %sel)
+  %frexp.0 = extractvalue { float, i32 } %frexp, 0
+  ret float %frexp.0
+}
+
+
+; Test vector case with fast math flags
+define <2 x float> @test_select_frexp_vec_fast_math(<2 x float> %x, <2 x i1> %cond) {
+; CHECK-LABEL: define <2 x float> @test_select_frexp_vec_fast_math(
+; CHECK-SAME: <2 x float> [[X:%.*]], <2 x i1> [[COND:%.*]]) {
+; CHECK-NEXT:    [[FREXP1:%.*]] = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> [[X]])
+; CHECK-NEXT:    [[MANTISSA:%.*]] = extractvalue { <2 x float>, <2 x i32> } [[FREXP1]], 0
+; CHECK-NEXT:    [[SELECT_FREXP:%.*]] = select nnan ninf nsz <2 x i1> [[COND]], <2 x float> splat (float 5.000000e-01), <2 x float> [[MANTISSA]]
+; CHECK-NEXT:    ret <2 x float> [[SELECT_FREXP]]
+;
+  %sel = select nnan ninf nsz <2 x i1> %cond, <2 x float> <float 1.000000e+00, float 1.000000e+00>, <2 x float> %x
+  %frexp = call { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float> %sel)
+  %frexp.0 = extractvalue { <2 x float>, <2 x i32> } %frexp, 0
+  ret <2 x float> %frexp.0
+}
+
 declare { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float>)

>From ad6c5019f66b621c96c0ed4158d0bb4875ebd3c5 Mon Sep 17 00:00:00 2001
From: Narayan Sreekumar <nsreekumar6 at gmail.com>
Date: Thu, 30 Jan 2025 22:05:47 +0530
Subject: [PATCH 4/4] [InstCombine] Refactor PatternMatch and add scalable
 Vector tests

---
 .../InstCombine/InstructionCombining.cpp      | 40 +++++++------------
 .../Transforms/InstCombine/select_frexp.ll    | 36 +++++++++++++++--
 2 files changed, 48 insertions(+), 28 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index bc07c7c047efbec..5621511570b5819 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -4074,6 +4074,9 @@ static Value *foldFrexpOfSelect(ExtractValueInst &EV, IntrinsicInst *FrexpCall,
                                 SelectInst *SelectInst,
                                 InstCombiner::BuilderTy &Builder) {
   // Helper to fold frexp of select to select of frexp.
+
+  if (!SelectInst->hasOneUse() || !FrexpCall->hasOneUse())
+    return nullptr;
   Value *Cond = SelectInst->getCondition();
   Value *TrueVal = SelectInst->getTrueValue();
   Value *FalseVal = SelectInst->getFalseValue();
@@ -4103,22 +4106,11 @@ static Value *foldFrexpOfSelect(ExtractValueInst &EV, IntrinsicInst *FrexpCall,
   int Exp;
   APFloat Mantissa = frexp(*ConstVal, Exp, APFloat::rmNearestTiesToEven);
 
-  Constant *ConstantMantissa;
-  if (auto *VecTy = dyn_cast<VectorType>(TrueVal->getType())) {
-    SmallVector<Constant *, 4> Elems(
-        VecTy->getElementCount().getFixedValue(),
-        ConstantFP::get(VecTy->getElementType(), Mantissa));
-    ConstantMantissa = ConstantVector::get(Elems);
-  } else {
-    ConstantMantissa = ConstantFP::get(TrueVal->getType(), Mantissa);
-  }
+  Constant *ConstantMantissa = ConstantFP::get(TrueVal->getType(), Mantissa);
 
-  Value *NewSel = Builder.CreateSelect(
+  Value *NewSel = Builder.CreateSelectFMF(
       Cond, ConstIsTrue ? ConstantMantissa : NewEV,
-      ConstIsTrue ? NewEV : ConstantMantissa, "select.frexp");
-  if (auto *NewSelInst = dyn_cast<Instruction>(NewSel))
-    NewSelInst->copyFastMathFlags(SelectInst);
-
+      ConstIsTrue ? NewEV : ConstantMantissa, SelectInst, "select.frexp");
   return NewSel;
 }
 Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) {
@@ -4130,17 +4122,15 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) {
   if (Value *V = simplifyExtractValueInst(Agg, EV.getIndices(),
                                           SQ.getWithInstruction(&EV)))
     return replaceInstUsesWith(EV, V);
-  if (EV.getNumIndices() == 1 && EV.getIndices()[0] == 0) {
-    if (auto *FrexpCall = dyn_cast<IntrinsicInst>(Agg)) {
-      if (FrexpCall->getIntrinsicID() == Intrinsic::frexp) {
-        if (auto *SelInst = dyn_cast<SelectInst>(FrexpCall->getArgOperand(0))) {
-          if (Value *Result =
-                  foldFrexpOfSelect(EV, FrexpCall, SelInst, Builder)) {
-            return replaceInstUsesWith(EV, Result);
-          }
-        }
-      }
-    }
+
+  Value *Cond, *TrueVal, *FalseVal;
+  if (match(&EV, m_ExtractValue<0>(m_Intrinsic<Intrinsic::frexp>(m_Select(
+                     m_Value(Cond), m_Value(TrueVal), m_Value(FalseVal)))))) {
+    auto *SelInst =
+        cast<SelectInst>(cast<IntrinsicInst>(Agg)->getArgOperand(0));
+    if (Value *Result =
+            foldFrexpOfSelect(EV, cast<IntrinsicInst>(Agg), SelInst, Builder))
+      return replaceInstUsesWith(EV, Result);
   }
   if (InsertValueInst *IV = dyn_cast<InsertValueInst>(Agg)) {
     // We're extracting from an insertvalue instruction, compare the indices
diff --git a/llvm/test/Transforms/InstCombine/select_frexp.ll b/llvm/test/Transforms/InstCombine/select_frexp.ll
index d729e7c7005142f..d025aedda7170de 100644
--- a/llvm/test/Transforms/InstCombine/select_frexp.ll
+++ b/llvm/test/Transforms/InstCombine/select_frexp.ll
@@ -40,10 +40,9 @@ define float @test_select_frexp_multi_use(float %x, i1 %cond) {
 ; CHECK-SAME: float [[X:%.*]], i1 [[COND:%.*]]) {
 ; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[COND]], float 1.000000e+00, float [[X]]
 ; CHECK-NEXT:    call void @use(float [[SEL]])
-; CHECK-NEXT:    [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[X]])
+; CHECK-NEXT:    [[FREXP:%.*]] = call { float, i32 } @llvm.frexp.f32.i32(float [[SEL]])
 ; CHECK-NEXT:    [[FREXP_0:%.*]] = extractvalue { float, i32 } [[FREXP]], 0
-; CHECK-NEXT:    [[SELECT_FREXP:%.*]] = select i1 [[COND]], float 5.000000e-01, float [[FREXP_0]]
-; CHECK-NEXT:    ret float [[SELECT_FREXP]]
+; CHECK-NEXT:    ret float [[FREXP_0]]
 ;
   %sel = select i1 %cond, float 1.000000e+00, float %x
   call void @use(float %sel)
@@ -158,4 +157,35 @@ define <2 x float> @test_select_frexp_vec_fast_math(<2 x float> %x, <2 x i1> %co
   ret <2 x float> %frexp.0
 }
 
+; Test with scalable vectors with constant at True Position
+define <vscale x 2 x float> @test_select_frexp_scalable_vec0(<vscale x 2 x float> %x, <vscale x 2 x i1> %cond) {
+; CHECK-LABEL: define <vscale x 2 x float> @test_select_frexp_scalable_vec0(
+; CHECK-SAME: <vscale x 2 x float> [[X:%.*]], <vscale x 2 x i1> [[COND:%.*]]) {
+; CHECK-NEXT:    [[FREXP1:%.*]] = call { <vscale x 2 x float>, <vscale x 2 x i32> } @llvm.frexp.nxv2f32.nxv2i32(<vscale x 2 x float> [[X]])
+; CHECK-NEXT:    [[MANTISSA:%.*]] = extractvalue { <vscale x 2 x float>, <vscale x 2 x i32> } [[FREXP1]], 0
+; CHECK-NEXT:    [[SELECT_FREXP:%.*]] = select <vscale x 2 x i1> [[COND]], <vscale x 2 x float> splat (float 5.000000e-01), <vscale x 2 x float> [[MANTISSA]]
+; CHECK-NEXT:    ret <vscale x 2 x float> [[SELECT_FREXP]]
+;
+  %sel = select <vscale x 2 x i1> %cond, <vscale x 2 x float> splat (float 1.000000e+00), <vscale x 2 x float> %x
+  %frexp = call { <vscale x 2 x float>, <vscale x 2 x i32> } @llvm.frexp.nxv2f32.nxv2i32(<vscale x 2 x float> %sel)
+  %frexp.0 = extractvalue { <vscale x 2 x float>, <vscale x 2 x i32> } %frexp, 0
+  ret <vscale x 2 x float> %frexp.0
+}
+
+; Test with scalable vectors with constant at False Position
+define <vscale x 2 x float> @test_select_frexp_scalable_vec1(<vscale x 2 x float> %x, <vscale x 2 x i1> %cond) {
+; CHECK-LABEL: define <vscale x 2 x float> @test_select_frexp_scalable_vec1(
+; CHECK-SAME: <vscale x 2 x float> [[X:%.*]], <vscale x 2 x i1> [[COND:%.*]]) {
+; CHECK-NEXT:    [[FREXP1:%.*]] = call { <vscale x 2 x float>, <vscale x 2 x i32> } @llvm.frexp.nxv2f32.nxv2i32(<vscale x 2 x float> [[X]])
+; CHECK-NEXT:    [[MANTISSA:%.*]] = extractvalue { <vscale x 2 x float>, <vscale x 2 x i32> } [[FREXP1]], 0
+; CHECK-NEXT:    [[SELECT_FREXP:%.*]] = select <vscale x 2 x i1> [[COND]], <vscale x 2 x float> [[MANTISSA]], <vscale x 2 x float> splat (float 5.000000e-01)
+; CHECK-NEXT:    ret <vscale x 2 x float> [[SELECT_FREXP]]
+;
+  %sel = select <vscale x 2 x i1> %cond, <vscale x 2 x float> %x, <vscale x 2 x float> splat (float 1.000000e+00)
+  %frexp = call { <vscale x 2 x float>, <vscale x 2 x i32> } @llvm.frexp.nxv2f32.nxv2i32(<vscale x 2 x float> %sel)
+  %frexp.0 = extractvalue { <vscale x 2 x float>, <vscale x 2 x i32> } %frexp, 0
+  ret <vscale x 2 x float> %frexp.0
+}
+
 declare { <2 x float>, <2 x i32> } @llvm.frexp.v2f32.v2i32(<2 x float>)
+declare { <vscale x 2 x float>, <vscale x 2 x i32> } @llvm.frexp.nxv2f32.nxv2i32(<vscale x 2 x float>)



More information about the llvm-commits mailing list