[llvm] 57a8ea8 - [InstCombine] Avoid infinite loop in insert/extract combine

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 14 05:59:00 PDT 2023


Author: Nikita Popov
Date: 2023-06-14T14:58:49+02:00
New Revision: 57a8ea85538503a35d1e04fd8c8ba32aa2ba3f2a

URL: https://github.com/llvm/llvm-project/commit/57a8ea85538503a35d1e04fd8c8ba32aa2ba3f2a
DIFF: https://github.com/llvm/llvm-project/commit/57a8ea85538503a35d1e04fd8c8ba32aa2ba3f2a.diff

LOG: [InstCombine] Avoid infinite loop in insert/extract combine

Fix the infinite loop reported on https://reviews.llvm.org/D151807#4420467.

collectShuffleElements() will widen vectors and replace extracts
via replaceExtractElements(), to allow the next call of
collectShuffleElements() to fold. However, it's possible for another
fold to run first, and break the expected sequence again. To ensure
this does not happen, directly rerun the collectShuffleElements()
fold if we have adjusted extracts.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
    llvm/test/Transforms/InstCombine/insert-extract-shuffle-inseltpoison.ll
    llvm/test/Transforms/InstCombine/insert-extract-shuffle.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 3a863f29e3c73..f20ad388fbc00 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -681,7 +681,7 @@ static bool collectSingleShuffleElements(Value *V, Value *LHS, Value *RHS,
 /// If we have insertion into a vector that is wider than the vector that we
 /// are extracting from, try to widen the source vector to allow a single
 /// shufflevector to replace one or more insert/extract pairs.
-static void replaceExtractElements(InsertElementInst *InsElt,
+static bool replaceExtractElements(InsertElementInst *InsElt,
                                    ExtractElementInst *ExtElt,
                                    InstCombinerImpl &IC) {
   auto *InsVecType = cast<FixedVectorType>(InsElt->getType());
@@ -692,7 +692,7 @@ static void replaceExtractElements(InsertElementInst *InsElt,
   // The inserted-to vector must be wider than the extracted-from vector.
   if (InsVecType->getElementType() != ExtVecType->getElementType() ||
       NumExtElts >= NumInsElts)
-    return;
+    return false;
 
   // Create a shuffle mask to widen the extended-from vector using poison
   // values. The mask selects all of the values of the original vector followed
@@ -720,7 +720,7 @@ static void replaceExtractElements(InsertElementInst *InsElt,
   // that will delete our widening shuffle. This would trigger another attempt
   // here to create that shuffle, and we spin forever.
   if (InsertionBlock != InsElt->getParent())
-    return;
+    return false;
 
   // TODO: This restriction matches the check in visitInsertElementInst() and
   // prevents an infinite loop caused by not turning the extract/insert pair
@@ -728,7 +728,7 @@ static void replaceExtractElements(InsertElementInst *InsElt,
   // folds for shufflevectors because we're afraid to generate shuffle masks
   // that the backend can't handle.
   if (InsElt->hasOneUse() && isa<InsertElementInst>(InsElt->user_back()))
-    return;
+    return false;
 
   auto *WideVec = new ShuffleVectorInst(ExtVecOp, ExtendMask);
 
@@ -754,6 +754,8 @@ static void replaceExtractElements(InsertElementInst *InsElt,
     // extracts directly, because they may still be used by the calling code.
     IC.addToWorklist(OldExt);
   }
+
+  return true;
 }
 
 /// We are building a shuffle to create V, which is a sequence of insertelement,
@@ -768,7 +770,7 @@ using ShuffleOps = std::pair<Value *, Value *>;
 
 static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask,
                                          Value *PermittedRHS,
-                                         InstCombinerImpl &IC) {
+                                         InstCombinerImpl &IC, bool &Rerun) {
   assert(V->getType()->isVectorTy() && "Invalid shuffle!");
   unsigned NumElts = cast<FixedVectorType>(V->getType())->getNumElements();
 
@@ -799,13 +801,14 @@ static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask,
         // otherwise we'd end up with a shuffle of three inputs.
         if (EI->getOperand(0) == PermittedRHS || PermittedRHS == nullptr) {
           Value *RHS = EI->getOperand(0);
-          ShuffleOps LR = collectShuffleElements(VecOp, Mask, RHS, IC);
+          ShuffleOps LR = collectShuffleElements(VecOp, Mask, RHS, IC, Rerun);
           assert(LR.second == nullptr || LR.second == RHS);
 
           if (LR.first->getType() != RHS->getType()) {
             // Although we are giving up for now, see if we can create extracts
             // that match the inserts for another round of combining.
-            replaceExtractElements(IEI, EI, IC);
+            if (replaceExtractElements(IEI, EI, IC))
+              Rerun = true;
 
             // We tried our best, but we can't find anything compatible with RHS
             // further up the chain. Return a trivial shuffle.
@@ -1685,16 +1688,22 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) {
 
     // Try to form a shuffle from a chain of extract-insert ops.
     if (isShuffleRootCandidate(IE)) {
-      SmallVector<int, 16> Mask;
-      ShuffleOps LR = collectShuffleElements(&IE, Mask, nullptr, *this);
-
-      // The proposed shuffle may be trivial, in which case we shouldn't
-      // perform the combine.
-      if (LR.first != &IE && LR.second != &IE) {
-        // We now have a shuffle of LHS, RHS, Mask.
-        if (LR.second == nullptr)
-          LR.second = UndefValue::get(LR.first->getType());
-        return new ShuffleVectorInst(LR.first, LR.second, Mask);
+      bool Rerun = true;
+      while (Rerun) {
+        Rerun = false;
+
+        SmallVector<int, 16> Mask;
+        ShuffleOps LR =
+            collectShuffleElements(&IE, Mask, nullptr, *this, Rerun);
+
+        // The proposed shuffle may be trivial, in which case we shouldn't
+        // perform the combine.
+        if (LR.first != &IE && LR.second != &IE) {
+          // We now have a shuffle of LHS, RHS, Mask.
+          if (LR.second == nullptr)
+            LR.second = UndefValue::get(LR.first->getType());
+          return new ShuffleVectorInst(LR.first, LR.second, Mask);
+        }
       }
     }
   }

diff  --git a/llvm/test/Transforms/InstCombine/insert-extract-shuffle-inseltpoison.ll b/llvm/test/Transforms/InstCombine/insert-extract-shuffle-inseltpoison.ll
index b67487e417f95..3a3097c01850a 100644
--- a/llvm/test/Transforms/InstCombine/insert-extract-shuffle-inseltpoison.ll
+++ b/llvm/test/Transforms/InstCombine/insert-extract-shuffle-inseltpoison.ll
@@ -267,7 +267,7 @@ define <4 x i32> @extractelt_insertion(<2 x i32> %x, i32 %y) {
 ; CHECK-LABEL: @extractelt_insertion(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = shufflevector <2 x i32> [[X:%.*]], <2 x i32> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
-; CHECK-NEXT:    [[B:%.*]] = shufflevector <4 x i32> <i32 0, i32 0, i32 0, i32 poison>, <4 x i32> [[TMP0]], <4 x i32> <i32 0, i32 1, i32 2, i32 5>
+; CHECK-NEXT:    [[B:%.*]] = shufflevector <4 x i32> <i32 0, i32 poison, i32 poison, i32 poison>, <4 x i32> [[TMP0]], <4 x i32> <i32 0, i32 0, i32 0, i32 5>
 ; CHECK-NEXT:    [[C:%.*]] = add i32 [[Y:%.*]], 3
 ; CHECK-NEXT:    [[D:%.*]] = extractelement <4 x i32> [[TMP0]], i32 [[C]]
 ; CHECK-NEXT:    [[E:%.*]] = icmp eq i32 [[D]], 0

diff  --git a/llvm/test/Transforms/InstCombine/insert-extract-shuffle.ll b/llvm/test/Transforms/InstCombine/insert-extract-shuffle.ll
index 6afd737ba4357..93faebdb63a80 100644
--- a/llvm/test/Transforms/InstCombine/insert-extract-shuffle.ll
+++ b/llvm/test/Transforms/InstCombine/insert-extract-shuffle.ll
@@ -267,7 +267,7 @@ define <4 x i32> @extractelt_insertion(<2 x i32> %x, i32 %y) {
 ; CHECK-LABEL: @extractelt_insertion(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = shufflevector <2 x i32> [[X:%.*]], <2 x i32> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
-; CHECK-NEXT:    [[B:%.*]] = shufflevector <4 x i32> <i32 0, i32 0, i32 0, i32 poison>, <4 x i32> [[TMP0]], <4 x i32> <i32 0, i32 1, i32 2, i32 5>
+; CHECK-NEXT:    [[B:%.*]] = shufflevector <4 x i32> <i32 0, i32 poison, i32 poison, i32 poison>, <4 x i32> [[TMP0]], <4 x i32> <i32 0, i32 0, i32 0, i32 5>
 ; CHECK-NEXT:    [[C:%.*]] = add i32 [[Y:%.*]], 3
 ; CHECK-NEXT:    [[D:%.*]] = extractelement <4 x i32> [[TMP0]], i32 [[C]]
 ; CHECK-NEXT:    [[E:%.*]] = icmp eq i32 [[D]], 0
@@ -789,3 +789,18 @@ define <4 x float> @splat_constant(<4 x float> %x) {
   %r = fadd <4 x float> %ins3, %splat3
   ret <4 x float> %r
 }
+
+define <4 x i32> @infloop_D151807(<4 x float> %arg) {
+; CHECK-LABEL: @infloop_D151807(
+; CHECK-NEXT:    [[I:%.*]] = shufflevector <4 x float> [[ARG:%.*]], <4 x float> poison, <2 x i32> <i32 2, i32 poison>
+; CHECK-NEXT:    [[I1:%.*]] = bitcast <2 x float> [[I]] to <2 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <2 x i32> [[I1]], <2 x i32> poison, <4 x i32> <i32 0, i32 poison, i32 poison, i32 poison>
+; CHECK-NEXT:    [[I4:%.*]] = shufflevector <4 x i32> <i32 0, i32 poison, i32 poison, i32 poison>, <4 x i32> [[TMP1]], <4 x i32> <i32 4, i32 0, i32 0, i32 0>
+; CHECK-NEXT:    ret <4 x i32> [[I4]]
+;
+  %i = shufflevector <4 x float> %arg, <4 x float> poison, <2 x i32> <i32 2, i32 poison>
+  %i1 = bitcast <2 x float> %i to <2 x i32>
+  %i3 = extractelement <2 x i32> %i1, i64 0
+  %i4 = insertelement <4 x i32> zeroinitializer, i32 %i3, i64 0
+  ret <4 x i32> %i4
+}


        


More information about the llvm-commits mailing list