[llvm] [IR] Add `samesign` flag to icmp instruction (PR #111419)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 7 12:03:06 PDT 2024


https://github.com/elhewaty created https://github.com/llvm/llvm-project/pull/111419

Inspired by https://discourse.llvm.org/t/rfc-signedness-independent-icmps/81423


>From 9671150db840fe3960f12ded101ffce85fad18a2 Mon Sep 17 00:00:00 2001
From: Mohamed Atef <mohamedatef1698 at gmail.com>
Date: Mon, 7 Oct 2024 21:54:04 +0300
Subject: [PATCH] [IR] Add `samesign` flag to icmp instruction

---
 llvm/docs/LangRef.rst                         |  4 +++
 llvm/include/llvm/AsmParser/LLToken.h         |  1 +
 llvm/include/llvm/Bitcode/LLVMBitCodes.h      |  4 +++
 llvm/include/llvm/IR/InstrTypes.h             | 22 ++++++++++++++
 .../Utils/ScalarEvolutionExpander.h           |  1 +
 llvm/lib/AsmParser/LLLexer.cpp                |  1 +
 llvm/lib/AsmParser/LLParser.cpp               | 10 +++++--
 llvm/lib/Bitcode/Reader/BitcodeReader.cpp     |  2 ++
 llvm/lib/Bitcode/Writer/BitcodeWriter.cpp     |  3 ++
 llvm/lib/IR/AsmWriter.cpp                     |  3 ++
 llvm/lib/IR/Instruction.cpp                   | 12 ++++++++
 .../Utils/ScalarEvolutionExpander.cpp         |  5 ++++
 llvm/test/Assembler/flags.ll                  | 12 ++++++++
 llvm/test/Bitcode/flags.ll                    |  4 +++
 llvm/test/Transforms/InstCombine/freeze.ll    | 11 +++++++
 llvm/test/Transforms/SimplifyCFG/HoistCode.ll | 30 +++++++++++++++++++
 16 files changed, 123 insertions(+), 2 deletions(-)

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 3f39d58b322a4f..c5e88b9fb244e3 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -12239,6 +12239,7 @@ Syntax:
 ::
 
       <result> = icmp <cond> <ty> <op1>, <op2>   ; yields i1 or <N x i1>:result
+      <result> = icmp samesign <cond> <ty> <op1>, <op2>   ; yields i1 or <N x i1>:result
 
 Overview:
 """""""""
@@ -12308,6 +12309,9 @@ If the operands are integer vectors, then they are compared element by
 element. The result is an ``i1`` vector with the same number of elements
 as the values being compared. Otherwise, the result is an ``i1``.
 
+If the ``samesign`` keyword is present and the operands are not of the
+same sign then the result a :ref:`poison value <poisonvalues>`.
+
 Example:
 """"""""
 
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 19029842a572a4..178c911120b4ce 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -114,6 +114,7 @@ enum Kind {
   kw_disjoint,
   kw_inbounds,
   kw_nneg,
+  kw_samesign,
   kw_inrange,
   kw_addrspace,
   kw_section,
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index cbd92fd52fc75a..9977f7756c8dfc 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -537,6 +537,10 @@ enum GetElementPtrOptionalFlags {
   GEP_NUW = 2,
 };
 
+/// PossiblySameSignOptionalFlags - Flags for serializing
+/// PossiblySameSignInst's SubclassOptionalData contents.
+enum PossiblySameSignInstOptionalFlags { PSSI_SAME_SIGN = 0 };
+
 /// Encoded AtomicOrdering values.
 enum AtomicOrderingCodes {
   ORDERING_NOTATOMIC = 0,
diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 4852f64d0977fb..107f296702fddb 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -1033,6 +1033,28 @@ class CmpInst : public Instruction {
   }
 };
 
+/// An icmp instruction, which can be marked as "samesign", indicating that the
+/// two operands have the same sign. This means that we can convert "slt/ult"
+/// to "ult", which enables more optimizations.
+class PossiblySameSignInst : public CmpInst {
+public:
+  enum { SameSign = (1 << 0) };
+
+  void setSameSign(bool B) {
+    SubclassOptionalData = (SubclassOptionalData & ~SameSign) | (B * SameSign);
+  }
+
+  bool hasSameSign() const { return SubclassOptionalData & SameSign; }
+
+  static bool classof(const Instruction *I) {
+    return I->getOpcode() == Instruction::ICmp;
+  }
+
+  static bool classof(const Value *V) {
+    return isa<Instruction>(V) && classof(cast<Instruction>(V));
+  }
+};
+
 // FIXME: these are redundant if CmpInst < BinaryOperator
 template <>
 struct OperandTraits<CmpInst> : public FixedNumOperandTraits<CmpInst, 2> {
diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
index 0af3efeacd040c..87eb947e4ab304 100644
--- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
+++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
@@ -47,6 +47,7 @@ struct PoisonFlags {
   unsigned Exact : 1;
   unsigned Disjoint : 1;
   unsigned NNeg : 1;
+  unsigned SameSign : 1;
   GEPNoWrapFlags GEPNW;
 
   PoisonFlags(const Instruction *I);
diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp
index a3e47da77fe776..cc0c18d98da358 100644
--- a/llvm/lib/AsmParser/LLLexer.cpp
+++ b/llvm/lib/AsmParser/LLLexer.cpp
@@ -571,6 +571,7 @@ lltok::Kind LLLexer::LexIdentifier() {
   KEYWORD(disjoint);
   KEYWORD(inbounds);
   KEYWORD(nneg);
+  KEYWORD(samesign);
   KEYWORD(inrange);
   KEYWORD(addrspace);
   KEYWORD(section);
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index d84521d2e6e10d..d6f3a92b32daea 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -6950,8 +6950,14 @@ int LLParser::parseInstruction(Instruction *&Inst, BasicBlock *BB,
   case lltok::kw_and:
   case lltok::kw_xor:
     return parseLogical(Inst, PFS, KeywordVal);
-  case lltok::kw_icmp:
-    return parseCompare(Inst, PFS, KeywordVal);
+  case lltok::kw_icmp: {
+    bool SameSign = EatIfPresent(lltok::kw_samesign);
+    if (parseCompare(Inst, PFS, KeywordVal))
+      return true;
+    if (SameSign)
+      cast<PossiblySameSignInst>(Inst)->setSameSign(true);
+    return false;
+  }
   case lltok::kw_fcmp: {
     FastMathFlags FMF = EatFastMathFlagsIfPresent();
     int Res = parseCompare(Inst, PFS, KeywordVal);
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 6f997510b03609..6bddde528d81e1 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -5462,6 +5462,8 @@ Error BitcodeReader::parseFunctionBody(Function *F) {
         if (!CmpInst::isIntPredicate(PredVal))
           return error("Invalid icmp predicate");
         I = new ICmpInst(PredVal, LHS, RHS);
+        if (Record[OpNum] & (1 << bitc::PSSI_SAME_SIGN))
+          cast<PossiblySameSignInst>(I)->setSameSign(true);
       }
 
       ResTypeID = getVirtualTypeID(I->getType()->getScalarType());
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index d9086bfebbd2a9..add97e103733dd 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -1714,6 +1714,9 @@ static uint64_t getOptimizationFlags(const Value *V) {
       Flags |= 1 << bitc::GEP_NUSW;
     if (GEP->hasNoUnsignedWrap())
       Flags |= 1 << bitc::GEP_NUW;
+  } else if (const auto *PSSI = dyn_cast<PossiblySameSignInst>(V)) {
+    if (PSSI->hasSameSign())
+      Flags |= 1 << bitc::PSSI_SAME_SIGN;
   }
 
   return Flags;
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index 280e347739cdb6..99dc55d468fc76 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -1433,6 +1433,9 @@ static void WriteOptimizationInfo(raw_ostream &Out, const User *U) {
       Out << " nuw";
     if (TI->hasNoSignedWrap())
       Out << " nsw";
+  } else if (const auto *PSSI = dyn_cast<PossiblySameSignInst>(U)) {
+    if (PSSI->hasSameSign())
+      Out << " samesign";
   }
 }
 
diff --git a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp
index b1c2b0200c8269..cf269f9ba56893 100644
--- a/llvm/lib/IR/Instruction.cpp
+++ b/llvm/lib/IR/Instruction.cpp
@@ -441,6 +441,10 @@ void Instruction::dropPoisonGeneratingFlags() {
     cast<TruncInst>(this)->setHasNoUnsignedWrap(false);
     cast<TruncInst>(this)->setHasNoSignedWrap(false);
     break;
+
+  case Instruction::ICmp:
+    cast<PossiblySameSignInst>(this)->setSameSign(false);
+    break;
   }
 
   if (isa<FPMathOperator>(this)) {
@@ -654,6 +658,10 @@ void Instruction::copyIRFlags(const Value *V, bool IncludeWrapFlags) {
   if (auto *NNI = dyn_cast<PossiblyNonNegInst>(V))
     if (isa<PossiblyNonNegInst>(this))
       setNonNeg(NNI->hasNonNeg());
+
+  if (auto *SrcICmp = dyn_cast<PossiblySameSignInst>(V))
+    if (auto *DestICmp = dyn_cast<PossiblySameSignInst>(this))
+      DestICmp->setSameSign(SrcICmp->hasSameSign());
 }
 
 void Instruction::andIRFlags(const Value *V) {
@@ -695,6 +703,10 @@ void Instruction::andIRFlags(const Value *V) {
   if (auto *NNI = dyn_cast<PossiblyNonNegInst>(V))
     if (isa<PossiblyNonNegInst>(this))
       setNonNeg(hasNonNeg() && NNI->hasNonNeg());
+
+  if (auto *SrcICmp = dyn_cast<PossiblySameSignInst>(V))
+    if (auto *DestICmp = dyn_cast<PossiblySameSignInst>(this))
+      DestICmp->setSameSign(DestICmp->hasSameSign() && SrcICmp->hasSameSign());
 }
 
 const char *Instruction::getOpcodeName(unsigned OpCode) {
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 1088547e1f3efe..17878a4a96a125 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -49,6 +49,7 @@ PoisonFlags::PoisonFlags(const Instruction *I) {
   Exact = false;
   Disjoint = false;
   NNeg = false;
+  SameSign = false;
   GEPNW = GEPNoWrapFlags::none();
   if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(I)) {
     NUW = OBO->hasNoUnsignedWrap();
@@ -66,6 +67,8 @@ PoisonFlags::PoisonFlags(const Instruction *I) {
   }
   if (auto *GEP = dyn_cast<GetElementPtrInst>(I))
     GEPNW = GEP->getNoWrapFlags();
+  if (auto *PSSI = dyn_cast<PossiblySameSignInst>(I))
+    SameSign = PSSI->hasSameSign();
 }
 
 void PoisonFlags::apply(Instruction *I) {
@@ -85,6 +88,8 @@ void PoisonFlags::apply(Instruction *I) {
   }
   if (auto *GEP = dyn_cast<GetElementPtrInst>(I))
     GEP->setNoWrapFlags(GEPNW);
+  if (auto *PSSI = dyn_cast<PossiblySameSignInst>(I))
+    PSSI->setSameSign(SameSign);
 }
 
 /// ReuseOrCreateCast - Arrange for there to be a cast of V to Ty at IP,
diff --git a/llvm/test/Assembler/flags.ll b/llvm/test/Assembler/flags.ll
index 84209500d27a5d..acc8874aef4438 100644
--- a/llvm/test/Assembler/flags.ll
+++ b/llvm/test/Assembler/flags.ll
@@ -312,6 +312,18 @@ define <2 x i32> @test_trunc_both_reversed_vector(<2 x i64> %a) {
   ret <2 x i32> %res
 }
 
+define i1 @test_icmp_samesign(i32 %a, i32 %b) {
+  ; CHECK: %res = icmp samesign ult i32 %a, %b
+  %res = icmp samesign ult i32 %a, %b
+  ret i1 %res
+}
+
+define <2 x i1> @test_icmp_samesign2(<2 x i32> %a, <2 x i32> %b) {
+  ; CHECK: %res = icmp samesign ult <2 x i32> %a, %b
+  %res = icmp samesign ult <2 x i32> %a, %b
+  ret <2 x i1> %res
+}
+
 define ptr @gep_nuw(ptr %p, i64 %idx) {
 ; CHECK: %gep = getelementptr nuw i8, ptr %p, i64 %idx
   %gep = getelementptr nuw i8, ptr %p, i64 %idx
diff --git a/llvm/test/Bitcode/flags.ll b/llvm/test/Bitcode/flags.ll
index fd56694ccceb2c..99988c9ba3d3d4 100644
--- a/llvm/test/Bitcode/flags.ll
+++ b/llvm/test/Bitcode/flags.ll
@@ -30,6 +30,8 @@ second:                                           ; preds = %first
   %tsv = trunc nsw <2 x i32> %aa to <2 x i16>
   %tusv = trunc nuw nsw <2 x i32> %aa to <2 x i16>
   %tv = trunc <2 x i32> %aa to <2 x i16>
+  %ii = icmp samesign ult i32 %a, %z
+  %iv = icmp samesign ult <2 x i32> %aa, %aa
   unreachable
 
 first:                                                    ; preds = %entry
@@ -53,5 +55,7 @@ first:                                                    ; preds = %entry
   %ttsv = trunc nsw <2 x i32> %aa to <2 x i16>
   %ttusv = trunc nuw nsw <2 x i32> %aa to <2 x i16>
   %ttv = trunc <2 x i32> %aa to <2 x i16>
+  %icm = icmp samesign ult i32 %a, %zz
+  %icv = icmp samesign ult <2 x i32> %aa, %aa
   br label %second
 }
diff --git a/llvm/test/Transforms/InstCombine/freeze.ll b/llvm/test/Transforms/InstCombine/freeze.ll
index 5fedb1f8575035..23785eba055544 100644
--- a/llvm/test/Transforms/InstCombine/freeze.ll
+++ b/llvm/test/Transforms/InstCombine/freeze.ll
@@ -1182,6 +1182,17 @@ define ptr @propagate_drop_flags_gep_nuw(ptr %p) {
   ret ptr %gep.fr
 }
 
+define i1 @propagate_drop_flags_icmp(i32 %a, i32 %b) {
+; CHECK-LABEL: @propagate_drop_flags_icmp(
+; CHECK-NEXT:    [[A_FR:%.*]] = freeze i32 [[A:%.*]]
+; CHECK-NEXT:    [[RET:%.*]] = icmp ult i32 [[A_FR]], 3
+; CHECK-NEXT:    ret i1 [[RET]]
+;
+  %ret = icmp samesign ult i32 %a, 3
+  %ret.fr = freeze i1 %ret
+  ret i1 %ret.fr
+}
+
 declare i32 @llvm.umax.i32(i32 %a, i32 %b)
 
 define i32 @freeze_call_with_range_attr(i32 %a) {
diff --git a/llvm/test/Transforms/SimplifyCFG/HoistCode.ll b/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
index e6a255a4b8f086..fe0b48028a3b62 100644
--- a/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
+++ b/llvm/test/Transforms/SimplifyCFG/HoistCode.ll
@@ -275,3 +275,33 @@ F:
   %gep2 = getelementptr nuw i8, ptr %p, i64 1
   ret ptr %gep2
 }
+
+define i1 @hoist_icmp_flags_preserve(i1 %C, i32 %x, i32 %y) {
+; CHECK-LABEL: @hoist_icmp_flags_preserve(
+; CHECK-NEXT:  common.ret:
+; CHECK-NEXT:    [[Z1:%.*]] = icmp samesign ult i32 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    ret i1 [[Z1]]
+;
+  br i1 %C, label %T, label %F
+T:
+  %z1 = icmp samesign ult i32 %x, %y
+  ret i1 %z1
+F:
+  %z2 = icmp samesign ult i32 %x, %y
+  ret i1 %z2
+}
+
+define i1 @hoist_icmp_flags_drop(i1 %C, i32 %x, i32 %y) {
+; CHECK-LABEL: @hoist_icmp_flags_drop(
+; CHECK-NEXT:  common.ret:
+; CHECK-NEXT:    [[Z1:%.*]] = icmp ult i32 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    ret i1 [[Z1]]
+;
+  br i1 %C, label %T, label %F
+T:
+  %z1 = icmp ult i32 %x, %y
+  ret i1 %z1
+F:
+  %z2 = icmp samesign ult i32 %x, %y
+  ret i1 %z2
+}



More information about the llvm-commits mailing list