[llvm] [InstCombine] Copy flags of extractelement for extelt -> icmp combine (PR #86366)

Marc Auberer via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 22 18:42:40 PDT 2024


https://github.com/marcauberer updated https://github.com/llvm/llvm-project/pull/86366

>From 9c8eb1e4da8df9069878d3ac968d7ccbed8f7a50 Mon Sep 17 00:00:00 2001
From: Marc Auberer <marc.auberer at chillibits.com>
Date: Fri, 22 Mar 2024 22:42:53 +0100
Subject: [PATCH] [InstCombine] Copy flags of extractelement for extelt -> icmp
 combine

---
 llvm/include/llvm/IR/InstrTypes.h                  | 11 +++++++++++
 llvm/lib/IR/Instructions.cpp                       | 10 ++++++++++
 .../InstCombine/InstCombineVectorOps.cpp           |  4 +++-
 llvm/test/Transforms/InstCombine/scalarization.ll  | 14 ++++++++++++++
 4 files changed, 38 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index e8c2cba8418dc8..d2b33fe9c651a1 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -1058,6 +1058,17 @@ class CmpInst : public Instruction {
   static CmpInst *Create(OtherOps Op, Predicate predicate, Value *S1,
                          Value *S2, const Twine &Name, BasicBlock *InsertAtEnd);
 
+  /// Construct a compare instruction, given the opcode, the predicate,
+  /// the two operands and the instruction to copy the flags from. Optionally
+  /// (if InstBefore is specified) insert the instruction into a BasicBlock
+  /// right before the specified instruction. The specified Instruction is
+  /// allowed to be a dereferenced end iterator. Create a CmpInst
+  static CmpInst *CreateWithCopiedFlags(OtherOps Op, Predicate Pred, Value *S1,
+                                        Value *S2,
+                                        const Instruction *FlagsSource,
+                                        const Twine &Name = "",
+                                        Instruction *InsertBefore = nullptr);
+
   /// Get the opcode casted to the right type
   OtherOps getOpcode() const {
     return static_cast<OtherOps>(Instruction::getOpcode());
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 494d50f89e374c..8ca211e6e2e72f 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -4623,6 +4623,16 @@ CmpInst::Create(OtherOps Op, Predicate predicate, Value *S1, Value *S2,
                       S1, S2, Name);
 }
 
+CmpInst *CmpInst::CreateWithCopiedFlags(OtherOps Op, Predicate Pred, Value *S1,
+                                        Value *S2,
+                                        const Instruction *FlagsSource,
+                                        const Twine &Name,
+                                        Instruction *InsertBefore) {
+  CmpInst *Inst = Create(Op, Pred, S1, S2, Name, InsertBefore);
+  Inst->copyIRFlags(FlagsSource);
+  return Inst;
+}
+
 void CmpInst::swapOperands() {
   if (ICmpInst *IC = dyn_cast<ICmpInst>(this))
     IC->swapOperands();
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index c7f4fb17648c87..c2bece88352e02 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -487,7 +487,9 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
     // extelt (cmp X, Y), Index --> cmp (extelt X, Index), (extelt Y, Index)
     Value *E0 = Builder.CreateExtractElement(X, Index);
     Value *E1 = Builder.CreateExtractElement(Y, Index);
-    return CmpInst::Create(cast<CmpInst>(SrcVec)->getOpcode(), Pred, E0, E1);
+    Instruction *SrcInst = cast<Instruction>(SrcVec);
+    return CmpInst::CreateWithCopiedFlags(cast<CmpInst>(SrcVec)->getOpcode(),
+                                          Pred, E0, E1, SrcInst);
   }
 
   if (auto *I = dyn_cast<Instruction>(SrcVec)) {
diff --git a/llvm/test/Transforms/InstCombine/scalarization.ll b/llvm/test/Transforms/InstCombine/scalarization.ll
index fe6dc526bd50ee..5ab960ece54d9d 100644
--- a/llvm/test/Transforms/InstCombine/scalarization.ll
+++ b/llvm/test/Transforms/InstCombine/scalarization.ll
@@ -341,6 +341,20 @@ define i1 @extractelt_vector_fcmp_constrhs_dynidx(<2 x float> %arg, i32 %idx) {
   ret i1 %ext
 }
 
+define i1 @extractelt_vector_fcmp_copy_flags(<4 x float> %x, <4 x i1> %y) {
+; CHECK-LABEL: @extractelt_vector_fcmp_copy_flags(
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <4 x float> [[X:%.*]], i64 2
+; CHECK-NEXT:    [[TMP2:%.*]] = fcmp nsz arcp oeq float [[TMP1]], 0.000000e+00
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <4 x i1> [[Y:%.*]], i64 2
+; CHECK-NEXT:    [[R:%.*]] = and i1 [[TMP2]], [[TMP3]]
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %cmp = fcmp nsz arcp oeq <4 x float> %x, zeroinitializer
+  %and = and <4 x i1> %cmp, %y
+  %r = extractelement <4 x i1> %and, i32 2
+  ret i1 %r
+}
+
 define i1 @extractelt_vector_fcmp_not_cheap_to_scalarize_multi_use(<2 x float> %arg0, <2 x float> %arg1, <2 x float> %arg2, i32 %idx) {
 ;
 ; CHECK-LABEL: @extractelt_vector_fcmp_not_cheap_to_scalarize_multi_use(



More information about the llvm-commits mailing list