[llvm] [IR][DAG] Add support for `nneg` flag with `uitofp` (PR #86141)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 21 13:42:21 PDT 2024


https://github.com/goldsteinn updated https://github.com/llvm/llvm-project/pull/86141

>From c22abc199896988b04c64d0811ac7853edd6cfd5 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Wed, 20 Mar 2024 16:46:24 -0500
Subject: [PATCH 1/2] [IR][DAG] Add support for `nneg` flag with `uitofp`

As noted when #82404 was pushed (canonicalizing `sitofp` -> `uitofp`),
different signedness on fp casts can have dramatic performance
implications on different backends.

So, it makes to create a reliable means for the backend to pick its
cast signedness if either are correct.

Further, this allows us to start canonicalizing `sitofp`- > `uitofp`
which may easy middle end analysis.
---
 llvm/docs/LangRef.rst                         | 10 ++++++
 llvm/include/llvm/CodeGen/TargetLowering.h    |  6 ++++
 llvm/include/llvm/IR/IRBuilder.h              | 23 +++++++-------
 llvm/include/llvm/IR/InstrTypes.h             | 10 ++++--
 llvm/lib/AsmParser/LLParser.cpp               |  2 +-
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 20 +++++++-----
 .../SelectionDAG/SelectionDAGBuilder.cpp      | 11 ++++++-
 llvm/lib/IR/Instruction.cpp                   |  5 +--
 llvm/lib/IR/Operator.cpp                      |  1 +
 llvm/test/Assembler/flags.ll                  |  7 +++++
 llvm/test/Bitcode/flags.ll                    |  4 +++
 llvm/test/Transforms/InstCombine/freeze.ll    | 11 +++++++
 llvm/test/Transforms/SimplifyCFG/HoistCode.ll | 31 +++++++++++++++++++
 13 files changed, 115 insertions(+), 26 deletions(-)

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 8bc1cab01bf0a6..08da14bf86c054 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -11616,6 +11616,10 @@ Overview:
 The '``uitofp``' instruction regards ``value`` as an unsigned integer
 and converts that value to the ``ty2`` type.
 
+The ``nneg`` (non-negative) flag, if present, specifies that the
+operand is non-negative. This property may be used by optimization
+passes to later convert the ``uitofp`` into a ``sitofp``.
+
 Arguments:
 """"""""""
 
@@ -11633,6 +11637,9 @@ integer quantity and converts it to the corresponding floating-point
 value. If the value cannot be exactly represented, it is rounded using
 the default rounding mode.
 
+If the ``nneg`` flag is set, and the ``uitofp`` argument is negative,
+the result is a poison value.
+
 
 Example:
 """"""""
@@ -11642,6 +11649,9 @@ Example:
       %X = uitofp i32 257 to float         ; yields float:257.0
       %Y = uitofp i8 -1 to double          ; yields double:255.0
 
+      %a = uitofp nneg i32 256 to i32      ; yields float:257.0
+      %b = uitofp nneg i32 -256 to i32     ; yields i32 poison
+
 '``sitofp .. to``' Instruction
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 59fad88f91b1d1..c53de1d4b6d61e 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3005,6 +3005,12 @@ class TargetLoweringBase {
     return false;
   }
 
+  /// Return true if sitofp from FromTy to ToTy is cheaper than
+  /// uitofp.
+  virtual bool isSIToFPCheaperThanUIToFP(EVT FromTy, EVT ToTy) const {
+    return false;
+  }
+
   /// Return true if this constant should be sign extended when promoting to
   /// a larger type.
   virtual bool signExtendConstant(const ConstantInt *C) const { return false; }
diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index c07ffea7115115..93a4da6ec45ed2 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -2009,14 +2009,7 @@ class IRBuilderBase {
 
   Value *CreateZExt(Value *V, Type *DestTy, const Twine &Name = "",
                     bool IsNonNeg = false) {
-    if (V->getType() == DestTy)
-      return V;
-    if (Value *Folded = Folder.FoldCast(Instruction::ZExt, V, DestTy))
-      return Folded;
-    Instruction *I = Insert(new ZExtInst(V, DestTy), Name);
-    if (IsNonNeg)
-      I->setNonNeg();
-    return I;
+    return CreateCast(Instruction::ZExt, V, DestTy, Name, IsNonNeg);
   }
 
   Value *CreateSExt(Value *V, Type *DestTy, const Twine &Name = "") {
@@ -2067,11 +2060,12 @@ class IRBuilderBase {
     return CreateCast(Instruction::FPToSI, V, DestTy, Name);
   }
 
-  Value *CreateUIToFP(Value *V, Type *DestTy, const Twine &Name = ""){
+  Value *CreateUIToFP(Value *V, Type *DestTy, const Twine &Name = "",
+                      bool IsNonNeg = false) {
     if (IsFPConstrained)
       return CreateConstrainedFPCast(Intrinsic::experimental_constrained_uitofp,
                                      V, DestTy, nullptr, Name);
-    return CreateCast(Instruction::UIToFP, V, DestTy, Name);
+    return CreateCast(Instruction::UIToFP, V, DestTy, Name, IsNonNeg);
   }
 
   Value *CreateSIToFP(Value *V, Type *DestTy, const Twine &Name = ""){
@@ -2142,12 +2136,17 @@ class IRBuilderBase {
   }
 
   Value *CreateCast(Instruction::CastOps Op, Value *V, Type *DestTy,
-                    const Twine &Name = "") {
+                    const Twine &Name = "", bool IsNonNeg = false) {
     if (V->getType() == DestTy)
       return V;
     if (Value *Folded = Folder.FoldCast(Op, V, DestTy))
       return Folded;
-    return Insert(CastInst::Create(Op, V, DestTy), Name);
+    Instruction *I = Insert(CastInst::Create(Op, V, DestTy), Name);
+    if (IsNonNeg) {
+      assert(isa<PossiblyNonNegInst>(I) && "Invalid use of IsNonNeg");
+      I->setNonNeg();
+    }
+    return I;
   }
 
   Value *CreatePointerCast(Value *V, Type *DestTy,
diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index e8c2cba8418dc8..8e2eff2e65247d 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -933,13 +933,19 @@ class CastInst : public UnaryInstruction {
   }
 };
 
-/// Instruction that can have a nneg flag (only zext).
+/// Instruction that can have a nneg flag (zext/uitofp).
 class PossiblyNonNegInst : public CastInst {
 public:
   enum { NonNeg = (1 << 0) };
 
   static bool classof(const Instruction *I) {
-    return I->getOpcode() == Instruction::ZExt;
+    switch (I->getOpcode()) {
+    case Instruction::ZExt:
+    case Instruction::UIToFP:
+      return true;
+    default:
+      return false;
+    }
   }
 
   static bool classof(const Value *V) {
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index f0be021668afa7..ca3973900ff969 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -6801,6 +6801,7 @@ int LLParser::parseInstruction(Instruction *&Inst, BasicBlock *BB,
   }
 
   // Casts.
+  case lltok::kw_uitofp:
   case lltok::kw_zext: {
     bool NonNeg = EatIfPresent(lltok::kw_nneg);
     bool Res = parseCast(Inst, PFS, KeywordVal);
@@ -6816,7 +6817,6 @@ int LLParser::parseInstruction(Instruction *&Inst, BasicBlock *BB,
   case lltok::kw_fpext:
   case lltok::kw_bitcast:
   case lltok::kw_addrspacecast:
-  case lltok::kw_uitofp:
   case lltok::kw_sitofp:
   case lltok::kw_fptoui:
   case lltok::kw_fptosi:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 7009f375df1151..58c32ff05b5c07 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -4019,13 +4019,13 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
   }
 
   // smax(a,b) - smin(a,b) --> abds(a,b)
-  if (hasOperation(ISD::ABDS, VT)  && 
+  if (hasOperation(ISD::ABDS, VT)  &&
       sd_match(N0, m_SMax(m_Value(A), m_Value(B))) &&
       sd_match(N1, m_SMin(m_Specific(A), m_Specific(B))))
     return DAG.getNode(ISD::ABDS, DL, VT, A, B);
 
   // umax(a,b) - umin(a,b) --> abdu(a,b)
-  if (hasOperation(ISD::ABDU, VT)  && 
+  if (hasOperation(ISD::ABDU, VT)  &&
       sd_match(N0, m_UMax(m_Value(A), m_Value(B))) &&
       sd_match(N1, m_UMin(m_Specific(A), m_Specific(B))))
     return DAG.getNode(ISD::ABDU, DL, VT, A, B);
@@ -17413,12 +17413,16 @@ SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
        TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
     return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
 
-  // If the input is a legal type, and UINT_TO_FP is not legal on this target,
-  // but SINT_TO_FP is legal on this target, try to convert.
-  if (!hasOperation(ISD::UINT_TO_FP, OpVT) &&
-      hasOperation(ISD::SINT_TO_FP, OpVT)) {
-    // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
-    if (DAG.SignBitIsZero(N0))
+  SDNodeFlags Flags = N->getFlags();
+  bool NonNeg = Flags.hasNonNeg() || DAG.SignBitIsZero(N0);
+
+  // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
+  if (NonNeg && hasOperation(ISD::SINT_TO_FP, OpVT)) {
+    // If the input is a legal type, and UINT_TO_FP is not legal on this target,
+    // but SINT_TO_FP is legal on this target, convert it.
+    // Or, if the target prefers SINT_TO_FP, convert it.
+    if (!hasOperation(ISD::UINT_TO_FP, OpVT) ||
+        DAG.getTargetLoweringInfo().isSIToFPCheaperThanUIToFP(OpVT, VT))
       return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
   }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 2d63774c75e372..5f8bb8ab3e7049 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -3882,7 +3882,16 @@ void SelectionDAGBuilder::visitUIToFP(const User &I) {
   SDValue N = getValue(I.getOperand(0));
   EVT DestVT = DAG.getTargetLoweringInfo().getValueType(DAG.getDataLayout(),
                                                         I.getType());
-  setValue(&I, DAG.getNode(ISD::UINT_TO_FP, getCurSDLoc(), DestVT, N));
+  SDNodeFlags Flags;
+  if (auto *PNI = dyn_cast<PossiblyNonNegInst>(&I))
+    Flags.setNonNeg(PNI->hasNonNeg());
+
+  if (Flags.hasNonNeg() &&
+      DAG.getTargetLoweringInfo().isSIToFPCheaperThanUIToFP(N.getValueType(),
+                                                            DestVT))
+    setValue(&I, DAG.getNode(ISD::SINT_TO_FP, getCurSDLoc(), DestVT, N));
+
+  setValue(&I, DAG.getNode(ISD::UINT_TO_FP, getCurSDLoc(), DestVT, N, Flags));
 }
 
 void SelectionDAGBuilder::visitSIToFP(const User &I) {
diff --git a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp
index 47a7f2c9de790f..7f11ffacf26501 100644
--- a/llvm/lib/IR/Instruction.cpp
+++ b/llvm/lib/IR/Instruction.cpp
@@ -382,7 +382,7 @@ void Instruction::setIsExact(bool b) {
 }
 
 void Instruction::setNonNeg(bool b) {
-  assert(isa<PossiblyNonNegInst>(this) && "Must be zext");
+  assert(isa<PossiblyNonNegInst>(this) && "Must be zext/uitofp");
   SubclassOptionalData = (SubclassOptionalData & ~PossiblyNonNegInst::NonNeg) |
                          (b * PossiblyNonNegInst::NonNeg);
 }
@@ -396,7 +396,7 @@ bool Instruction::hasNoSignedWrap() const {
 }
 
 bool Instruction::hasNonNeg() const {
-  assert(isa<PossiblyNonNegInst>(this) && "Must be zext");
+  assert(isa<PossiblyNonNegInst>(this) && "Must be zext/uitofp");
   return (SubclassOptionalData & PossiblyNonNegInst::NonNeg) != 0;
 }
 
@@ -429,6 +429,7 @@ void Instruction::dropPoisonGeneratingFlags() {
     cast<GetElementPtrInst>(this)->setIsInBounds(false);
     break;
 
+  case Instruction::UIToFP:
   case Instruction::ZExt:
     setNonNeg(false);
     break;
diff --git a/llvm/lib/IR/Operator.cpp b/llvm/lib/IR/Operator.cpp
index b9cd219d94dc8a..6603ac36239096 100644
--- a/llvm/lib/IR/Operator.cpp
+++ b/llvm/lib/IR/Operator.cpp
@@ -39,6 +39,7 @@ bool Operator::hasPoisonGeneratingFlags() const {
     // Note: inrange exists on constexpr only
     return GEP->isInBounds() || GEP->getInRange() != std::nullopt;
   }
+  case Instruction::UIToFP:
   case Instruction::ZExt:
     if (auto *NNI = dyn_cast<PossiblyNonNegInst>(this))
       return NNI->hasNonNeg();
diff --git a/llvm/test/Assembler/flags.ll b/llvm/test/Assembler/flags.ll
index 04bddd02f50c81..c4f1d4c288b8d5 100644
--- a/llvm/test/Assembler/flags.ll
+++ b/llvm/test/Assembler/flags.ll
@@ -256,6 +256,13 @@ define i64 @test_zext(i32 %a) {
   ret i64 %res
 }
 
+define float @test_uitofp(i32 %a) {
+; CHECK: %res = uitofp nneg i32 %a to float
+  %res = uitofp nneg i32 %a to float
+  ret float %res
+}
+
+
 define i64 @test_or(i64 %a, i64 %b) {
 ; CHECK: %res = or disjoint i64 %a, %b
   %res = or disjoint i64 %a, %b
diff --git a/llvm/test/Bitcode/flags.ll b/llvm/test/Bitcode/flags.ll
index e3fc827d865d7e..5d41e441b5ced4 100644
--- a/llvm/test/Bitcode/flags.ll
+++ b/llvm/test/Bitcode/flags.ll
@@ -18,6 +18,8 @@ second:                                           ; preds = %first
   %z = add i32 %a, 0                              ; <i32> [#uses=0]
   %hh = zext nneg i32 %a to i64
   %ll = zext i32 %s to i64
+  %ff = uitofp nneg i32 %a to float
+  %bb = uitofp i32 %s to float
   %jj = or disjoint i32 %a, 0
   %oo = or i32 %a, 0
   unreachable
@@ -30,6 +32,8 @@ first:                                            ; preds = %entry
   %zz = add i32 %a, 0                             ; <i32> [#uses=0]
   %kk = zext nneg i32 %a to i64
   %rr = zext i32 %ss to i64
+  %ww = uitofp nneg i32 %a to float
+  %xx = uitofp i32 %ss to float
   %mm = or disjoint i32 %a, 0
   %nn = or i32 %a, 0
   br label %second
diff --git a/llvm/test/Transforms/InstCombine/freeze.ll b/llvm/test/Transforms/InstCombine/freeze.ll
index da59101d5710cb..2342184f8221e6 100644
--- a/llvm/test/Transforms/InstCombine/freeze.ll
+++ b/llvm/test/Transforms/InstCombine/freeze.ll
@@ -1127,6 +1127,17 @@ define i32 @freeze_zext_nneg(i8 %x) {
   ret i32 %fr
 }
 
+define float @freeze_uitofp_nneg(i8 %x) {
+; CHECK-LABEL: @freeze_uitofp_nneg(
+; CHECK-NEXT:    [[X_FR:%.*]] = freeze i8 [[X:%.*]]
+; CHECK-NEXT:    [[UITOFP:%.*]] = uitofp i8 [[X_FR]] to float
+; CHECK-NEXT:    ret float [[UITOFP]]
+;
+  %uitofp = uitofp nneg i8 %x to float
+  %fr = freeze float %uitofp
+  ret float %fr
+}
+
 define i32 @propagate_drop_flags_or(i32 %arg) {
 ; CHECK-LABEL: @propagate_drop_flags_or(
 ; CHECK-NEXT:    [[ARG_FR:%.*]] = freeze i32 [[ARG:%.*]]
diff --git a/llvm/test/Transforms/SimplifyCFG/HoistCode.ll b/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
index a081eddfc45660..89a13cead35e06 100644
--- a/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
+++ b/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
@@ -125,6 +125,37 @@ F:
   ret i32 %z2
 }
 
+
+define float @hoist_uitofp_flags_preserve(i1 %C, i8 %x) {
+; CHECK-LABEL: @hoist_uitofp_flags_preserve(
+; CHECK-NEXT:  common.ret:
+; CHECK-NEXT:    [[Z1:%.*]] = uitofp nneg i8 [[X:%.*]] to float
+; CHECK-NEXT:    ret float [[Z1]]
+;
+  br i1 %C, label %T, label %F
+T:
+  %z1 = uitofp nneg i8 %x to float
+  ret float %z1
+F:
+  %z2 = uitofp nneg i8 %x to float
+  ret float %z2
+}
+
+define float @hoist_uitofp_flags_drop(i1 %C, i8 %x) {
+; CHECK-LABEL: @hoist_uitofp_flags_drop(
+; CHECK-NEXT:  common.ret:
+; CHECK-NEXT:    [[Z1:%.*]] = uitofp i8 [[X:%.*]] to float
+; CHECK-NEXT:    ret float [[Z1]]
+;
+  br i1 %C, label %T, label %F
+T:
+  %z1 = uitofp nneg i8 %x to float
+  ret float %z1
+F:
+  %z2 = uitofp i8 %x to float
+  ret float %z2
+}
+
 define i32 @hoist_or_flags_preserve(i1 %C, i32 %x, i32 %y) {
 ; CHECK-LABEL: @hoist_or_flags_preserve(
 ; CHECK-NEXT:  common.ret:

>From aa9471492ce030affdf489c6604460c9eec1b4e7 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Thu, 21 Mar 2024 15:41:37 -0500
Subject: [PATCH 2/2] [X86] Set `isSIToFPCheaperThanUIToFP` to true for all
 types

X86 pretty much as always better off or equally well off with `sitofp`
---
 llvm/lib/Target/X86/X86ISelLowering.h | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index 0a1e8ca4427314..0e20c0c252aa82 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -1066,6 +1066,11 @@ namespace llvm {
     bool preferSextInRegOfTruncate(EVT TruncVT, EVT VT,
                                    EVT ExtVT) const override;
 
+    /// Always prefer signed cast from int -> fp.
+    bool isSIToFPCheaperThanUIToFP(EVT FromTy, EVT ToTy) const override {
+      return true;
+    }
+
     bool isXAndYEqZeroPreferableToXAndYEqY(ISD::CondCode Cond,
                                            EVT VT) const override;
 



More information about the llvm-commits mailing list