[llvm] [IR][DAG] Add support for `nneg` flag with `uitofp` (PR #86141)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 21 13:16:54 PDT 2024
https://github.com/goldsteinn updated https://github.com/llvm/llvm-project/pull/86141
>From 156943d609123a12fc507f018f3798a090b8d944 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Wed, 20 Mar 2024 16:46:24 -0500
Subject: [PATCH] [IR][DAG] Add support for `nneg` flag with `uitofp`
As noted when #82404 was pushed (canonicalizing `sitofp` -> `uitofp`),
different signedness on fp casts can have dramatic performance
implications on different backends.
So, it makes to create a reliable means for the backend to pick its
cast signedness if either are correct.
Further, this allows us to start canonicalizing `sitofp`- > `uitofp`
which may easy middle end analysis.
---
llvm/docs/LangRef.rst | 10 ++++++
llvm/include/llvm/CodeGen/TargetLowering.h | 6 ++++
llvm/include/llvm/IR/IRBuilder.h | 23 +++++++-------
llvm/include/llvm/IR/InstrTypes.h | 10 ++++--
llvm/lib/AsmParser/LLParser.cpp | 2 +-
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 20 +++++++-----
.../SelectionDAG/SelectionDAGBuilder.cpp | 11 ++++++-
llvm/lib/IR/Instruction.cpp | 5 +--
llvm/lib/IR/Operator.cpp | 1 +
llvm/test/Assembler/flags.ll | 7 +++++
llvm/test/Bitcode/flags.ll | 4 +++
llvm/test/Transforms/InstCombine/freeze.ll | 11 +++++++
llvm/test/Transforms/SimplifyCFG/HoistCode.ll | 31 +++++++++++++++++++
13 files changed, 115 insertions(+), 26 deletions(-)
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 8bc1cab01bf0a6..08da14bf86c054 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -11616,6 +11616,10 @@ Overview:
The '``uitofp``' instruction regards ``value`` as an unsigned integer
and converts that value to the ``ty2`` type.
+The ``nneg`` (non-negative) flag, if present, specifies that the
+operand is non-negative. This property may be used by optimization
+passes to later convert the ``uitofp`` into a ``sitofp``.
+
Arguments:
""""""""""
@@ -11633,6 +11637,9 @@ integer quantity and converts it to the corresponding floating-point
value. If the value cannot be exactly represented, it is rounded using
the default rounding mode.
+If the ``nneg`` flag is set, and the ``uitofp`` argument is negative,
+the result is a poison value.
+
Example:
""""""""
@@ -11642,6 +11649,9 @@ Example:
%X = uitofp i32 257 to float ; yields float:257.0
%Y = uitofp i8 -1 to double ; yields double:255.0
+ %a = uitofp nneg i32 256 to i32 ; yields float:257.0
+ %b = uitofp nneg i32 -256 to i32 ; yields i32 poison
+
'``sitofp .. to``' Instruction
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 59fad88f91b1d1..c53de1d4b6d61e 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3005,6 +3005,12 @@ class TargetLoweringBase {
return false;
}
+ /// Return true if sitofp from FromTy to ToTy is cheaper than
+ /// uitofp.
+ virtual bool isSIToFPCheaperThanUIToFP(EVT FromTy, EVT ToTy) const {
+ return false;
+ }
+
/// Return true if this constant should be sign extended when promoting to
/// a larger type.
virtual bool signExtendConstant(const ConstantInt *C) const { return false; }
diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index c07ffea7115115..93a4da6ec45ed2 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -2009,14 +2009,7 @@ class IRBuilderBase {
Value *CreateZExt(Value *V, Type *DestTy, const Twine &Name = "",
bool IsNonNeg = false) {
- if (V->getType() == DestTy)
- return V;
- if (Value *Folded = Folder.FoldCast(Instruction::ZExt, V, DestTy))
- return Folded;
- Instruction *I = Insert(new ZExtInst(V, DestTy), Name);
- if (IsNonNeg)
- I->setNonNeg();
- return I;
+ return CreateCast(Instruction::ZExt, V, DestTy, Name, IsNonNeg);
}
Value *CreateSExt(Value *V, Type *DestTy, const Twine &Name = "") {
@@ -2067,11 +2060,12 @@ class IRBuilderBase {
return CreateCast(Instruction::FPToSI, V, DestTy, Name);
}
- Value *CreateUIToFP(Value *V, Type *DestTy, const Twine &Name = ""){
+ Value *CreateUIToFP(Value *V, Type *DestTy, const Twine &Name = "",
+ bool IsNonNeg = false) {
if (IsFPConstrained)
return CreateConstrainedFPCast(Intrinsic::experimental_constrained_uitofp,
V, DestTy, nullptr, Name);
- return CreateCast(Instruction::UIToFP, V, DestTy, Name);
+ return CreateCast(Instruction::UIToFP, V, DestTy, Name, IsNonNeg);
}
Value *CreateSIToFP(Value *V, Type *DestTy, const Twine &Name = ""){
@@ -2142,12 +2136,17 @@ class IRBuilderBase {
}
Value *CreateCast(Instruction::CastOps Op, Value *V, Type *DestTy,
- const Twine &Name = "") {
+ const Twine &Name = "", bool IsNonNeg = false) {
if (V->getType() == DestTy)
return V;
if (Value *Folded = Folder.FoldCast(Op, V, DestTy))
return Folded;
- return Insert(CastInst::Create(Op, V, DestTy), Name);
+ Instruction *I = Insert(CastInst::Create(Op, V, DestTy), Name);
+ if (IsNonNeg) {
+ assert(isa<PossiblyNonNegInst>(I) && "Invalid use of IsNonNeg");
+ I->setNonNeg();
+ }
+ return I;
}
Value *CreatePointerCast(Value *V, Type *DestTy,
diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index e8c2cba8418dc8..8e2eff2e65247d 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -933,13 +933,19 @@ class CastInst : public UnaryInstruction {
}
};
-/// Instruction that can have a nneg flag (only zext).
+/// Instruction that can have a nneg flag (zext/uitofp).
class PossiblyNonNegInst : public CastInst {
public:
enum { NonNeg = (1 << 0) };
static bool classof(const Instruction *I) {
- return I->getOpcode() == Instruction::ZExt;
+ switch (I->getOpcode()) {
+ case Instruction::ZExt:
+ case Instruction::UIToFP:
+ return true;
+ default:
+ return false;
+ }
}
static bool classof(const Value *V) {
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index f0be021668afa7..ca3973900ff969 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -6801,6 +6801,7 @@ int LLParser::parseInstruction(Instruction *&Inst, BasicBlock *BB,
}
// Casts.
+ case lltok::kw_uitofp:
case lltok::kw_zext: {
bool NonNeg = EatIfPresent(lltok::kw_nneg);
bool Res = parseCast(Inst, PFS, KeywordVal);
@@ -6816,7 +6817,6 @@ int LLParser::parseInstruction(Instruction *&Inst, BasicBlock *BB,
case lltok::kw_fpext:
case lltok::kw_bitcast:
case lltok::kw_addrspacecast:
- case lltok::kw_uitofp:
case lltok::kw_sitofp:
case lltok::kw_fptoui:
case lltok::kw_fptosi:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 7009f375df1151..7b5055f59d86b9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -4019,13 +4019,13 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
}
// smax(a,b) - smin(a,b) --> abds(a,b)
- if (hasOperation(ISD::ABDS, VT) &&
+ if (hasOperation(ISD::ABDS, VT) &&
sd_match(N0, m_SMax(m_Value(A), m_Value(B))) &&
sd_match(N1, m_SMin(m_Specific(A), m_Specific(B))))
return DAG.getNode(ISD::ABDS, DL, VT, A, B);
// umax(a,b) - umin(a,b) --> abdu(a,b)
- if (hasOperation(ISD::ABDU, VT) &&
+ if (hasOperation(ISD::ABDU, VT) &&
sd_match(N0, m_UMax(m_Value(A), m_Value(B))) &&
sd_match(N1, m_UMin(m_Specific(A), m_Specific(B))))
return DAG.getNode(ISD::ABDU, DL, VT, A, B);
@@ -17413,12 +17413,16 @@ SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) {
TLI.isOperationLegalOrCustom(ISD::ConstantFP, VT)))
return DAG.getNode(ISD::UINT_TO_FP, SDLoc(N), VT, N0);
- // If the input is a legal type, and UINT_TO_FP is not legal on this target,
- // but SINT_TO_FP is legal on this target, try to convert.
- if (!hasOperation(ISD::UINT_TO_FP, OpVT) &&
- hasOperation(ISD::SINT_TO_FP, OpVT)) {
- // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
- if (DAG.SignBitIsZero(N0))
+ SDNodeFlags Flags = N->getFlags();
+ bool NonNeg = Flags.hasNonNeg() || DAG.SignBitIsZero(N0);
+
+ // If the sign bit is known to be zero, we can change this to SINT_TO_FP.
+ if (NonNeg && hasOperation(ISD::SINT_TO_FP, OpVT)) {
+ // If the input is a legal type, and UINT_TO_FP is not legal on this target,
+ // but SINT_TO_FP is legal on this target, convert it.
+ // Or, if the target prefers SINT_TO_FP, convert it.
+ if (!hasOperation(ISD::UINT_TO_FP, OpVT) ||
+ DAG.getTargetLoweringInfo().isSIToFPCheaperThanUIToFP(VT, OpVT))
return DAG.getNode(ISD::SINT_TO_FP, SDLoc(N), VT, N0);
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 2d63774c75e372..5f8bb8ab3e7049 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -3882,7 +3882,16 @@ void SelectionDAGBuilder::visitUIToFP(const User &I) {
SDValue N = getValue(I.getOperand(0));
EVT DestVT = DAG.getTargetLoweringInfo().getValueType(DAG.getDataLayout(),
I.getType());
- setValue(&I, DAG.getNode(ISD::UINT_TO_FP, getCurSDLoc(), DestVT, N));
+ SDNodeFlags Flags;
+ if (auto *PNI = dyn_cast<PossiblyNonNegInst>(&I))
+ Flags.setNonNeg(PNI->hasNonNeg());
+
+ if (Flags.hasNonNeg() &&
+ DAG.getTargetLoweringInfo().isSIToFPCheaperThanUIToFP(N.getValueType(),
+ DestVT))
+ setValue(&I, DAG.getNode(ISD::SINT_TO_FP, getCurSDLoc(), DestVT, N));
+
+ setValue(&I, DAG.getNode(ISD::UINT_TO_FP, getCurSDLoc(), DestVT, N, Flags));
}
void SelectionDAGBuilder::visitSIToFP(const User &I) {
diff --git a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp
index 47a7f2c9de790f..7f11ffacf26501 100644
--- a/llvm/lib/IR/Instruction.cpp
+++ b/llvm/lib/IR/Instruction.cpp
@@ -382,7 +382,7 @@ void Instruction::setIsExact(bool b) {
}
void Instruction::setNonNeg(bool b) {
- assert(isa<PossiblyNonNegInst>(this) && "Must be zext");
+ assert(isa<PossiblyNonNegInst>(this) && "Must be zext/uitofp");
SubclassOptionalData = (SubclassOptionalData & ~PossiblyNonNegInst::NonNeg) |
(b * PossiblyNonNegInst::NonNeg);
}
@@ -396,7 +396,7 @@ bool Instruction::hasNoSignedWrap() const {
}
bool Instruction::hasNonNeg() const {
- assert(isa<PossiblyNonNegInst>(this) && "Must be zext");
+ assert(isa<PossiblyNonNegInst>(this) && "Must be zext/uitofp");
return (SubclassOptionalData & PossiblyNonNegInst::NonNeg) != 0;
}
@@ -429,6 +429,7 @@ void Instruction::dropPoisonGeneratingFlags() {
cast<GetElementPtrInst>(this)->setIsInBounds(false);
break;
+ case Instruction::UIToFP:
case Instruction::ZExt:
setNonNeg(false);
break;
diff --git a/llvm/lib/IR/Operator.cpp b/llvm/lib/IR/Operator.cpp
index b9cd219d94dc8a..6603ac36239096 100644
--- a/llvm/lib/IR/Operator.cpp
+++ b/llvm/lib/IR/Operator.cpp
@@ -39,6 +39,7 @@ bool Operator::hasPoisonGeneratingFlags() const {
// Note: inrange exists on constexpr only
return GEP->isInBounds() || GEP->getInRange() != std::nullopt;
}
+ case Instruction::UIToFP:
case Instruction::ZExt:
if (auto *NNI = dyn_cast<PossiblyNonNegInst>(this))
return NNI->hasNonNeg();
diff --git a/llvm/test/Assembler/flags.ll b/llvm/test/Assembler/flags.ll
index 04bddd02f50c81..c4f1d4c288b8d5 100644
--- a/llvm/test/Assembler/flags.ll
+++ b/llvm/test/Assembler/flags.ll
@@ -256,6 +256,13 @@ define i64 @test_zext(i32 %a) {
ret i64 %res
}
+define float @test_uitofp(i32 %a) {
+; CHECK: %res = uitofp nneg i32 %a to float
+ %res = uitofp nneg i32 %a to float
+ ret float %res
+}
+
+
define i64 @test_or(i64 %a, i64 %b) {
; CHECK: %res = or disjoint i64 %a, %b
%res = or disjoint i64 %a, %b
diff --git a/llvm/test/Bitcode/flags.ll b/llvm/test/Bitcode/flags.ll
index e3fc827d865d7e..5d41e441b5ced4 100644
--- a/llvm/test/Bitcode/flags.ll
+++ b/llvm/test/Bitcode/flags.ll
@@ -18,6 +18,8 @@ second: ; preds = %first
%z = add i32 %a, 0 ; <i32> [#uses=0]
%hh = zext nneg i32 %a to i64
%ll = zext i32 %s to i64
+ %ff = uitofp nneg i32 %a to float
+ %bb = uitofp i32 %s to float
%jj = or disjoint i32 %a, 0
%oo = or i32 %a, 0
unreachable
@@ -30,6 +32,8 @@ first: ; preds = %entry
%zz = add i32 %a, 0 ; <i32> [#uses=0]
%kk = zext nneg i32 %a to i64
%rr = zext i32 %ss to i64
+ %ww = uitofp nneg i32 %a to float
+ %xx = uitofp i32 %ss to float
%mm = or disjoint i32 %a, 0
%nn = or i32 %a, 0
br label %second
diff --git a/llvm/test/Transforms/InstCombine/freeze.ll b/llvm/test/Transforms/InstCombine/freeze.ll
index da59101d5710cb..2342184f8221e6 100644
--- a/llvm/test/Transforms/InstCombine/freeze.ll
+++ b/llvm/test/Transforms/InstCombine/freeze.ll
@@ -1127,6 +1127,17 @@ define i32 @freeze_zext_nneg(i8 %x) {
ret i32 %fr
}
+define float @freeze_uitofp_nneg(i8 %x) {
+; CHECK-LABEL: @freeze_uitofp_nneg(
+; CHECK-NEXT: [[X_FR:%.*]] = freeze i8 [[X:%.*]]
+; CHECK-NEXT: [[UITOFP:%.*]] = uitofp i8 [[X_FR]] to float
+; CHECK-NEXT: ret float [[UITOFP]]
+;
+ %uitofp = uitofp nneg i8 %x to float
+ %fr = freeze float %uitofp
+ ret float %fr
+}
+
define i32 @propagate_drop_flags_or(i32 %arg) {
; CHECK-LABEL: @propagate_drop_flags_or(
; CHECK-NEXT: [[ARG_FR:%.*]] = freeze i32 [[ARG:%.*]]
diff --git a/llvm/test/Transforms/SimplifyCFG/HoistCode.ll b/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
index a081eddfc45660..89a13cead35e06 100644
--- a/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
+++ b/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
@@ -125,6 +125,37 @@ F:
ret i32 %z2
}
+
+define float @hoist_uitofp_flags_preserve(i1 %C, i8 %x) {
+; CHECK-LABEL: @hoist_uitofp_flags_preserve(
+; CHECK-NEXT: common.ret:
+; CHECK-NEXT: [[Z1:%.*]] = uitofp nneg i8 [[X:%.*]] to float
+; CHECK-NEXT: ret float [[Z1]]
+;
+ br i1 %C, label %T, label %F
+T:
+ %z1 = uitofp nneg i8 %x to float
+ ret float %z1
+F:
+ %z2 = uitofp nneg i8 %x to float
+ ret float %z2
+}
+
+define float @hoist_uitofp_flags_drop(i1 %C, i8 %x) {
+; CHECK-LABEL: @hoist_uitofp_flags_drop(
+; CHECK-NEXT: common.ret:
+; CHECK-NEXT: [[Z1:%.*]] = uitofp i8 [[X:%.*]] to float
+; CHECK-NEXT: ret float [[Z1]]
+;
+ br i1 %C, label %T, label %F
+T:
+ %z1 = uitofp nneg i8 %x to float
+ ret float %z1
+F:
+ %z2 = uitofp i8 %x to float
+ ret float %z2
+}
+
define i32 @hoist_or_flags_preserve(i1 %C, i32 %x, i32 %y) {
; CHECK-LABEL: @hoist_or_flags_preserve(
; CHECK-NEXT: common.ret:
More information about the llvm-commits
mailing list