[llvm] d9962c4 - [IR] Add disjoint flag for Or instructions. (#72583)

via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 24 08:49:24 PST 2023


Author: Craig Topper
Date: 2023-11-24T08:49:19-08:00
New Revision: d9962c400f970d17396e84c1a55cdbea29a7c893

URL: https://github.com/llvm/llvm-project/commit/d9962c400f970d17396e84c1a55cdbea29a7c893
DIFF: https://github.com/llvm/llvm-project/commit/d9962c400f970d17396e84c1a55cdbea29a7c893.diff

LOG: [IR] Add disjoint flag for Or instructions. (#72583)

This flag indicates that every bit is known to be zero in at least one
of the inputs. This allows the Or to be treated as an Add since there is
no possibility of a carry from any bit.

If the flag is present and this property does not hold, the result is
poison.

This makes it easier to reverse the InstCombine transform that turns Add
into Or.

This is inspired by a comment here
https://github.com/llvm/llvm-project/pull/71955#discussion_r1391614578

Discourse thread
https://discourse.llvm.org/t/rfc-add-or-disjoint-flag/75036

Added: 
    

Modified: 
    llvm/docs/LangRef.rst
    llvm/include/llvm/AsmParser/LLToken.h
    llvm/include/llvm/Bitcode/LLVMBitCodes.h
    llvm/include/llvm/IR/InstrTypes.h
    llvm/lib/AsmParser/LLLexer.cpp
    llvm/lib/AsmParser/LLParser.cpp
    llvm/lib/Bitcode/Reader/BitcodeReader.cpp
    llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
    llvm/lib/IR/AsmWriter.cpp
    llvm/lib/IR/Instruction.cpp
    llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
    llvm/test/Assembler/flags.ll
    llvm/test/Bitcode/compatibility.ll
    llvm/test/Bitcode/flags.ll
    llvm/test/Transforms/InstCombine/freeze.ll
    llvm/test/Transforms/InstCombine/or.ll
    llvm/test/Transforms/SimplifyCFG/HoistCode.ll

Removed: 
    


################################################################################
diff  --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index bc1eab1e0b7a07f..e448c5ed5c5d947 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -9981,6 +9981,7 @@ Syntax:
 ::
 
       <result> = or <ty> <op1>, <op2>   ; yields ty:result
+      <result> = or disjoint <ty> <op1>, <op2>   ; yields ty:result
 
 Overview:
 """""""""
@@ -10012,6 +10013,12 @@ The truth table used for the '``or``' instruction is:
 |   1 |   1 |   1 |
 +-----+-----+-----+
 
+``disjoint`` means that for each bit, that bit is zero in at least one of the
+inputs. This allows the Or to be treated as an Add since no carry can occur from
+any bit. If the disjoint keyword is present, the result value of the ``or`` is a
+:ref:`poison value <poisonvalues>` if both inputs have a one in the same bit
+position. For vectors, only the element containing the bit is poison.
+
 Example:
 """"""""
 

diff  --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 1eb3986e65e9154..0683291faae72c4 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -109,6 +109,7 @@ enum Kind {
   kw_nuw,
   kw_nsw,
   kw_exact,
+  kw_disjoint,
   kw_inbounds,
   kw_nneg,
   kw_inrange,

diff  --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 9fa70c0671ef340..99a41fa107d0811 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -512,6 +512,10 @@ enum PossiblyNonNegInstOptionalFlags { PNNI_NON_NEG = 0 };
 /// PossiblyExactOperator's SubclassOptionalData contents.
 enum PossiblyExactOperatorOptionalFlags { PEO_EXACT = 0 };
 
+/// PossiblyDisjointInstOptionalFlags - Flags for serializing
+/// PossiblyDisjointInst's SubclassOptionalData contents.
+enum PossiblyDisjointInstOptionalFlags { PDI_DISJOINT = 0 };
+
 /// Encoded AtomicOrdering values.
 enum AtomicOrderingCodes {
   ORDERING_NOTATOMIC = 0,

diff  --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index fc5e228168a058b..ddae3e4f43f48dd 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -415,6 +415,29 @@ struct OperandTraits<BinaryOperator> :
 
 DEFINE_TRANSPARENT_OPERAND_ACCESSORS(BinaryOperator, Value)
 
+/// An or instruction, which can be marked as "disjoint", indicating that the
+/// inputs don't have a 1 in the same bit position. Meaning this instruction
+/// can also be treated as an add.
+class PossiblyDisjointInst : public BinaryOperator {
+public:
+  enum { IsDisjoint = (1 << 0) };
+
+  void setIsDisjoint(bool B) {
+    SubclassOptionalData =
+        (SubclassOptionalData & ~IsDisjoint) | (B * IsDisjoint);
+  }
+
+  bool isDisjoint() const { return SubclassOptionalData & IsDisjoint; }
+
+  static bool classof(const Instruction *I) {
+    return I->getOpcode() == Instruction::Or;
+  }
+
+  static bool classof(const Value *V) {
+    return isa<Instruction>(V) && classof(cast<Instruction>(V));
+  }
+};
+
 //===----------------------------------------------------------------------===//
 //                               CastInst Class
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index fb111042f674d10..09a205c445dbede 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -564,6 +564,7 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(nuw);
   KEYWORD(nsw);
   KEYWORD(exact);
+  KEYWORD(disjoint);
   KEYWORD(inbounds);
   KEYWORD(nneg);
   KEYWORD(inrange);

diff  --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 2d2c1bf65311c01..d236b6cfa9000cb 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -6370,8 +6370,15 @@ int LLParser::parseInstruction(Instruction *&Inst, BasicBlock *BB,
   case lltok::kw_srem:
     return parseArithmetic(Inst, PFS, KeywordVal,
                            /*IsFP*/ false);
+  case lltok::kw_or: {
+    bool Disjoint = EatIfPresent(lltok::kw_disjoint);
+    if (parseLogical(Inst, PFS, KeywordVal))
+      return true;
+    if (Disjoint)
+      cast<PossiblyDisjointInst>(Inst)->setIsDisjoint(true);
+    return false;
+  }
   case lltok::kw_and:
-  case lltok::kw_or:
   case lltok::kw_xor:
     return parseLogical(Inst, PFS, KeywordVal);
   case lltok::kw_icmp:

diff  --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 788906a3d56c8b9..e4c3770946b3abc 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -4866,12 +4866,14 @@ Error BitcodeReader::parseFunctionBody(Function *F) {
                    Opc == Instruction::AShr) {
           if (Record[OpNum] & (1 << bitc::PEO_EXACT))
             cast<BinaryOperator>(I)->setIsExact(true);
+        } else if (Opc == Instruction::Or) {
+          if (Record[OpNum] & (1 << bitc::PDI_DISJOINT))
+            cast<PossiblyDisjointInst>(I)->setIsDisjoint(true);
         } else if (isa<FPMathOperator>(I)) {
           FastMathFlags FMF = getDecodedFastMathFlags(Record[OpNum]);
           if (FMF.any())
             I->setFastMathFlags(FMF);
         }
-
       }
       break;
     }

diff  --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 9c21cc69179e555..8239775d04865ab 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -1540,6 +1540,9 @@ static uint64_t getOptimizationFlags(const Value *V) {
   } else if (const auto *PEO = dyn_cast<PossiblyExactOperator>(V)) {
     if (PEO->isExact())
       Flags |= 1 << bitc::PEO_EXACT;
+  } else if (const auto *PDI = dyn_cast<PossiblyDisjointInst>(V)) {
+    if (PDI->isDisjoint())
+      Flags |= 1 << bitc::PDI_DISJOINT;
   } else if (const auto *FPMO = dyn_cast<FPMathOperator>(V)) {
     if (FPMO->hasAllowReassoc())
       Flags |= bitc::AllowReassoc;

diff  --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index c37adb423397059..601b636f306adeb 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1355,6 +1355,10 @@ static void WriteOptimizationInfo(raw_ostream &Out, const User *U) {
                dyn_cast<PossiblyExactOperator>(U)) {
     if (Div->isExact())
       Out << " exact";
+  } else if (const PossiblyDisjointInst *PDI =
+                 dyn_cast<PossiblyDisjointInst>(U)) {
+    if (PDI->isDisjoint())
+      Out << " disjoint";
   } else if (const GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
     if (GEP->isInBounds())
       Out << " inbounds";

diff  --git a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp
index 7449692f05d7bf9..97797e99371d600 100644
--- a/llvm/lib/IR/Instruction.cpp
+++ b/llvm/lib/IR/Instruction.cpp
@@ -357,6 +357,10 @@ void Instruction::dropPoisonGeneratingFlags() {
     cast<PossiblyExactOperator>(this)->setIsExact(false);
     break;
 
+  case Instruction::Or:
+    cast<PossiblyDisjointInst>(this)->setIsDisjoint(false);
+    break;
+
   case Instruction::GetElementPtr:
     cast<GetElementPtrInst>(this)->setIsInBounds(false);
     break;
@@ -532,6 +536,10 @@ void Instruction::copyIRFlags(const Value *V, bool IncludeWrapFlags) {
     if (isa<PossiblyExactOperator>(this))
       setIsExact(PE->isExact());
 
+  if (auto *SrcPD = dyn_cast<PossiblyDisjointInst>(V))
+    if (auto *DestPD = dyn_cast<PossiblyDisjointInst>(this))
+      DestPD->setIsDisjoint(SrcPD->isDisjoint());
+
   // Copy the fast-math flags.
   if (auto *FP = dyn_cast<FPMathOperator>(V))
     if (isa<FPMathOperator>(this))
@@ -558,6 +566,10 @@ void Instruction::andIRFlags(const Value *V) {
     if (isa<PossiblyExactOperator>(this))
       setIsExact(isExact() && PE->isExact());
 
+  if (auto *SrcPD = dyn_cast<PossiblyDisjointInst>(V))
+    if (auto *DestPD = dyn_cast<PossiblyDisjointInst>(this))
+      DestPD->setIsDisjoint(DestPD->isDisjoint() && SrcPD->isDisjoint());
+
   if (auto *FP = dyn_cast<FPMathOperator>(V)) {
     if (isa<FPMathOperator>(this)) {
       FastMathFlags FM = getFastMathFlags();

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 69b95c0b35c68e1..fa076098d63cde5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -231,8 +231,11 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     // If either the LHS or the RHS are One, the result is One.
     if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) ||
         SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.One, LHSKnown,
-                             Depth + 1))
+                             Depth + 1)) {
+      // Disjoint flag may not longer hold.
+      I->dropPoisonGeneratingFlags();
       return I;
+    }
     assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?");
     assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?");
 

diff  --git a/llvm/test/Assembler/flags.ll b/llvm/test/Assembler/flags.ll
index 6ab5e1bfb9c4f46..04bddd02f50c814 100644
--- a/llvm/test/Assembler/flags.ll
+++ b/llvm/test/Assembler/flags.ll
@@ -256,3 +256,8 @@ define i64 @test_zext(i32 %a) {
   ret i64 %res
 }
 
+define i64 @test_or(i64 %a, i64 %b) {
+; CHECK: %res = or disjoint i64 %a, %b
+  %res = or disjoint i64 %a, %b
+  ret i64 %res
+}

diff  --git a/llvm/test/Bitcode/compatibility.ll b/llvm/test/Bitcode/compatibility.ll
index 8170f18879aafac..483024a250da027 100644
--- a/llvm/test/Bitcode/compatibility.ll
+++ b/llvm/test/Bitcode/compatibility.ll
@@ -1359,6 +1359,10 @@ define void @instructions.bitwise_binops(i8 %op1, i8 %op2) {
   xor i8 %op1, %op2
   ; CHECK: xor i8 %op1, %op2
 
+  ; disjoint
+  or disjoint i8 %op1, %op2
+  ; CHECK: or disjoint i8 %op1, %op2
+
   ret void
 }
 

diff  --git a/llvm/test/Bitcode/flags.ll b/llvm/test/Bitcode/flags.ll
index a6e368b7e76327f..e3fc827d865d7e2 100644
--- a/llvm/test/Bitcode/flags.ll
+++ b/llvm/test/Bitcode/flags.ll
@@ -18,6 +18,8 @@ second:                                           ; preds = %first
   %z = add i32 %a, 0                              ; <i32> [#uses=0]
   %hh = zext nneg i32 %a to i64
   %ll = zext i32 %s to i64
+  %jj = or disjoint i32 %a, 0
+  %oo = or i32 %a, 0
   unreachable
 
 first:                                            ; preds = %entry
@@ -28,5 +30,7 @@ first:                                            ; preds = %entry
   %zz = add i32 %a, 0                             ; <i32> [#uses=0]
   %kk = zext nneg i32 %a to i64
   %rr = zext i32 %ss to i64
+  %mm = or disjoint i32 %a, 0
+  %nn = or i32 %a, 0
   br label %second
 }

diff  --git a/llvm/test/Transforms/InstCombine/freeze.ll b/llvm/test/Transforms/InstCombine/freeze.ll
index dd9272b4b35f193..da59101d5710cb5 100644
--- a/llvm/test/Transforms/InstCombine/freeze.ll
+++ b/llvm/test/Transforms/InstCombine/freeze.ll
@@ -1127,6 +1127,17 @@ define i32 @freeze_zext_nneg(i8 %x) {
   ret i32 %fr
 }
 
+define i32 @propagate_drop_flags_or(i32 %arg) {
+; CHECK-LABEL: @propagate_drop_flags_or(
+; CHECK-NEXT:    [[ARG_FR:%.*]] = freeze i32 [[ARG:%.*]]
+; CHECK-NEXT:    [[V1:%.*]] = or i32 [[ARG_FR]], 2
+; CHECK-NEXT:    ret i32 [[V1]]
+;
+  %v1 = or disjoint i32 %arg, 2
+  %v1.fr = freeze i32 %v1
+  ret i32 %v1.fr
+}
+
 !0 = !{}
 !1 = !{i64 4}
 !2 = !{i32 0, i32 100}

diff  --git a/llvm/test/Transforms/InstCombine/or.ll b/llvm/test/Transforms/InstCombine/or.ll
index fd53783a06f9deb..642a0282e5f7c08 100644
--- a/llvm/test/Transforms/InstCombine/or.ll
+++ b/llvm/test/Transforms/InstCombine/or.ll
@@ -1576,3 +1576,14 @@ define <4 x i1> @and_or_not_or_logical_vec(<4 x i32> %ap, <4 x i32> %bp) {
   %Z = or <4 x i1> %X, %Y
   ret <4 x i1> %Z
 }
+
+; Make sure SimplifyDemandedBits drops the disjoint flag.
+define i8 @drop_disjoint(i8 %x) {
+; CHECK-LABEL: @drop_disjoint(
+; CHECK-NEXT:    [[B:%.*]] = or i8 [[X:%.*]], 1
+; CHECK-NEXT:    ret i8 [[B]]
+;
+  %a = and i8 %x, -2
+  %b = or disjoint i8 %a, 1
+  ret i8 %b
+}

diff  --git a/llvm/test/Transforms/SimplifyCFG/HoistCode.ll b/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
index 08cf6cd5be80cf7..a081eddfc45660c 100644
--- a/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
+++ b/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
@@ -124,3 +124,33 @@ F:
   %z2 = zext i8 %x to i32
   ret i32 %z2
 }
+
+define i32 @hoist_or_flags_preserve(i1 %C, i32 %x, i32 %y) {
+; CHECK-LABEL: @hoist_or_flags_preserve(
+; CHECK-NEXT:  common.ret:
+; CHECK-NEXT:    [[Z1:%.*]] = or disjoint i32 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    ret i32 [[Z1]]
+;
+  br i1 %C, label %T, label %F
+T:
+  %z1 = or disjoint i32 %x, %y
+  ret i32 %z1
+F:
+  %z2 = or disjoint i32 %x, %y
+  ret i32 %z2
+}
+
+define i32 @hoist_or_flags_drop(i1 %C, i32 %x, i32 %y) {
+; CHECK-LABEL: @hoist_or_flags_drop(
+; CHECK-NEXT:  common.ret:
+; CHECK-NEXT:    [[Z1:%.*]] = or i32 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    ret i32 [[Z1]]
+;
+  br i1 %C, label %T, label %F
+T:
+  %z1 = or i32 %x, %y
+  ret i32 %z1
+F:
+  %z2 = or disjoint i32 %x, %y
+  ret i32 %z2
+}


        


More information about the llvm-commits mailing list