[llvm] ed3f06b - [IR] Add zext nneg flag (#67982)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 30 01:04:09 PDT 2023


Author: Nikita Popov
Date: 2023-10-30T09:04:04+01:00
New Revision: ed3f06b9b393cd51e78e5fbc7a46bce090c1817a

URL: https://github.com/llvm/llvm-project/commit/ed3f06b9b393cd51e78e5fbc7a46bce090c1817a
DIFF: https://github.com/llvm/llvm-project/commit/ed3f06b9b393cd51e78e5fbc7a46bce090c1817a.diff

LOG: [IR] Add zext nneg flag (#67982)

Add an nneg flag to the zext instruction, which specifies that the
argument is non-negative. Otherwise, the result is a poison value.

The primary use-case for the flag is to preserve information when sext
gets replaced with zext due to range-based canonicalization. The nneg
flag allows us to convert the zext back into an sext later. This is
useful for some optimizations (e.g. a signed icmp can fold with sext but
not zext), as well as some targets (e.g. RISCV prefers sext over zext).

Discourse thread: https://discourse.llvm.org/t/rfc-add-zext-nneg-flag/73914

This patch is based on https://reviews.llvm.org/D156444 by
@Panagiotis156, with some implementation simplifications and additional
tests.

---------

Co-authored-by: Panagiotis K <karouzakispan at gmail.com>

Added: 
    

Modified: 
    llvm/docs/LangRef.rst
    llvm/include/llvm/AsmParser/LLToken.h
    llvm/include/llvm/Bitcode/LLVMBitCodes.h
    llvm/include/llvm/IR/InstrTypes.h
    llvm/include/llvm/IR/Instruction.h
    llvm/lib/AsmParser/LLLexer.cpp
    llvm/lib/AsmParser/LLParser.cpp
    llvm/lib/Bitcode/Reader/BitcodeReader.cpp
    llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
    llvm/lib/IR/AsmWriter.cpp
    llvm/lib/IR/Instruction.cpp
    llvm/lib/IR/Operator.cpp
    llvm/test/Assembler/flags.ll
    llvm/test/Bitcode/flags.ll
    llvm/test/Transforms/InstCombine/freeze.ll
    llvm/test/Transforms/SimplifyCFG/HoistCode.ll

Removed: 
    


################################################################################
diff  --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index c97a7ae372bc6eb..3631dff50f30d8b 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -11229,6 +11229,10 @@ Overview:
 
 The '``zext``' instruction zero extends its operand to type ``ty2``.
 
+The ``nneg`` (non-negative) flag, if present, specifies that the operand is
+non-negative. This property may be used by optimization passes to later
+convert the ``zext`` into a ``sext``.
+
 Arguments:
 """"""""""
 
@@ -11245,6 +11249,9 @@ until it reaches the size of the destination type, ``ty2``.
 
 When zero extending from i1, the result will always be either 0 or 1.
 
+If the ``nneg`` flag is set, and the ``zext`` argument is negative, the result
+is a poison value.
+
 Example:
 """"""""
 
@@ -11254,6 +11261,9 @@ Example:
       %Y = zext i1 true to i32              ; yields i32:1
       %Z = zext <2 x i16> <i16 8, i16 7> to <2 x i32> ; yields <i32 8, i32 7>
 
+      %a = zext nneg i8 127 to i16 ; yields i16 127
+      %b = zext nneg i8 -1 to i16  ; yields i16 poison
+
 .. _i_sext:
 
 '``sext .. to``' Instruction

diff  --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 2d6b8a19401d78d..773a1b84ea5330e 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -110,6 +110,7 @@ enum Kind {
   kw_nsw,
   kw_exact,
   kw_inbounds,
+  kw_nneg,
   kw_inrange,
   kw_addrspace,
   kw_section,

diff  --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 5d7be5ca936ad37..ee2ccec6e89a4e6 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -505,6 +505,9 @@ enum FastMathMap {
   AllowReassoc    = (1 << 7)
 };
 
+/// Flags for serializing PossiblyNonNegInst's SubclassOptionalData contents.
+enum PossiblyNonNegInstOptionalFlags { PNNI_NON_NEG = 0 };
+
 /// PossiblyExactOperatorOptionalFlags - Flags for serializing
 /// PossiblyExactOperator's SubclassOptionalData contents.
 enum PossiblyExactOperatorOptionalFlags { PEO_EXACT = 0 };

diff  --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 6095b0a1be69cb3..fc5e228168a058b 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -692,6 +692,20 @@ class CastInst : public UnaryInstruction {
   }
 };
 
+/// Instruction that can have a nneg flag (only zext).
+class PossiblyNonNegInst : public CastInst {
+public:
+  enum { NonNeg = (1 << 0) };
+
+  static bool classof(const Instruction *I) {
+    return I->getOpcode() == Instruction::ZExt;
+  }
+
+  static bool classof(const Value *V) {
+    return isa<Instruction>(V) && classof(cast<Instruction>(V));
+  }
+};
+
 //===----------------------------------------------------------------------===//
 //                               CmpInst Class
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/include/llvm/IR/Instruction.h b/llvm/include/llvm/IR/Instruction.h
index af7aa791cb6da60..5142847fa75fff2 100644
--- a/llvm/include/llvm/IR/Instruction.h
+++ b/llvm/include/llvm/IR/Instruction.h
@@ -410,12 +410,19 @@ class Instruction : public User,
   /// which supports this flag. See LangRef.html for the meaning of this flag.
   void setIsExact(bool b = true);
 
+  /// Set or clear the nneg flag on this instruction, which must be a zext
+  /// instruction.
+  void setNonNeg(bool b = true);
+
   /// Determine whether the no unsigned wrap flag is set.
   bool hasNoUnsignedWrap() const LLVM_READONLY;
 
   /// Determine whether the no signed wrap flag is set.
   bool hasNoSignedWrap() const LLVM_READONLY;
 
+  /// Determine whether the the nneg flag is set.
+  bool hasNonNeg() const LLVM_READONLY;
+
   /// Return true if this operator has flags which may cause this instruction
   /// to evaluate to poison despite having non-poison inputs.
   bool hasPoisonGeneratingFlags() const LLVM_READONLY;

diff  --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index ae46209b30ede03..284a4c64c6793ef 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -565,6 +565,7 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(nsw);
   KEYWORD(exact);
   KEYWORD(inbounds);
+  KEYWORD(nneg);
   KEYWORD(inrange);
   KEYWORD(addrspace);
   KEYWORD(section);

diff  --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index e104f8b3d1fdba5..42f306a99d5eefa 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -6383,8 +6383,16 @@ int LLParser::parseInstruction(Instruction *&Inst, BasicBlock *BB,
   }
 
   // Casts.
+  case lltok::kw_zext: {
+    bool NonNeg = EatIfPresent(lltok::kw_nneg);
+    bool Res = parseCast(Inst, PFS, KeywordVal);
+    if (Res != 0)
+      return Res;
+    if (NonNeg)
+      Inst->setNonNeg();
+    return 0;
+  }
   case lltok::kw_trunc:
-  case lltok::kw_zext:
   case lltok::kw_sext:
   case lltok::kw_fptrunc:
   case lltok::kw_fpext:

diff  --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 28addb9068b242b..fcba7366fdb6112 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -4877,12 +4877,13 @@ Error BitcodeReader::parseFunctionBody(Function *F) {
       Value *Op;
       unsigned OpTypeID;
       if (getValueTypePair(Record, OpNum, NextValueNo, Op, OpTypeID, CurBB) ||
-          OpNum+2 != Record.size())
+          OpNum + 1 > Record.size())
         return error("Invalid record");
 
-      ResTypeID = Record[OpNum];
+      ResTypeID = Record[OpNum++];
       Type *ResTy = getTypeByID(ResTypeID);
-      int Opc = getDecodedCastOpcode(Record[OpNum + 1]);
+      int Opc = getDecodedCastOpcode(Record[OpNum++]);
+
       if (Opc == -1 || !ResTy)
         return error("Invalid record");
       Instruction *Temp = nullptr;
@@ -4898,6 +4899,9 @@ Error BitcodeReader::parseFunctionBody(Function *F) {
           return error("Invalid cast");
         I = CastInst::Create(CastOp, Op, ResTy);
       }
+      if (OpNum < Record.size() && isa<PossiblyNonNegInst>(I) &&
+          (Record[OpNum] & (1 << bitc::PNNI_NON_NEG)))
+        I->setNonNeg(true);
       InstructionList.push_back(I);
       break;
     }

diff  --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index c427459508ecfc8..d7ebc76d9bfb0ea 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -123,6 +123,7 @@ enum {
   FUNCTION_INST_BINOP_ABBREV,
   FUNCTION_INST_BINOP_FLAGS_ABBREV,
   FUNCTION_INST_CAST_ABBREV,
+  FUNCTION_INST_CAST_FLAGS_ABBREV,
   FUNCTION_INST_RET_VOID_ABBREV,
   FUNCTION_INST_RET_VAL_ABBREV,
   FUNCTION_INST_UNREACHABLE_ABBREV,
@@ -1551,6 +1552,9 @@ static uint64_t getOptimizationFlags(const Value *V) {
       Flags |= bitc::AllowContract;
     if (FPMO->hasApproxFunc())
       Flags |= bitc::ApproxFunc;
+  } else if (const auto *NNI = dyn_cast<PossiblyNonNegInst>(V)) {
+    if (NNI->hasNonNeg())
+      Flags |= 1 << bitc::PNNI_NON_NEG;
   }
 
   return Flags;
@@ -2827,6 +2831,12 @@ void ModuleBitcodeWriter::writeInstruction(const Instruction &I,
         AbbrevToUse = FUNCTION_INST_CAST_ABBREV;
       Vals.push_back(VE.getTypeID(I.getType()));
       Vals.push_back(getEncodedCastOpcode(I.getOpcode()));
+      uint64_t Flags = getOptimizationFlags(&I);
+      if (Flags != 0) {
+        if (AbbrevToUse == FUNCTION_INST_CAST_ABBREV)
+          AbbrevToUse = FUNCTION_INST_CAST_FLAGS_ABBREV;
+        Vals.push_back(Flags);
+      }
     } else {
       assert(isa<BinaryOperator>(I) && "Unknown instruction!");
       Code = bitc::FUNC_CODE_INST_BINOP;
@@ -3648,6 +3658,18 @@ void ModuleBitcodeWriter::writeBlockInfo() {
         FUNCTION_INST_CAST_ABBREV)
       llvm_unreachable("Unexpected abbrev ordering!");
   }
+  { // INST_CAST_FLAGS abbrev for FUNCTION_BLOCK.
+    auto Abbv = std::make_shared<BitCodeAbbrev>();
+    Abbv->Add(BitCodeAbbrevOp(bitc::FUNC_CODE_INST_CAST));
+    Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 6)); // OpVal
+    Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed,    // dest ty
+                              VE.computeBitsRequiredForTypeIndicies()));
+    Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 4)); // opc
+    Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 8)); // flags
+    if (Stream.EmitBlockInfoAbbrev(bitc::FUNCTION_BLOCK_ID, Abbv) !=
+        FUNCTION_INST_CAST_FLAGS_ABBREV)
+      llvm_unreachable("Unexpected abbrev ordering!");
+  }
 
   { // INST_RET abbrev for FUNCTION_BLOCK.
     auto Abbv = std::make_shared<BitCodeAbbrev>();

diff  --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index bd8b3e9ad52215e..c738b50a7c721ed 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1348,6 +1348,9 @@ static void WriteOptimizationInfo(raw_ostream &Out, const User *U) {
   } else if (const GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
     if (GEP->isInBounds())
       Out << " inbounds";
+  } else if (const auto *NNI = dyn_cast<PossiblyNonNegInst>(U)) {
+    if (NNI->hasNonNeg())
+      Out << " nneg";
   }
 }
 

diff  --git a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp
index 9b176eb78888e7c..1b3c03348f41a70 100644
--- a/llvm/lib/IR/Instruction.cpp
+++ b/llvm/lib/IR/Instruction.cpp
@@ -171,6 +171,12 @@ void Instruction::setIsExact(bool b) {
   cast<PossiblyExactOperator>(this)->setIsExact(b);
 }
 
+void Instruction::setNonNeg(bool b) {
+  assert(isa<PossiblyNonNegInst>(this) && "Must be zext");
+  SubclassOptionalData = (SubclassOptionalData & ~PossiblyNonNegInst::NonNeg) |
+                         (b * PossiblyNonNegInst::NonNeg);
+}
+
 bool Instruction::hasNoUnsignedWrap() const {
   return cast<OverflowingBinaryOperator>(this)->hasNoUnsignedWrap();
 }
@@ -179,6 +185,11 @@ bool Instruction::hasNoSignedWrap() const {
   return cast<OverflowingBinaryOperator>(this)->hasNoSignedWrap();
 }
 
+bool Instruction::hasNonNeg() const {
+  assert(isa<PossiblyNonNegInst>(this) && "Must be zext");
+  return (SubclassOptionalData & PossiblyNonNegInst::NonNeg) != 0;
+}
+
 bool Instruction::hasPoisonGeneratingFlags() const {
   return cast<Operator>(this)->hasPoisonGeneratingFlags();
 }
@@ -203,7 +214,12 @@ void Instruction::dropPoisonGeneratingFlags() {
   case Instruction::GetElementPtr:
     cast<GetElementPtrInst>(this)->setIsInBounds(false);
     break;
+
+  case Instruction::ZExt:
+    setNonNeg(false);
+    break;
   }
+
   if (isa<FPMathOperator>(this)) {
     setHasNoNaNs(false);
     setHasNoInfs(false);
@@ -378,6 +394,10 @@ void Instruction::copyIRFlags(const Value *V, bool IncludeWrapFlags) {
   if (auto *SrcGEP = dyn_cast<GetElementPtrInst>(V))
     if (auto *DestGEP = dyn_cast<GetElementPtrInst>(this))
       DestGEP->setIsInBounds(SrcGEP->isInBounds() || DestGEP->isInBounds());
+
+  if (auto *NNI = dyn_cast<PossiblyNonNegInst>(V))
+    if (isa<PossiblyNonNegInst>(this))
+      setNonNeg(NNI->hasNonNeg());
 }
 
 void Instruction::andIRFlags(const Value *V) {
@@ -403,6 +423,10 @@ void Instruction::andIRFlags(const Value *V) {
   if (auto *SrcGEP = dyn_cast<GetElementPtrInst>(V))
     if (auto *DestGEP = dyn_cast<GetElementPtrInst>(this))
       DestGEP->setIsInBounds(SrcGEP->isInBounds() && DestGEP->isInBounds());
+
+  if (auto *NNI = dyn_cast<PossiblyNonNegInst>(V))
+    if (isa<PossiblyNonNegInst>(this))
+      setNonNeg(hasNonNeg() && NNI->hasNonNeg());
 }
 
 const char *Instruction::getOpcodeName(unsigned OpCode) {

diff  --git a/llvm/lib/IR/Operator.cpp b/llvm/lib/IR/Operator.cpp
index d2a1f2eb49dafed..0c917ad77e15806 100644
--- a/llvm/lib/IR/Operator.cpp
+++ b/llvm/lib/IR/Operator.cpp
@@ -37,6 +37,10 @@ bool Operator::hasPoisonGeneratingFlags() const {
     // Note: inrange exists on constexpr only
     return GEP->isInBounds() || GEP->getInRangeIndex() != std::nullopt;
   }
+  case Instruction::ZExt:
+    if (auto *NNI = dyn_cast<PossiblyNonNegInst>(this))
+      return NNI->hasNonNeg();
+    return false;
   default:
     if (const auto *FP = dyn_cast<FPMathOperator>(this))
       return FP->hasNoNaNs() || FP->hasNoInfs();

diff  --git a/llvm/test/Assembler/flags.ll b/llvm/test/Assembler/flags.ll
index 3b54b06b81d4e2b..8331edf52a1699d 100644
--- a/llvm/test/Assembler/flags.ll
+++ b/llvm/test/Assembler/flags.ll
@@ -260,3 +260,9 @@ define i64 @mul_unsigned_ce() {
 	ret i64 mul nuw (i64 ptrtoint (ptr @addr to i64), i64 91)
 }
 
+define i64 @test_zext(i32 %a) {
+; CHECK: %res = zext nneg i32 %a to i64
+  %res = zext nneg i32 %a to i64
+  ret i64 %res
+}
+

diff  --git a/llvm/test/Bitcode/flags.ll b/llvm/test/Bitcode/flags.ll
index 6febaa6b40df863..a6e368b7e76327f 100644
--- a/llvm/test/Bitcode/flags.ll
+++ b/llvm/test/Bitcode/flags.ll
@@ -16,6 +16,8 @@ second:                                           ; preds = %first
   %s = add nsw i32 %a, 0                          ; <i32> [#uses=0]
   %us = add nuw nsw i32 %a, 0                     ; <i32> [#uses=0]
   %z = add i32 %a, 0                              ; <i32> [#uses=0]
+  %hh = zext nneg i32 %a to i64
+  %ll = zext i32 %s to i64
   unreachable
 
 first:                                            ; preds = %entry
@@ -24,5 +26,7 @@ first:                                            ; preds = %entry
   %ss = add nsw i32 %a, 0                         ; <i32> [#uses=0]
   %uuss = add nuw nsw i32 %a, 0                   ; <i32> [#uses=0]
   %zz = add i32 %a, 0                             ; <i32> [#uses=0]
+  %kk = zext nneg i32 %a to i64
+  %rr = zext i32 %ss to i64
   br label %second
 }

diff  --git a/llvm/test/Transforms/InstCombine/freeze.ll b/llvm/test/Transforms/InstCombine/freeze.ll
index f8e9a757e2bc938..3fde49d08481278 100644
--- a/llvm/test/Transforms/InstCombine/freeze.ll
+++ b/llvm/test/Transforms/InstCombine/freeze.ll
@@ -1116,6 +1116,17 @@ define i32 @freeze_ctpop(i32 %x) {
   ret i32 %fr
 }
 
+define i32 @freeze_zext_nneg(i8 %x) {
+; CHECK-LABEL: @freeze_zext_nneg(
+; CHECK-NEXT:    [[X_FR:%.*]] = freeze i8 [[X:%.*]]
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext i8 [[X_FR]] to i32
+; CHECK-NEXT:    ret i32 [[ZEXT]]
+;
+  %zext = zext nneg i8 %x to i32
+  %fr = freeze i32 %zext
+  ret i32 %fr
+}
+
 !0 = !{}
 !1 = !{i64 4}
 !2 = !{i32 0, i32 100}

diff  --git a/llvm/test/Transforms/SimplifyCFG/HoistCode.ll b/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
index 4088ecfc818982f..08cf6cd5be80cf7 100644
--- a/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
+++ b/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
@@ -94,3 +94,33 @@ end:
   %cond = phi fast float [ 0.0, %bb0 ], [ %x, %bb1 ], [ %x, %bb2 ]
   ret float %cond
 }
+
+define i32 @hoist_zext_flags_preserve(i1 %C, i8 %x) {
+; CHECK-LABEL: @hoist_zext_flags_preserve(
+; CHECK-NEXT:  common.ret:
+; CHECK-NEXT:    [[Z1:%.*]] = zext nneg i8 [[X:%.*]] to i32
+; CHECK-NEXT:    ret i32 [[Z1]]
+;
+  br i1 %C, label %T, label %F
+T:
+  %z1 = zext nneg i8 %x to i32
+  ret i32 %z1
+F:
+  %z2 = zext nneg i8 %x to i32
+  ret i32 %z2
+}
+
+define i32 @hoist_zext_flags_drop(i1 %C, i8 %x) {
+; CHECK-LABEL: @hoist_zext_flags_drop(
+; CHECK-NEXT:  common.ret:
+; CHECK-NEXT:    [[Z1:%.*]] = zext i8 [[X:%.*]] to i32
+; CHECK-NEXT:    ret i32 [[Z1]]
+;
+  br i1 %C, label %T, label %F
+T:
+  %z1 = zext nneg i8 %x to i32
+  ret i32 %z1
+F:
+  %z2 = zext i8 %x to i32
+  ret i32 %z2
+}


        


More information about the llvm-commits mailing list