[llvm] ed3f06b - [IR] Add zext nneg flag (#67982)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 30 01:04:09 PDT 2023
Author: Nikita Popov
Date: 2023-10-30T09:04:04+01:00
New Revision: ed3f06b9b393cd51e78e5fbc7a46bce090c1817a
URL: https://github.com/llvm/llvm-project/commit/ed3f06b9b393cd51e78e5fbc7a46bce090c1817a
DIFF: https://github.com/llvm/llvm-project/commit/ed3f06b9b393cd51e78e5fbc7a46bce090c1817a.diff
LOG: [IR] Add zext nneg flag (#67982)
Add an nneg flag to the zext instruction, which specifies that the
argument is non-negative. Otherwise, the result is a poison value.
The primary use-case for the flag is to preserve information when sext
gets replaced with zext due to range-based canonicalization. The nneg
flag allows us to convert the zext back into an sext later. This is
useful for some optimizations (e.g. a signed icmp can fold with sext but
not zext), as well as some targets (e.g. RISCV prefers sext over zext).
Discourse thread: https://discourse.llvm.org/t/rfc-add-zext-nneg-flag/73914
This patch is based on https://reviews.llvm.org/D156444 by
@Panagiotis156, with some implementation simplifications and additional
tests.
---------
Co-authored-by: Panagiotis K <karouzakispan at gmail.com>
Added:
Modified:
llvm/docs/LangRef.rst
llvm/include/llvm/AsmParser/LLToken.h
llvm/include/llvm/Bitcode/LLVMBitCodes.h
llvm/include/llvm/IR/InstrTypes.h
llvm/include/llvm/IR/Instruction.h
llvm/lib/AsmParser/LLLexer.cpp
llvm/lib/AsmParser/LLParser.cpp
llvm/lib/Bitcode/Reader/BitcodeReader.cpp
llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
llvm/lib/IR/AsmWriter.cpp
llvm/lib/IR/Instruction.cpp
llvm/lib/IR/Operator.cpp
llvm/test/Assembler/flags.ll
llvm/test/Bitcode/flags.ll
llvm/test/Transforms/InstCombine/freeze.ll
llvm/test/Transforms/SimplifyCFG/HoistCode.ll
Removed:
################################################################################
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index c97a7ae372bc6eb..3631dff50f30d8b 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -11229,6 +11229,10 @@ Overview:
The '``zext``' instruction zero extends its operand to type ``ty2``.
+The ``nneg`` (non-negative) flag, if present, specifies that the operand is
+non-negative. This property may be used by optimization passes to later
+convert the ``zext`` into a ``sext``.
+
Arguments:
""""""""""
@@ -11245,6 +11249,9 @@ until it reaches the size of the destination type, ``ty2``.
When zero extending from i1, the result will always be either 0 or 1.
+If the ``nneg`` flag is set, and the ``zext`` argument is negative, the result
+is a poison value.
+
Example:
""""""""
@@ -11254,6 +11261,9 @@ Example:
%Y = zext i1 true to i32 ; yields i32:1
%Z = zext <2 x i16> <i16 8, i16 7> to <2 x i32> ; yields <i32 8, i32 7>
+ %a = zext nneg i8 127 to i16 ; yields i16 127
+ %b = zext nneg i8 -1 to i16 ; yields i16 poison
+
.. _i_sext:
'``sext .. to``' Instruction
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 2d6b8a19401d78d..773a1b84ea5330e 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -110,6 +110,7 @@ enum Kind {
kw_nsw,
kw_exact,
kw_inbounds,
+ kw_nneg,
kw_inrange,
kw_addrspace,
kw_section,
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 5d7be5ca936ad37..ee2ccec6e89a4e6 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -505,6 +505,9 @@ enum FastMathMap {
AllowReassoc = (1 << 7)
};
+/// Flags for serializing PossiblyNonNegInst's SubclassOptionalData contents.
+enum PossiblyNonNegInstOptionalFlags { PNNI_NON_NEG = 0 };
+
/// PossiblyExactOperatorOptionalFlags - Flags for serializing
/// PossiblyExactOperator's SubclassOptionalData contents.
enum PossiblyExactOperatorOptionalFlags { PEO_EXACT = 0 };
diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 6095b0a1be69cb3..fc5e228168a058b 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -692,6 +692,20 @@ class CastInst : public UnaryInstruction {
}
};
+/// Instruction that can have a nneg flag (only zext).
+class PossiblyNonNegInst : public CastInst {
+public:
+ enum { NonNeg = (1 << 0) };
+
+ static bool classof(const Instruction *I) {
+ return I->getOpcode() == Instruction::ZExt;
+ }
+
+ static bool classof(const Value *V) {
+ return isa<Instruction>(V) && classof(cast<Instruction>(V));
+ }
+};
+
//===----------------------------------------------------------------------===//
// CmpInst Class
//===----------------------------------------------------------------------===//
diff --git a/llvm/include/llvm/IR/Instruction.h b/llvm/include/llvm/IR/Instruction.h
index af7aa791cb6da60..5142847fa75fff2 100644
--- a/llvm/include/llvm/IR/Instruction.h
+++ b/llvm/include/llvm/IR/Instruction.h
@@ -410,12 +410,19 @@ class Instruction : public User,
/// which supports this flag. See LangRef.html for the meaning of this flag.
void setIsExact(bool b = true);
+ /// Set or clear the nneg flag on this instruction, which must be a zext
+ /// instruction.
+ void setNonNeg(bool b = true);
+
/// Determine whether the no unsigned wrap flag is set.
bool hasNoUnsignedWrap() const LLVM_READONLY;
/// Determine whether the no signed wrap flag is set.
bool hasNoSignedWrap() const LLVM_READONLY;
+ /// Determine whether the the nneg flag is set.
+ bool hasNonNeg() const LLVM_READONLY;
+
/// Return true if this operator has flags which may cause this instruction
/// to evaluate to poison despite having non-poison inputs.
bool hasPoisonGeneratingFlags() const LLVM_READONLY;
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index ae46209b30ede03..284a4c64c6793ef 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -565,6 +565,7 @@ lltok::Kind LLLexer::LexIdentifier() {
KEYWORD(nsw);
KEYWORD(exact);
KEYWORD(inbounds);
+ KEYWORD(nneg);
KEYWORD(inrange);
KEYWORD(addrspace);
KEYWORD(section);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index e104f8b3d1fdba5..42f306a99d5eefa 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -6383,8 +6383,16 @@ int LLParser::parseInstruction(Instruction *&Inst, BasicBlock *BB,
}
// Casts.
+ case lltok::kw_zext: {
+ bool NonNeg = EatIfPresent(lltok::kw_nneg);
+ bool Res = parseCast(Inst, PFS, KeywordVal);
+ if (Res != 0)
+ return Res;
+ if (NonNeg)
+ Inst->setNonNeg();
+ return 0;
+ }
case lltok::kw_trunc:
- case lltok::kw_zext:
case lltok::kw_sext:
case lltok::kw_fptrunc:
case lltok::kw_fpext:
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 28addb9068b242b..fcba7366fdb6112 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -4877,12 +4877,13 @@ Error BitcodeReader::parseFunctionBody(Function *F) {
Value *Op;
unsigned OpTypeID;
if (getValueTypePair(Record, OpNum, NextValueNo, Op, OpTypeID, CurBB) ||
- OpNum+2 != Record.size())
+ OpNum + 1 > Record.size())
return error("Invalid record");
- ResTypeID = Record[OpNum];
+ ResTypeID = Record[OpNum++];
Type *ResTy = getTypeByID(ResTypeID);
- int Opc = getDecodedCastOpcode(Record[OpNum + 1]);
+ int Opc = getDecodedCastOpcode(Record[OpNum++]);
+
if (Opc == -1 || !ResTy)
return error("Invalid record");
Instruction *Temp = nullptr;
@@ -4898,6 +4899,9 @@ Error BitcodeReader::parseFunctionBody(Function *F) {
return error("Invalid cast");
I = CastInst::Create(CastOp, Op, ResTy);
}
+ if (OpNum < Record.size() && isa<PossiblyNonNegInst>(I) &&
+ (Record[OpNum] & (1 << bitc::PNNI_NON_NEG)))
+ I->setNonNeg(true);
InstructionList.push_back(I);
break;
}
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index c427459508ecfc8..d7ebc76d9bfb0ea 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -123,6 +123,7 @@ enum {
FUNCTION_INST_BINOP_ABBREV,
FUNCTION_INST_BINOP_FLAGS_ABBREV,
FUNCTION_INST_CAST_ABBREV,
+ FUNCTION_INST_CAST_FLAGS_ABBREV,
FUNCTION_INST_RET_VOID_ABBREV,
FUNCTION_INST_RET_VAL_ABBREV,
FUNCTION_INST_UNREACHABLE_ABBREV,
@@ -1551,6 +1552,9 @@ static uint64_t getOptimizationFlags(const Value *V) {
Flags |= bitc::AllowContract;
if (FPMO->hasApproxFunc())
Flags |= bitc::ApproxFunc;
+ } else if (const auto *NNI = dyn_cast<PossiblyNonNegInst>(V)) {
+ if (NNI->hasNonNeg())
+ Flags |= 1 << bitc::PNNI_NON_NEG;
}
return Flags;
@@ -2827,6 +2831,12 @@ void ModuleBitcodeWriter::writeInstruction(const Instruction &I,
AbbrevToUse = FUNCTION_INST_CAST_ABBREV;
Vals.push_back(VE.getTypeID(I.getType()));
Vals.push_back(getEncodedCastOpcode(I.getOpcode()));
+ uint64_t Flags = getOptimizationFlags(&I);
+ if (Flags != 0) {
+ if (AbbrevToUse == FUNCTION_INST_CAST_ABBREV)
+ AbbrevToUse = FUNCTION_INST_CAST_FLAGS_ABBREV;
+ Vals.push_back(Flags);
+ }
} else {
assert(isa<BinaryOperator>(I) && "Unknown instruction!");
Code = bitc::FUNC_CODE_INST_BINOP;
@@ -3648,6 +3658,18 @@ void ModuleBitcodeWriter::writeBlockInfo() {
FUNCTION_INST_CAST_ABBREV)
llvm_unreachable("Unexpected abbrev ordering!");
}
+ { // INST_CAST_FLAGS abbrev for FUNCTION_BLOCK.
+ auto Abbv = std::make_shared<BitCodeAbbrev>();
+ Abbv->Add(BitCodeAbbrevOp(bitc::FUNC_CODE_INST_CAST));
+ Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 6)); // OpVal
+ Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, // dest ty
+ VE.computeBitsRequiredForTypeIndicies()));
+ Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 4)); // opc
+ Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 8)); // flags
+ if (Stream.EmitBlockInfoAbbrev(bitc::FUNCTION_BLOCK_ID, Abbv) !=
+ FUNCTION_INST_CAST_FLAGS_ABBREV)
+ llvm_unreachable("Unexpected abbrev ordering!");
+ }
{ // INST_RET abbrev for FUNCTION_BLOCK.
auto Abbv = std::make_shared<BitCodeAbbrev>();
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index bd8b3e9ad52215e..c738b50a7c721ed 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1348,6 +1348,9 @@ static void WriteOptimizationInfo(raw_ostream &Out, const User *U) {
} else if (const GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
if (GEP->isInBounds())
Out << " inbounds";
+ } else if (const auto *NNI = dyn_cast<PossiblyNonNegInst>(U)) {
+ if (NNI->hasNonNeg())
+ Out << " nneg";
}
}
diff --git a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp
index 9b176eb78888e7c..1b3c03348f41a70 100644
--- a/llvm/lib/IR/Instruction.cpp
+++ b/llvm/lib/IR/Instruction.cpp
@@ -171,6 +171,12 @@ void Instruction::setIsExact(bool b) {
cast<PossiblyExactOperator>(this)->setIsExact(b);
}
+void Instruction::setNonNeg(bool b) {
+ assert(isa<PossiblyNonNegInst>(this) && "Must be zext");
+ SubclassOptionalData = (SubclassOptionalData & ~PossiblyNonNegInst::NonNeg) |
+ (b * PossiblyNonNegInst::NonNeg);
+}
+
bool Instruction::hasNoUnsignedWrap() const {
return cast<OverflowingBinaryOperator>(this)->hasNoUnsignedWrap();
}
@@ -179,6 +185,11 @@ bool Instruction::hasNoSignedWrap() const {
return cast<OverflowingBinaryOperator>(this)->hasNoSignedWrap();
}
+bool Instruction::hasNonNeg() const {
+ assert(isa<PossiblyNonNegInst>(this) && "Must be zext");
+ return (SubclassOptionalData & PossiblyNonNegInst::NonNeg) != 0;
+}
+
bool Instruction::hasPoisonGeneratingFlags() const {
return cast<Operator>(this)->hasPoisonGeneratingFlags();
}
@@ -203,7 +214,12 @@ void Instruction::dropPoisonGeneratingFlags() {
case Instruction::GetElementPtr:
cast<GetElementPtrInst>(this)->setIsInBounds(false);
break;
+
+ case Instruction::ZExt:
+ setNonNeg(false);
+ break;
}
+
if (isa<FPMathOperator>(this)) {
setHasNoNaNs(false);
setHasNoInfs(false);
@@ -378,6 +394,10 @@ void Instruction::copyIRFlags(const Value *V, bool IncludeWrapFlags) {
if (auto *SrcGEP = dyn_cast<GetElementPtrInst>(V))
if (auto *DestGEP = dyn_cast<GetElementPtrInst>(this))
DestGEP->setIsInBounds(SrcGEP->isInBounds() || DestGEP->isInBounds());
+
+ if (auto *NNI = dyn_cast<PossiblyNonNegInst>(V))
+ if (isa<PossiblyNonNegInst>(this))
+ setNonNeg(NNI->hasNonNeg());
}
void Instruction::andIRFlags(const Value *V) {
@@ -403,6 +423,10 @@ void Instruction::andIRFlags(const Value *V) {
if (auto *SrcGEP = dyn_cast<GetElementPtrInst>(V))
if (auto *DestGEP = dyn_cast<GetElementPtrInst>(this))
DestGEP->setIsInBounds(SrcGEP->isInBounds() && DestGEP->isInBounds());
+
+ if (auto *NNI = dyn_cast<PossiblyNonNegInst>(V))
+ if (isa<PossiblyNonNegInst>(this))
+ setNonNeg(hasNonNeg() && NNI->hasNonNeg());
}
const char *Instruction::getOpcodeName(unsigned OpCode) {
diff --git a/llvm/lib/IR/Operator.cpp b/llvm/lib/IR/Operator.cpp
index d2a1f2eb49dafed..0c917ad77e15806 100644
--- a/llvm/lib/IR/Operator.cpp
+++ b/llvm/lib/IR/Operator.cpp
@@ -37,6 +37,10 @@ bool Operator::hasPoisonGeneratingFlags() const {
// Note: inrange exists on constexpr only
return GEP->isInBounds() || GEP->getInRangeIndex() != std::nullopt;
}
+ case Instruction::ZExt:
+ if (auto *NNI = dyn_cast<PossiblyNonNegInst>(this))
+ return NNI->hasNonNeg();
+ return false;
default:
if (const auto *FP = dyn_cast<FPMathOperator>(this))
return FP->hasNoNaNs() || FP->hasNoInfs();
diff --git a/llvm/test/Assembler/flags.ll b/llvm/test/Assembler/flags.ll
index 3b54b06b81d4e2b..8331edf52a1699d 100644
--- a/llvm/test/Assembler/flags.ll
+++ b/llvm/test/Assembler/flags.ll
@@ -260,3 +260,9 @@ define i64 @mul_unsigned_ce() {
ret i64 mul nuw (i64 ptrtoint (ptr @addr to i64), i64 91)
}
+define i64 @test_zext(i32 %a) {
+; CHECK: %res = zext nneg i32 %a to i64
+ %res = zext nneg i32 %a to i64
+ ret i64 %res
+}
+
diff --git a/llvm/test/Bitcode/flags.ll b/llvm/test/Bitcode/flags.ll
index 6febaa6b40df863..a6e368b7e76327f 100644
--- a/llvm/test/Bitcode/flags.ll
+++ b/llvm/test/Bitcode/flags.ll
@@ -16,6 +16,8 @@ second: ; preds = %first
%s = add nsw i32 %a, 0 ; <i32> [#uses=0]
%us = add nuw nsw i32 %a, 0 ; <i32> [#uses=0]
%z = add i32 %a, 0 ; <i32> [#uses=0]
+ %hh = zext nneg i32 %a to i64
+ %ll = zext i32 %s to i64
unreachable
first: ; preds = %entry
@@ -24,5 +26,7 @@ first: ; preds = %entry
%ss = add nsw i32 %a, 0 ; <i32> [#uses=0]
%uuss = add nuw nsw i32 %a, 0 ; <i32> [#uses=0]
%zz = add i32 %a, 0 ; <i32> [#uses=0]
+ %kk = zext nneg i32 %a to i64
+ %rr = zext i32 %ss to i64
br label %second
}
diff --git a/llvm/test/Transforms/InstCombine/freeze.ll b/llvm/test/Transforms/InstCombine/freeze.ll
index f8e9a757e2bc938..3fde49d08481278 100644
--- a/llvm/test/Transforms/InstCombine/freeze.ll
+++ b/llvm/test/Transforms/InstCombine/freeze.ll
@@ -1116,6 +1116,17 @@ define i32 @freeze_ctpop(i32 %x) {
ret i32 %fr
}
+define i32 @freeze_zext_nneg(i8 %x) {
+; CHECK-LABEL: @freeze_zext_nneg(
+; CHECK-NEXT: [[X_FR:%.*]] = freeze i8 [[X:%.*]]
+; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X_FR]] to i32
+; CHECK-NEXT: ret i32 [[ZEXT]]
+;
+ %zext = zext nneg i8 %x to i32
+ %fr = freeze i32 %zext
+ ret i32 %fr
+}
+
!0 = !{}
!1 = !{i64 4}
!2 = !{i32 0, i32 100}
diff --git a/llvm/test/Transforms/SimplifyCFG/HoistCode.ll b/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
index 4088ecfc818982f..08cf6cd5be80cf7 100644
--- a/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
+++ b/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
@@ -94,3 +94,33 @@ end:
%cond = phi fast float [ 0.0, %bb0 ], [ %x, %bb1 ], [ %x, %bb2 ]
ret float %cond
}
+
+define i32 @hoist_zext_flags_preserve(i1 %C, i8 %x) {
+; CHECK-LABEL: @hoist_zext_flags_preserve(
+; CHECK-NEXT: common.ret:
+; CHECK-NEXT: [[Z1:%.*]] = zext nneg i8 [[X:%.*]] to i32
+; CHECK-NEXT: ret i32 [[Z1]]
+;
+ br i1 %C, label %T, label %F
+T:
+ %z1 = zext nneg i8 %x to i32
+ ret i32 %z1
+F:
+ %z2 = zext nneg i8 %x to i32
+ ret i32 %z2
+}
+
+define i32 @hoist_zext_flags_drop(i1 %C, i8 %x) {
+; CHECK-LABEL: @hoist_zext_flags_drop(
+; CHECK-NEXT: common.ret:
+; CHECK-NEXT: [[Z1:%.*]] = zext i8 [[X:%.*]] to i32
+; CHECK-NEXT: ret i32 [[Z1]]
+;
+ br i1 %C, label %T, label %F
+T:
+ %z1 = zext nneg i8 %x to i32
+ ret i32 %z1
+F:
+ %z2 = zext i8 %x to i32
+ ret i32 %z2
+}
More information about the llvm-commits
mailing list