[llvm] [IR] Add zext nneg flag (PR #67982)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 2 06:00:31 PDT 2023


https://github.com/nikic created https://github.com/llvm/llvm-project/pull/67982

Add an nneg flag to the zext instruction, which specifies that the argument is non-negative. Otherwise, the result is a poison value.

The primary use-case for the flag is to preserve information when sext gets replaced with zext due to range-based canonicalization. The nneg flag allows us to convert the zext back into an sext later. This is useful for some optimizations (e.g. a signed icmp can fold with sext but not zext), as well as some targets (e.g. RISCV prefers sext over zext).

This patch is based on https://reviews.llvm.org/D156444, with some implementation simplifications and additional tests.

>From ff1b36743745225b486c1b064a233006261cb182 Mon Sep 17 00:00:00 2001
From: Panagiotis K <karouzakispan at gmail.com>
Date: Mon, 2 Oct 2023 14:16:38 +0200
Subject: [PATCH] [IR] Add zext nneg flag

Add an nneg flag to the zext instruction, which specifies that the
argument is non-negative. Otherwise, the result is a poison value.

The primary use-case for the flag is to preserve information when
sext gets replaced with zext due to range-based canonicalization.
The nneg flag allows us to convert the zext back into an sext
later. This is useful for some optimizations (e.g. a signed icmp
can fold with sext but not zext), as well as some targets (e.g.
RISCV prefers sext over zext).

Co-authored-by: Nikita Popov <npopov at redhat.com>
---
 llvm/docs/LangRef.rst                         |  6 ++++
 llvm/include/llvm/AsmParser/LLToken.h         |  1 +
 llvm/include/llvm/Bitcode/LLVMBitCodes.h      |  3 ++
 llvm/include/llvm/IR/InstrTypes.h             | 14 +++++++++
 llvm/include/llvm/IR/Instruction.h            |  7 +++++
 llvm/lib/AsmParser/LLLexer.cpp                |  1 +
 llvm/lib/AsmParser/LLParser.cpp               | 11 ++++++-
 llvm/lib/Bitcode/Reader/BitcodeReader.cpp     | 11 +++++--
 llvm/lib/Bitcode/Writer/BitcodeWriter.cpp     | 22 ++++++++++++++
 llvm/lib/IR/AsmWriter.cpp                     |  3 ++
 llvm/lib/IR/Instruction.cpp                   | 24 +++++++++++++++
 llvm/lib/IR/Operator.cpp                      |  4 +++
 llvm/test/Assembler/flags.ll                  |  6 ++++
 llvm/test/Bitcode/flags.ll                    |  4 +++
 llvm/test/Transforms/InstCombine/freeze.ll    | 11 +++++++
 llvm/test/Transforms/SimplifyCFG/HoistCode.ll | 30 +++++++++++++++++++
 16 files changed, 154 insertions(+), 4 deletions(-)

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index c8f74c19bd6b3cf..4c4f3cbbaf0d541 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -11156,6 +11156,9 @@ until it reaches the size of the destination type, ``ty2``.
 
 When zero extending from i1, the result will always be either 0 or 1.
 
+If the ``nneg`` flag is set, and the ``zext`` argument is negative, the result
+is a poison value.
+
 Example:
 """"""""
 
@@ -11165,6 +11168,9 @@ Example:
       %Y = zext i1 true to i32              ; yields i32:1
       %Z = zext <2 x i16> <i16 8, i16 7> to <2 x i32> ; yields <i32 8, i32 7>
 
+      %a = zext nneg i8 127 to i16 ; yields i16 127
+      %b = zext nneg i8 -1 to i16  ; yields i16 poison
+
 .. _i_sext:
 
 '``sext .. to``' Instruction
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 673dc58ce6451e3..ae566672b97986b 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -110,6 +110,7 @@ enum Kind {
   kw_nsw,
   kw_exact,
   kw_inbounds,
+  kw_nneg,
   kw_inrange,
   kw_addrspace,
   kw_section,
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 52e76356a892e45..6f11bb0b30c379a 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -505,6 +505,9 @@ enum FastMathMap {
   AllowReassoc    = (1 << 7)
 };
 
+/// Flags for serializing NonNegInstruction's SubclassOptionalData contents.
+enum NonNegInstructionOptionalFlags { NNI_NON_NEG = 0 };
+
 /// PossiblyExactOperatorOptionalFlags - Flags for serializing
 /// PossiblyExactOperator's SubclassOptionalData contents.
 enum PossiblyExactOperatorOptionalFlags { PEO_EXACT = 0 };
diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 6095b0a1be69cb3..b6da2bf836cb56f 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -692,6 +692,20 @@ class CastInst : public UnaryInstruction {
   }
 };
 
+/// Instruction that can have a nneg flag (only zext).
+class NonNegInstruction : public CastInst {
+public:
+  enum { NonNeg = (1 << 0) };
+
+  static bool classof(const Instruction *I) {
+    return I->getOpcode() == Instruction::ZExt;
+  }
+
+  static bool classof(const Value *V) {
+    return isa<Instruction>(V) && classof(cast<Instruction>(V));
+  }
+};
+
 //===----------------------------------------------------------------------===//
 //                               CmpInst Class
 //===----------------------------------------------------------------------===//
diff --git a/llvm/include/llvm/IR/Instruction.h b/llvm/include/llvm/IR/Instruction.h
index 69c3af5b76103f6..781a69e586a9e6b 100644
--- a/llvm/include/llvm/IR/Instruction.h
+++ b/llvm/include/llvm/IR/Instruction.h
@@ -407,12 +407,19 @@ class Instruction : public User,
   /// which supports this flag. See LangRef.html for the meaning of this flag.
   void setIsExact(bool b = true);
 
+  /// Set or clear the nneg flag on this instruction, which must be a zext
+  /// instruction.
+  void setNonNeg(bool b = true);
+
   /// Determine whether the no unsigned wrap flag is set.
   bool hasNoUnsignedWrap() const LLVM_READONLY;
 
   /// Determine whether the no signed wrap flag is set.
   bool hasNoSignedWrap() const LLVM_READONLY;
 
+  /// Determine whether the the nneg flag is set.
+  bool hasNonNeg() const LLVM_READONLY;
+
   /// Return true if this operator has flags which may cause this instruction
   /// to evaluate to poison despite having non-poison inputs.
   bool hasPoisonGeneratingFlags() const LLVM_READONLY;
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index 466bdebc001f589..02e1a1dce3c01b5 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -565,6 +565,7 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(nsw);
   KEYWORD(exact);
   KEYWORD(inbounds);
+  KEYWORD(nneg);
   KEYWORD(inrange);
   KEYWORD(addrspace);
   KEYWORD(section);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 04eabc94cfc6abe..2416814c53c0156 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -6381,8 +6381,17 @@ int LLParser::parseInstruction(Instruction *&Inst, BasicBlock *BB,
   }
 
   // Casts.
+  case lltok::kw_zext: {
+    bool NonNeg = EatIfPresent(lltok::kw_nneg);
+    bool Res = parseCast(Inst, PFS, KeywordVal);
+    if (NonNeg)
+      Inst->setNonNeg();
+    if (Res != 0) {
+      return Res;
+    }
+    return 0;
+  }
   case lltok::kw_trunc:
-  case lltok::kw_zext:
   case lltok::kw_sext:
   case lltok::kw_fptrunc:
   case lltok::kw_fpext:
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index e56291859022eec..28ba79e48e7697d 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -4875,12 +4875,13 @@ Error BitcodeReader::parseFunctionBody(Function *F) {
       Value *Op;
       unsigned OpTypeID;
       if (getValueTypePair(Record, OpNum, NextValueNo, Op, OpTypeID, CurBB) ||
-          OpNum+2 != Record.size())
+          OpNum + 1 > Record.size())
         return error("Invalid record");
 
-      ResTypeID = Record[OpNum];
+      ResTypeID = Record[OpNum++];
       Type *ResTy = getTypeByID(ResTypeID);
-      int Opc = getDecodedCastOpcode(Record[OpNum + 1]);
+      int Opc = getDecodedCastOpcode(Record[OpNum++]);
+
       if (Opc == -1 || !ResTy)
         return error("Invalid record");
       Instruction *Temp = nullptr;
@@ -4892,10 +4893,14 @@ Error BitcodeReader::parseFunctionBody(Function *F) {
         }
       } else {
         auto CastOp = (Instruction::CastOps)Opc;
+
         if (!CastInst::castIsValid(CastOp, Op, ResTy))
           return error("Invalid cast");
         I = CastInst::Create(CastOp, Op, ResTy);
       }
+      if (OpNum < Record.size() && isa<NonNegInstruction>(I) &&
+          (Record[OpNum] & (1 << bitc::NNI_NON_NEG)))
+        I->setNonNeg(true);
       InstructionList.push_back(I);
       break;
     }
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index e991d055f33474b..75ddc058c5bcd79 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -123,6 +123,7 @@ enum {
   FUNCTION_INST_BINOP_ABBREV,
   FUNCTION_INST_BINOP_FLAGS_ABBREV,
   FUNCTION_INST_CAST_ABBREV,
+  FUNCTION_INST_CAST_FLAGS_ABBREV,
   FUNCTION_INST_RET_VOID_ABBREV,
   FUNCTION_INST_RET_VAL_ABBREV,
   FUNCTION_INST_UNREACHABLE_ABBREV,
@@ -1549,6 +1550,9 @@ static uint64_t getOptimizationFlags(const Value *V) {
       Flags |= bitc::AllowContract;
     if (FPMO->hasApproxFunc())
       Flags |= bitc::ApproxFunc;
+  } else if (const auto *NNI = dyn_cast<NonNegInstruction>(V)) {
+    if (NNI->hasNonNeg())
+      Flags |= 1 << bitc::NNI_NON_NEG;
   }
 
   return Flags;
@@ -2825,6 +2829,12 @@ void ModuleBitcodeWriter::writeInstruction(const Instruction &I,
         AbbrevToUse = FUNCTION_INST_CAST_ABBREV;
       Vals.push_back(VE.getTypeID(I.getType()));
       Vals.push_back(getEncodedCastOpcode(I.getOpcode()));
+      uint64_t Flags = getOptimizationFlags(&I);
+      if (Flags != 0) {
+        if (AbbrevToUse == FUNCTION_INST_CAST_ABBREV)
+          AbbrevToUse = FUNCTION_INST_CAST_FLAGS_ABBREV;
+        Vals.push_back(Flags);
+      }
     } else {
       assert(isa<BinaryOperator>(I) && "Unknown instruction!");
       Code = bitc::FUNC_CODE_INST_BINOP;
@@ -3646,6 +3656,18 @@ void ModuleBitcodeWriter::writeBlockInfo() {
         FUNCTION_INST_CAST_ABBREV)
       llvm_unreachable("Unexpected abbrev ordering!");
   }
+  { // INST_CAST_FLAGS abbrev for FUNCTION_BLOCK.
+    auto Abbv = std::make_shared<BitCodeAbbrev>();
+    Abbv->Add(BitCodeAbbrevOp(bitc::FUNC_CODE_INST_CAST));
+    Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::VBR, 6)); // OpVal
+    Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed,    // dest ty
+                              VE.computeBitsRequiredForTypeIndicies()));
+    Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 4)); // opc
+    Abbv->Add(BitCodeAbbrevOp(BitCodeAbbrevOp::Fixed, 8)); // flags
+    if (Stream.EmitBlockInfoAbbrev(bitc::FUNCTION_BLOCK_ID, Abbv) !=
+        FUNCTION_INST_CAST_FLAGS_ABBREV)
+      llvm_unreachable("Unexpected abbrev ordering!");
+  }
 
   { // INST_RET abbrev for FUNCTION_BLOCK.
     auto Abbv = std::make_shared<BitCodeAbbrev>();
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index e190d82127908db..abcee3f599da633 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1347,6 +1347,9 @@ static void WriteOptimizationInfo(raw_ostream &Out, const User *U) {
   } else if (const GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
     if (GEP->isInBounds())
       Out << " inbounds";
+  } else if (const auto *NNI = dyn_cast<NonNegInstruction>(U)) {
+    if (NNI->hasNonNeg())
+      Out << " nneg";
   }
 }
 
diff --git a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp
index b497951a598cc50..e7c8d1090d95ac1 100644
--- a/llvm/lib/IR/Instruction.cpp
+++ b/llvm/lib/IR/Instruction.cpp
@@ -172,6 +172,12 @@ void Instruction::setIsExact(bool b) {
   cast<PossiblyExactOperator>(this)->setIsExact(b);
 }
 
+void Instruction::setNonNeg(bool b) {
+  assert(isa<NonNegInstruction>(this) && "Must be zext");
+  SubclassOptionalData = (SubclassOptionalData & ~NonNegInstruction::NonNeg) |
+                         (b * NonNegInstruction::NonNeg);
+}
+
 bool Instruction::hasNoUnsignedWrap() const {
   return cast<OverflowingBinaryOperator>(this)->hasNoUnsignedWrap();
 }
@@ -180,6 +186,11 @@ bool Instruction::hasNoSignedWrap() const {
   return cast<OverflowingBinaryOperator>(this)->hasNoSignedWrap();
 }
 
+bool Instruction::hasNonNeg() const {
+  assert(isa<NonNegInstruction>(this) && "Must be zext");
+  return (SubclassOptionalData & NonNegInstruction::NonNeg) != 0;
+}
+
 bool Instruction::hasPoisonGeneratingFlags() const {
   return cast<Operator>(this)->hasPoisonGeneratingFlags();
 }
@@ -204,7 +215,12 @@ void Instruction::dropPoisonGeneratingFlags() {
   case Instruction::GetElementPtr:
     cast<GetElementPtrInst>(this)->setIsInBounds(false);
     break;
+
+  case Instruction::ZExt:
+    setNonNeg(false);
+    break;
   }
+
   if (isa<FPMathOperator>(this)) {
     setHasNoNaNs(false);
     setHasNoInfs(false);
@@ -379,6 +395,10 @@ void Instruction::copyIRFlags(const Value *V, bool IncludeWrapFlags) {
   if (auto *SrcGEP = dyn_cast<GetElementPtrInst>(V))
     if (auto *DestGEP = dyn_cast<GetElementPtrInst>(this))
       DestGEP->setIsInBounds(SrcGEP->isInBounds() || DestGEP->isInBounds());
+
+  if (auto *NNI = dyn_cast<NonNegInstruction>(V))
+    if (isa<NonNegInstruction>(this))
+      setNonNeg(NNI->hasNonNeg());
 }
 
 void Instruction::andIRFlags(const Value *V) {
@@ -404,6 +424,10 @@ void Instruction::andIRFlags(const Value *V) {
   if (auto *SrcGEP = dyn_cast<GetElementPtrInst>(V))
     if (auto *DestGEP = dyn_cast<GetElementPtrInst>(this))
       DestGEP->setIsInBounds(SrcGEP->isInBounds() && DestGEP->isInBounds());
+
+  if (auto *NNI = dyn_cast<NonNegInstruction>(V))
+    if (isa<NonNegInstruction>(this))
+      setNonNeg(hasNonNeg() && NNI->hasNonNeg());
 }
 
 const char *Instruction::getOpcodeName(unsigned OpCode) {
diff --git a/llvm/lib/IR/Operator.cpp b/llvm/lib/IR/Operator.cpp
index d2a1f2eb49dafed..a8be4233991d362 100644
--- a/llvm/lib/IR/Operator.cpp
+++ b/llvm/lib/IR/Operator.cpp
@@ -37,6 +37,10 @@ bool Operator::hasPoisonGeneratingFlags() const {
     // Note: inrange exists on constexpr only
     return GEP->isInBounds() || GEP->getInRangeIndex() != std::nullopt;
   }
+  case Instruction::ZExt:
+    if (auto *NNI = dyn_cast<NonNegInstruction>(this))
+      return NNI->hasNonNeg();
+    return false;
   default:
     if (const auto *FP = dyn_cast<FPMathOperator>(this))
       return FP->hasNoNaNs() || FP->hasNoInfs();
diff --git a/llvm/test/Assembler/flags.ll b/llvm/test/Assembler/flags.ll
index 3b54b06b81d4e2b..8331edf52a1699d 100644
--- a/llvm/test/Assembler/flags.ll
+++ b/llvm/test/Assembler/flags.ll
@@ -260,3 +260,9 @@ define i64 @mul_unsigned_ce() {
 	ret i64 mul nuw (i64 ptrtoint (ptr @addr to i64), i64 91)
 }
 
+define i64 @test_zext(i32 %a) {
+; CHECK: %res = zext nneg i32 %a to i64
+  %res = zext nneg i32 %a to i64
+  ret i64 %res
+}
+
diff --git a/llvm/test/Bitcode/flags.ll b/llvm/test/Bitcode/flags.ll
index 6febaa6b40df863..a6e368b7e76327f 100644
--- a/llvm/test/Bitcode/flags.ll
+++ b/llvm/test/Bitcode/flags.ll
@@ -16,6 +16,8 @@ second:                                           ; preds = %first
   %s = add nsw i32 %a, 0                          ; <i32> [#uses=0]
   %us = add nuw nsw i32 %a, 0                     ; <i32> [#uses=0]
   %z = add i32 %a, 0                              ; <i32> [#uses=0]
+  %hh = zext nneg i32 %a to i64
+  %ll = zext i32 %s to i64
   unreachable
 
 first:                                            ; preds = %entry
@@ -24,5 +26,7 @@ first:                                            ; preds = %entry
   %ss = add nsw i32 %a, 0                         ; <i32> [#uses=0]
   %uuss = add nuw nsw i32 %a, 0                   ; <i32> [#uses=0]
   %zz = add i32 %a, 0                             ; <i32> [#uses=0]
+  %kk = zext nneg i32 %a to i64
+  %rr = zext i32 %ss to i64
   br label %second
 }
diff --git a/llvm/test/Transforms/InstCombine/freeze.ll b/llvm/test/Transforms/InstCombine/freeze.ll
index f8e9a757e2bc938..3fde49d08481278 100644
--- a/llvm/test/Transforms/InstCombine/freeze.ll
+++ b/llvm/test/Transforms/InstCombine/freeze.ll
@@ -1116,6 +1116,17 @@ define i32 @freeze_ctpop(i32 %x) {
   ret i32 %fr
 }
 
+define i32 @freeze_zext_nneg(i8 %x) {
+; CHECK-LABEL: @freeze_zext_nneg(
+; CHECK-NEXT:    [[X_FR:%.*]] = freeze i8 [[X:%.*]]
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext i8 [[X_FR]] to i32
+; CHECK-NEXT:    ret i32 [[ZEXT]]
+;
+  %zext = zext nneg i8 %x to i32
+  %fr = freeze i32 %zext
+  ret i32 %fr
+}
+
 !0 = !{}
 !1 = !{i64 4}
 !2 = !{i32 0, i32 100}
diff --git a/llvm/test/Transforms/SimplifyCFG/HoistCode.ll b/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
index 4088ecfc818982f..08cf6cd5be80cf7 100644
--- a/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
+++ b/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
@@ -94,3 +94,33 @@ end:
   %cond = phi fast float [ 0.0, %bb0 ], [ %x, %bb1 ], [ %x, %bb2 ]
   ret float %cond
 }
+
+define i32 @hoist_zext_flags_preserve(i1 %C, i8 %x) {
+; CHECK-LABEL: @hoist_zext_flags_preserve(
+; CHECK-NEXT:  common.ret:
+; CHECK-NEXT:    [[Z1:%.*]] = zext nneg i8 [[X:%.*]] to i32
+; CHECK-NEXT:    ret i32 [[Z1]]
+;
+  br i1 %C, label %T, label %F
+T:
+  %z1 = zext nneg i8 %x to i32
+  ret i32 %z1
+F:
+  %z2 = zext nneg i8 %x to i32
+  ret i32 %z2
+}
+
+define i32 @hoist_zext_flags_drop(i1 %C, i8 %x) {
+; CHECK-LABEL: @hoist_zext_flags_drop(
+; CHECK-NEXT:  common.ret:
+; CHECK-NEXT:    [[Z1:%.*]] = zext i8 [[X:%.*]] to i32
+; CHECK-NEXT:    ret i32 [[Z1]]
+;
+  br i1 %C, label %T, label %F
+T:
+  %z1 = zext nneg i8 %x to i32
+  ret i32 %z1
+F:
+  %z2 = zext i8 %x to i32
+  ret i32 %z2
+}



More information about the llvm-commits mailing list