[llvm] [IR] Add zext nneg flag (PR #67982)
Nikita Popov via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 2 06:00:31 PDT 2023
https://github.com/nikic created https://github.com/llvm/llvm-project/pull/67982
Add an nneg flag to the zext instruction, which specifies that the argument is non-negative. Otherwise, the result is a poison value.
The primary use-case for the flag is to preserve information when sext gets replaced with zext due to range-based canonicalization. The nneg flag allows us to convert the zext back into an sext later. This is useful for some optimizations (e.g. a signed icmp can fold with sext but not zext), as well as some targets (e.g. RISCV prefers sext over zext).
This patch is based on https://reviews.llvm.org/D156444, with some implementation simplifications and additional tests.
>From ff1b36743745225b486c1b064a233006261cb182 Mon Sep 17 00:00:00 2001
From: Panagiotis K <karouzakispan at gmail.com>
Date: Mon, 2 Oct 2023 14:16:38 +0200
Subject: [PATCH] [IR] Add zext nneg flag
Add an nneg flag to the zext instruction, which specifies that the
argument is non-negative. Otherwise, the result is a poison value.
The primary use-case for the flag is to preserve information when
sext gets replaced with zext due to range-based canonicalization.
The nneg flag allows us to convert the zext back into an sext
later. This is useful for some optimizations (e.g. a signed icmp
can fold with sext but not zext), as well as some targets (e.g.
RISCV prefers sext over zext).
Co-authored-by: Nikita Popov <npopov at redhat.com>
---
llvm/docs/LangRef.rst | 6 ++++
llvm/include/llvm/AsmParser/LLToken.h | 1 +
llvm/include/llvm/Bitcode/LLVMBitCodes.h | 3 ++
llvm/include/llvm/IR/InstrTypes.h | 14 +++++++++
llvm/include/llvm/IR/Instruction.h | 7 +++++
llvm/lib/AsmParser/LLLexer.cpp | 1 +
llvm/lib/AsmParser/LLParser.cpp | 11 ++++++-
llvm/lib/Bitcode/Reader/BitcodeReader.cpp | 11 +++++--
llvm/lib/Bitcode/Writer/BitcodeWriter.cpp | 22 ++++++++++++++
llvm/lib/IR/AsmWriter.cpp | 3 ++
llvm/lib/IR/Instruction.cpp | 24 +++++++++++++++
llvm/lib/IR/Operator.cpp | 4 +++
llvm/test/Assembler/flags.ll | 6 ++++
llvm/test/Bitcode/flags.ll | 4 +++
llvm/test/Transforms/InstCombine/freeze.ll | 11 +++++++
llvm/test/Transforms/SimplifyCFG/HoistCode.ll | 30 +++++++++++++++++++
16 files changed, 154 insertions(+), 4 deletions(-)
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index c8f74c19bd6b3cf..4c4f3cbbaf0d541 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -11156,6 +11156,9 @@ until it reaches the size of the destination type, ``ty2``.
When zero extending from i1, the result will always be either 0 or 1.
+If the ``nneg`` flag is set, and the ``zext`` argument is negative, the result
+is a poison value.
+
Example:
""""""""
@@ -11165,6 +11168,9 @@ Example:
%Y = zext i1 true to i32 ; yields i32:1
%Z = zext <2 x i16> <i16 8, i16 7> to <2 x i32> ; yields <i32 8, i32 7>
+ %a = zext nneg i8 127 to i16 ; yields i16 127
+ %b = zext nneg i8 -1 to i16 ; yields i16 poison
+
.. _i_sext:
'``sext .. to``' Instruction
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 673dc58ce6451e3..ae566672b97986b 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -110,6 +110,7 @@ enum Kind {
kw_nsw,
kw_exact,
kw_inbounds,
+ kw_nneg,
kw_inrange,
kw_addrspace,
kw_section,
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 52e76356a892e45..6f11bb0b30c379a 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -505,6 +505,9 @@ enum FastMathMap {
AllowReassoc = (1 << 7)
};
+/// Flags for serializing NonNegInstruction's SubclassOptionalData contents.
+enum NonNegInstructionOptionalFlags { NNI_NON_NEG = 0 };
+
/// PossiblyExactOperatorOptionalFlags - Flags for serializing
/// PossiblyExactOperator's SubclassOptionalData contents.
enum PossiblyExactOperatorOptionalFlags { PEO_EXACT = 0 };
diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 6095b0a1be69cb3..b6da2bf836cb56f 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -692,6 +692,20 @@ class CastInst : public UnaryInstruction {
}
};
+/// Instruction that can have a nneg flag (only zext).
+class NonNegInstruction : public CastInst {
+public:
+ enum { NonNeg = (1 << 0) };
+
+ static bool classof(const Instruction *I) {
+ return I->getOpcode() == Instruction::ZExt;
+ }
+
+ static bool classof(const Value *V) {
+ return isa<Instruction>(V) && classof(cast<Instruction>(V));
+ }
+};
+
//===----------------------------------------------------------------------===//
// CmpInst Class
//===----------------------------------------------------------------------===//
diff --git a/llvm/include/llvm/IR/Instruction.h b/llvm/include/llvm/IR/Instruction.h
index 69c3af5b76103f6..781a69e586a9e6b 100644
--- a/llvm/include/llvm/IR/Instruction.h
+++ b/llvm/include/llvm/IR/Instruction.h
@@ -407,12 +407,19 @@ class Instruction : public User,
/// which supports this flag. See LangRef.html for the meaning of this flag.
void setIsExact(bool b = true);
+ /// Set or clear the nneg flag on this instruction, which must be a zext
+ /// instruction.
+ void setNonNeg(bool b = true);
+
/// Determine whether the no unsigned wrap flag is set.
bool hasNoUnsignedWrap() const LLVM_READONLY;
/// Determine whether the no signed wrap flag is set.
bool hasNoSignedWrap() const LLVM_READONLY;
+ /// Determine whether the the nneg flag is set.
+ bool hasNonNeg() const LLVM_READONLY;
+
/// Return true if this operator has flags which may cause this instruction
/// to evaluate to poison despite having non-poison inputs.
bool hasPoisonGeneratingFlags() const LLVM_READONLY;
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index 466bdebc001f589..02e1a1dce3c01b5 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -565,6 +565,7 @@ lltok::Kind LLLexer::LexIdentifier() {
KEYWORD(nsw);
KEYWORD(exact);
KEYWORD(inbounds);
+ KEYWORD(nneg);
KEYWORD(inrange);
KEYWORD(addrspace);
KEYWORD(section);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 04eabc94cfc6abe..2416814c53c0156 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -6381,8 +6381,17 @@ int LLParser::parseInstruction(Instruction *&Inst, BasicBlock *BB,
}
// Casts.
+ case lltok::kw_zext: {
+ bool NonNeg = EatIfPresent(lltok::kw_nneg);
+ bool Res = parseCast(Inst, PFS, KeywordVal);
+ if (NonNeg)
+ Inst->setNonNeg();
+ if (Res != 0) {
+ return Res;
+ }
+ return 0;
+ }
case lltok::kw_trunc:
- case lltok::kw_zext:
case lltok::kw_sext:
case lltok::kw_fptrunc:
case lltok::kw_fpext:
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index e56291859022eec..28ba79e48e7697d 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -4875,12 +4875,13 @@ Error BitcodeReader::parseFunctionBody(Function *F) {
Value *Op;
unsigned OpTypeID;
if (getValueTypePair(Record, OpNum, NextValueNo, Op, OpTypeID, CurBB) ||
- OpNum+2 != Record.size())
+ OpNum + 1 > Record.size())
return error("Invalid record");
- ResTypeID = Record[OpNum];
+ ResTypeID = Record[OpNum++];
Type *ResTy = getTypeByID(ResTypeID);
- int Opc = getDecodedCastOpcode(Record[OpNum + 1]);
+ int Opc = getDecodedCastOpcode(Record[OpNum++]);
+
if (Opc == -1 || !ResTy)
return error("Invalid record");
Instruction *Temp = nullptr;
@@ -4892,10 +4893,14 @@ Error BitcodeReader::parseFunctionBody(Function *F) {
}
} else {
auto CastOp = (Instruction::CastOps)Opc;
+
if (!CastInst::castIsValid(CastOp, Op, ResTy))
return error("Invalid cast");
I = CastInst::Create(CastOp, Op, ResTy);
}
+ if (OpNum < Record.size() && isa<NonNegInstruction>(I) &&
+ (Record[OpNum] & (1 << bitc::NNI_NON_NEG)))
+ I->setNonNeg(true);
InstructionList.push_back(I);
break;
}
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index e991d055f33474b..75ddc058c5bcd79 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -123,6 +123,7 @@ enum {
FUNCTION_INST_BINOP_ABBREV,
FUNCTION_INST_BINOP_FLAGS_ABBREV,
FUNCTION_INST_CAST_ABBREV,
+ FUNCTION_INST_CAST_FLAGS_ABBREV,
FUNCTION_INST_RET_VOID_ABBREV,
FUNCTION_INST_RET_VAL_ABBREV,
FUNCTION_INST_UNREACHABLE_ABBREV,
@@ -1549,6 +1550,9 @@ static uint64_t getOptimizationFlags(const Value *V) {
Flags |= bitc::AllowContract;
if (FPMO->hasApproxFunc())
Flags |= bitc::ApproxFunc;
+ } else if (const auto *NNI = dyn_cast<NonNegInstruction>(V)) {
+ if (NNI->hasNonNeg())
+ Flags |= 1 << bitc::NNI_NON_NEG;
}
return Flags;
@@ -2825,6 +2829,12 @@ void ModuleBitcodeWriter::writeInstruction(const Instruction &I,
AbbrevToUse = FUNCTION_INST_CAST_ABBREV;
Vals.push_back(VE.getTypeID(I.getType()));
Vals.push_back(getEncodedCastOpcode(I.getOpcode()));
+ uint64_t Flags = getOptimizationFlags(&I);
+ if (Flags != 0) {
+ if (AbbrevToUse == FUNCTION_INST_CAST_ABBREV)
+ AbbrevToUse = FUNCTION_INST_CAST_FLAGS_ABBREV;
+ Vals.push_back(Flags);
+ }
} else {
assert(isa<BinaryOperator>(I) && "Unknown instruction!");
Code = bitc::FUNC_CODE_INST_BINOP;
@@ -3646,6 +3656,18 @@ void ModuleBitcodeWriter::writeBlockInfo() {
FUNCTION_INST_CAST_ABBREV)
llvm_unreachable("Unexpected abbrev ordering!");
}
+ { // INST_CAST_FLAGS abbrev for FUNCTION_BLOCK.
+ auto Abbv = std::make_shared<BitCodeAbbrev>();
+ Abbv->Add(BitCodeAbbrevOp(bitc::FUNC_CODE_INST_CAST));
+ Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 6)); // OpVal
+ Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, // dest ty
+ VE.computeBitsRequiredForTypeIndicies()));
+ Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 4)); // opc
+ Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 8)); // flags
+ if (Stream.EmitBlockInfoAbbrev(bitc::FUNCTION_BLOCK_ID, Abbv) !=
+ FUNCTION_INST_CAST_FLAGS_ABBREV)
+ llvm_unreachable("Unexpected abbrev ordering!");
+ }
{ // INST_RET abbrev for FUNCTION_BLOCK.
auto Abbv = std::make_shared<BitCodeAbbrev>();
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index e190d82127908db..abcee3f599da633 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1347,6 +1347,9 @@ static void WriteOptimizationInfo(raw_ostream &Out, const User *U) {
} else if (const GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
if (GEP->isInBounds())
Out << " inbounds";
+ } else if (const auto *NNI = dyn_cast<NonNegInstruction>(U)) {
+ if (NNI->hasNonNeg())
+ Out << " nneg";
}
}
diff --git a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp
index b497951a598cc50..e7c8d1090d95ac1 100644
--- a/llvm/lib/IR/Instruction.cpp
+++ b/llvm/lib/IR/Instruction.cpp
@@ -172,6 +172,12 @@ void Instruction::setIsExact(bool b) {
cast<PossiblyExactOperator>(this)->setIsExact(b);
}
+void Instruction::setNonNeg(bool b) {
+ assert(isa<NonNegInstruction>(this) && "Must be zext");
+ SubclassOptionalData = (SubclassOptionalData & ~NonNegInstruction::NonNeg) |
+ (b * NonNegInstruction::NonNeg);
+}
+
bool Instruction::hasNoUnsignedWrap() const {
return cast<OverflowingBinaryOperator>(this)->hasNoUnsignedWrap();
}
@@ -180,6 +186,11 @@ bool Instruction::hasNoSignedWrap() const {
return cast<OverflowingBinaryOperator>(this)->hasNoSignedWrap();
}
+bool Instruction::hasNonNeg() const {
+ assert(isa<NonNegInstruction>(this) && "Must be zext");
+ return (SubclassOptionalData & NonNegInstruction::NonNeg) != 0;
+}
+
bool Instruction::hasPoisonGeneratingFlags() const {
return cast<Operator>(this)->hasPoisonGeneratingFlags();
}
@@ -204,7 +215,12 @@ void Instruction::dropPoisonGeneratingFlags() {
case Instruction::GetElementPtr:
cast<GetElementPtrInst>(this)->setIsInBounds(false);
break;
+
+ case Instruction::ZExt:
+ setNonNeg(false);
+ break;
}
+
if (isa<FPMathOperator>(this)) {
setHasNoNaNs(false);
setHasNoInfs(false);
@@ -379,6 +395,10 @@ void Instruction::copyIRFlags(const Value *V, bool IncludeWrapFlags) {
if (auto *SrcGEP = dyn_cast<GetElementPtrInst>(V))
if (auto *DestGEP = dyn_cast<GetElementPtrInst>(this))
DestGEP->setIsInBounds(SrcGEP->isInBounds() || DestGEP->isInBounds());
+
+ if (auto *NNI = dyn_cast<NonNegInstruction>(V))
+ if (isa<NonNegInstruction>(this))
+ setNonNeg(NNI->hasNonNeg());
}
void Instruction::andIRFlags(const Value *V) {
@@ -404,6 +424,10 @@ void Instruction::andIRFlags(const Value *V) {
if (auto *SrcGEP = dyn_cast<GetElementPtrInst>(V))
if (auto *DestGEP = dyn_cast<GetElementPtrInst>(this))
DestGEP->setIsInBounds(SrcGEP->isInBounds() && DestGEP->isInBounds());
+
+ if (auto *NNI = dyn_cast<NonNegInstruction>(V))
+ if (isa<NonNegInstruction>(this))
+ setNonNeg(hasNonNeg() && NNI->hasNonNeg());
}
const char *Instruction::getOpcodeName(unsigned OpCode) {
diff --git a/llvm/lib/IR/Operator.cpp b/llvm/lib/IR/Operator.cpp
index d2a1f2eb49dafed..a8be4233991d362 100644
--- a/llvm/lib/IR/Operator.cpp
+++ b/llvm/lib/IR/Operator.cpp
@@ -37,6 +37,10 @@ bool Operator::hasPoisonGeneratingFlags() const {
// Note: inrange exists on constexpr only
return GEP->isInBounds() || GEP->getInRangeIndex() != std::nullopt;
}
+ case Instruction::ZExt:
+ if (auto *NNI = dyn_cast<NonNegInstruction>(this))
+ return NNI->hasNonNeg();
+ return false;
default:
if (const auto *FP = dyn_cast<FPMathOperator>(this))
return FP->hasNoNaNs() || FP->hasNoInfs();
diff --git a/llvm/test/Assembler/flags.ll b/llvm/test/Assembler/flags.ll
index 3b54b06b81d4e2b..8331edf52a1699d 100644
--- a/llvm/test/Assembler/flags.ll
+++ b/llvm/test/Assembler/flags.ll
@@ -260,3 +260,9 @@ define i64 @mul_unsigned_ce() {
ret i64 mul nuw (i64 ptrtoint (ptr @addr to i64), i64 91)
}
+define i64 @test_zext(i32 %a) {
+; CHECK: %res = zext nneg i32 %a to i64
+ %res = zext nneg i32 %a to i64
+ ret i64 %res
+}
+
diff --git a/llvm/test/Bitcode/flags.ll b/llvm/test/Bitcode/flags.ll
index 6febaa6b40df863..a6e368b7e76327f 100644
--- a/llvm/test/Bitcode/flags.ll
+++ b/llvm/test/Bitcode/flags.ll
@@ -16,6 +16,8 @@ second: ; preds = %first
%s = add nsw i32 %a, 0 ; <i32> [#uses=0]
%us = add nuw nsw i32 %a, 0 ; <i32> [#uses=0]
%z = add i32 %a, 0 ; <i32> [#uses=0]
+ %hh = zext nneg i32 %a to i64
+ %ll = zext i32 %s to i64
unreachable
first: ; preds = %entry
@@ -24,5 +26,7 @@ first: ; preds = %entry
%ss = add nsw i32 %a, 0 ; <i32> [#uses=0]
%uuss = add nuw nsw i32 %a, 0 ; <i32> [#uses=0]
%zz = add i32 %a, 0 ; <i32> [#uses=0]
+ %kk = zext nneg i32 %a to i64
+ %rr = zext i32 %ss to i64
br label %second
}
diff --git a/llvm/test/Transforms/InstCombine/freeze.ll b/llvm/test/Transforms/InstCombine/freeze.ll
index f8e9a757e2bc938..3fde49d08481278 100644
--- a/llvm/test/Transforms/InstCombine/freeze.ll
+++ b/llvm/test/Transforms/InstCombine/freeze.ll
@@ -1116,6 +1116,17 @@ define i32 @freeze_ctpop(i32 %x) {
ret i32 %fr
}
+define i32 @freeze_zext_nneg(i8 %x) {
+; CHECK-LABEL: @freeze_zext_nneg(
+; CHECK-NEXT: [[X_FR:%.*]] = freeze i8 [[X:%.*]]
+; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X_FR]] to i32
+; CHECK-NEXT: ret i32 [[ZEXT]]
+;
+ %zext = zext nneg i8 %x to i32
+ %fr = freeze i32 %zext
+ ret i32 %fr
+}
+
!0 = !{}
!1 = !{i64 4}
!2 = !{i32 0, i32 100}
diff --git a/llvm/test/Transforms/SimplifyCFG/HoistCode.ll b/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
index 4088ecfc818982f..08cf6cd5be80cf7 100644
--- a/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
+++ b/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
@@ -94,3 +94,33 @@ end:
%cond = phi fast float [ 0.0, %bb0 ], [ %x, %bb1 ], [ %x, %bb2 ]
ret float %cond
}
+
+define i32 @hoist_zext_flags_preserve(i1 %C, i8 %x) {
+; CHECK-LABEL: @hoist_zext_flags_preserve(
+; CHECK-NEXT: common.ret:
+; CHECK-NEXT: [[Z1:%.*]] = zext nneg i8 [[X:%.*]] to i32
+; CHECK-NEXT: ret i32 [[Z1]]
+;
+ br i1 %C, label %T, label %F
+T:
+ %z1 = zext nneg i8 %x to i32
+ ret i32 %z1
+F:
+ %z2 = zext nneg i8 %x to i32
+ ret i32 %z2
+}
+
+define i32 @hoist_zext_flags_drop(i1 %C, i8 %x) {
+; CHECK-LABEL: @hoist_zext_flags_drop(
+; CHECK-NEXT: common.ret:
+; CHECK-NEXT: [[Z1:%.*]] = zext i8 [[X:%.*]] to i32
+; CHECK-NEXT: ret i32 [[Z1]]
+;
+ br i1 %C, label %T, label %F
+T:
+ %z1 = zext nneg i8 %x to i32
+ ret i32 %z1
+F:
+ %z2 = zext i8 %x to i32
+ ret i32 %z2
+}
More information about the llvm-commits
mailing list