[llvm] 63081dc - LoadStoreVectorizer: Match nested adds to prove vectorization is safe

Volkan Keles via llvm-commits llvm-commits at lists.llvm.org
Mon May 18 12:13:17 PDT 2020


Author: Volkan Keles
Date: 2020-05-18T12:13:01-07:00
New Revision: 63081dc6f642a6be61b3ef213f5c6e257f35671c

URL: https://github.com/llvm/llvm-project/commit/63081dc6f642a6be61b3ef213f5c6e257f35671c
DIFF: https://github.com/llvm/llvm-project/commit/63081dc6f642a6be61b3ef213f5c6e257f35671c.diff

LOG: LoadStoreVectorizer: Match nested adds to prove vectorization is safe

If both OpA and OpB is an add with NSW/NUW and with the same LHS operand,
we can guarantee that the transformation is safe if we can prove that OpA
won't overflow when IdxDiff added to the RHS of OpA.

Review: https://reviews.llvm.org/D79817

Added: 
    llvm/test/Transforms/LoadStoreVectorizer/X86/vectorize-i8-nested-add.ll

Modified: 
    llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
index 9915e27c17b1..c02b8f889500 100644
--- a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
@@ -430,20 +430,78 @@ bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB,
 
   // Now we need to prove that adding IdxDiff to ValA won't overflow.
   bool Safe = false;
+  auto CheckFlags = [](Instruction *I, bool Signed) {
+    BinaryOperator *BinOpI = cast<BinaryOperator>(I);
+    return (Signed && BinOpI->hasNoSignedWrap()) ||
+           (!Signed && BinOpI->hasNoUnsignedWrap());
+  };
+
   // First attempt: if OpB is an add with NSW/NUW, and OpB is IdxDiff added to
   // ValA, we're okay.
   if (OpB->getOpcode() == Instruction::Add &&
       isa<ConstantInt>(OpB->getOperand(1)) &&
-      IdxDiff.sle(cast<ConstantInt>(OpB->getOperand(1))->getSExtValue())) {
-    if (Signed)
-      Safe = cast<BinaryOperator>(OpB)->hasNoSignedWrap();
-    else
-      Safe = cast<BinaryOperator>(OpB)->hasNoUnsignedWrap();
+      IdxDiff.sle(cast<ConstantInt>(OpB->getOperand(1))->getSExtValue()) &&
+      CheckFlags(OpB, Signed))
+    Safe = true;
+
+  // Second attempt: If both OpA and OpB is an add with NSW/NUW and with
+  // the same LHS operand, we can guarantee that the transformation is safe
+  // if we can prove that OpA won't overflow when IdxDiff added to the RHS
+  // of OpA.
+  // For example:
+  //  %tmp7 = add nsw i32 %tmp2, %v0
+  //  %tmp8 = sext i32 %tmp7 to i64
+  //  ...
+  //  %tmp11 = add nsw i32 %v0, 1
+  //  %tmp12 = add nsw i32 %tmp2, %tmp11
+  //  %tmp13 = sext i32 %tmp12 to i64
+  //
+  //  Both %tmp7 and %tmp2 has the nsw flag and the first operand
+  //  is %tmp2. It's guaranteed that adding 1 to %tmp7 won't overflow
+  //  because %tmp11 adds 1 to %v0 and both %tmp11 and %tmp12 has the
+  //  nsw flag.
+  OpA = dyn_cast<Instruction>(ValA);
+  if (!Safe && OpA && OpA->getOpcode() == Instruction::Add &&
+      OpB->getOpcode() == Instruction::Add &&
+      OpA->getOperand(0) == OpB->getOperand(0) && CheckFlags(OpA, Signed) &&
+      CheckFlags(OpB, Signed)) {
+    Value *RHSA = OpA->getOperand(1);
+    Value *RHSB = OpB->getOperand(1);
+    Instruction *OpRHSA = dyn_cast<Instruction>(RHSA);
+    Instruction *OpRHSB = dyn_cast<Instruction>(RHSB);
+    // Match `x +nsw/nuw y` and `x +nsw/nuw (y +nsw/nuw IdxDiff)`.
+    if (OpRHSB && OpRHSB->getOpcode() == Instruction::Add &&
+        CheckFlags(OpRHSB, Signed) && isa<ConstantInt>(OpRHSB->getOperand(1))) {
+      int64_t CstVal = cast<ConstantInt>(OpRHSB->getOperand(1))->getSExtValue();
+      if (OpRHSB->getOperand(0) == RHSA && IdxDiff.getSExtValue() == CstVal)
+        Safe = true;
+    }
+    // Match `x +nsw/nuw (y +nsw/nuw -Idx)` and `x +nsw/nuw (y +nsw/nuw x)`.
+    if (OpRHSA && OpRHSA->getOpcode() == Instruction::Add &&
+        CheckFlags(OpRHSA, Signed) && isa<ConstantInt>(OpRHSA->getOperand(1))) {
+      int64_t CstVal = cast<ConstantInt>(OpRHSA->getOperand(1))->getSExtValue();
+      if (OpRHSA->getOperand(0) == RHSB && IdxDiff.getSExtValue() == -CstVal)
+        Safe = true;
+    }
+    // Match `x +nsw/nuw (y +nsw/nuw c)` and
+    // `x +nsw/nuw (y +nsw/nuw (c + IdxDiff))`.
+    if (OpRHSA && OpRHSB && OpRHSA->getOpcode() == Instruction::Add &&
+        OpRHSB->getOpcode() == Instruction::Add && CheckFlags(OpRHSA, Signed) &&
+        CheckFlags(OpRHSB, Signed) && isa<ConstantInt>(OpRHSA->getOperand(1)) &&
+        isa<ConstantInt>(OpRHSB->getOperand(1))) {
+      int64_t CstValA =
+          cast<ConstantInt>(OpRHSA->getOperand(1))->getSExtValue();
+      int64_t CstValB =
+          cast<ConstantInt>(OpRHSB->getOperand(1))->getSExtValue();
+      if (OpRHSA->getOperand(0) == OpRHSB->getOperand(0) &&
+          IdxDiff.getSExtValue() == (CstValB - CstValA))
+        Safe = true;
+    }
   }
 
   unsigned BitWidth = ValA->getType()->getScalarSizeInBits();
 
-  // Second attempt:
+  // Third attempt:
   // If all set bits of IdxDiff or any higher order bit other than the sign bit
   // are known to be zero in ValA, we can add Diff to it while guaranteeing no
   // overflow of any sort.

diff  --git a/llvm/test/Transforms/LoadStoreVectorizer/X86/vectorize-i8-nested-add.ll b/llvm/test/Transforms/LoadStoreVectorizer/X86/vectorize-i8-nested-add.ll
new file mode 100644
index 000000000000..6ee00ad62976
--- /dev/null
+++ b/llvm/test/Transforms/LoadStoreVectorizer/X86/vectorize-i8-nested-add.ll
@@ -0,0 +1,165 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -o - -S -load-store-vectorizer -dce %s | FileCheck %s
+
+; Make sure LoadStoreVectorizer vectorizes the loads below.
+; In order to prove that the vectorization is safe, it tries to
+; match nested adds and find an expression that adds a constant
+; value to an existing index and the result doesn't overflow.
+
+target triple = "x86_64--"
+
+define void @ld_v4i8_add_nsw(i32 %v0, i32 %v1, i8* %src, <4 x i8>* %dst) {
+; CHECK-LABEL: @ld_v4i8_add_nsw(
+; CHECK-NEXT:  bb:
+; CHECK-NEXT:    [[TMP:%.*]] = add nsw i32 [[V0:%.*]], -1
+; CHECK-NEXT:    [[TMP1:%.*]] = add nsw i32 [[V1:%.*]], [[TMP]]
+; CHECK-NEXT:    [[TMP2:%.*]] = sext i32 [[TMP1]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i8, i8* [[SRC:%.*]], i64 [[TMP2]]
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast i8* [[TMP3]] to <4 x i8>*
+; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i8>, <4 x i8>* [[TMP0]], align 1
+; CHECK-NEXT:    [[TMP41:%.*]] = extractelement <4 x i8> [[TMP1]], i32 0
+; CHECK-NEXT:    [[TMP82:%.*]] = extractelement <4 x i8> [[TMP1]], i32 1
+; CHECK-NEXT:    [[TMP133:%.*]] = extractelement <4 x i8> [[TMP1]], i32 2
+; CHECK-NEXT:    [[TMP184:%.*]] = extractelement <4 x i8> [[TMP1]], i32 3
+; CHECK-NEXT:    [[TMP19:%.*]] = insertelement <4 x i8> undef, i8 [[TMP41]], i32 0
+; CHECK-NEXT:    [[TMP20:%.*]] = insertelement <4 x i8> [[TMP19]], i8 [[TMP82]], i32 1
+; CHECK-NEXT:    [[TMP21:%.*]] = insertelement <4 x i8> [[TMP20]], i8 [[TMP133]], i32 2
+; CHECK-NEXT:    [[TMP22:%.*]] = insertelement <4 x i8> [[TMP21]], i8 [[TMP184]], i32 3
+; CHECK-NEXT:    store <4 x i8> [[TMP22]], <4 x i8>* [[DST:%.*]]
+; CHECK-NEXT:    ret void
+;
+bb:
+  %tmp = add nsw i32 %v0, -1
+  %tmp1 = add nsw i32 %v1, %tmp
+  %tmp2 = sext i32 %tmp1 to i64
+  %tmp3 = getelementptr inbounds i8, i8* %src, i64 %tmp2
+  %tmp4 = load i8, i8* %tmp3, align 1
+  %tmp5 = add nsw i32 %v1, %v0
+  %tmp6 = sext i32 %tmp5 to i64
+  %tmp7 = getelementptr inbounds i8, i8* %src, i64 %tmp6
+  %tmp8 = load i8, i8* %tmp7, align 1
+  %tmp9 = add nsw i32 %v0, 1
+  %tmp10 = add nsw i32 %v1, %tmp9
+  %tmp11 = sext i32 %tmp10 to i64
+  %tmp12 = getelementptr inbounds i8, i8* %src, i64 %tmp11
+  %tmp13 = load i8, i8* %tmp12, align 1
+  %tmp14 = add nsw i32 %v0, 2
+  %tmp15 = add nsw i32 %v1, %tmp14
+  %tmp16 = sext i32 %tmp15 to i64
+  %tmp17 = getelementptr inbounds i8, i8* %src, i64 %tmp16
+  %tmp18 = load i8, i8* %tmp17, align 1
+  %tmp19 = insertelement <4 x i8> undef, i8 %tmp4, i32 0
+  %tmp20 = insertelement <4 x i8> %tmp19, i8 %tmp8, i32 1
+  %tmp21 = insertelement <4 x i8> %tmp20, i8 %tmp13, i32 2
+  %tmp22 = insertelement <4 x i8> %tmp21, i8 %tmp18, i32 3
+  store <4 x i8> %tmp22, <4 x i8>* %dst
+  ret void
+}
+
+define void @ld_v4i8_add_nuw(i32 %v0, i32 %v1, i8* %src, <4 x i8>* %dst) {
+; CHECK-LABEL: @ld_v4i8_add_nuw(
+; CHECK-NEXT:  bb:
+; CHECK-NEXT:    [[TMP:%.*]] = add nuw i32 [[V0:%.*]], -1
+; CHECK-NEXT:    [[TMP1:%.*]] = add nuw i32 [[V1:%.*]], [[TMP]]
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i32 [[TMP1]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i8, i8* [[SRC:%.*]], i64 [[TMP2]]
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast i8* [[TMP3]] to <4 x i8>*
+; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i8>, <4 x i8>* [[TMP0]], align 1
+; CHECK-NEXT:    [[TMP41:%.*]] = extractelement <4 x i8> [[TMP1]], i32 0
+; CHECK-NEXT:    [[TMP82:%.*]] = extractelement <4 x i8> [[TMP1]], i32 1
+; CHECK-NEXT:    [[TMP133:%.*]] = extractelement <4 x i8> [[TMP1]], i32 2
+; CHECK-NEXT:    [[TMP184:%.*]] = extractelement <4 x i8> [[TMP1]], i32 3
+; CHECK-NEXT:    [[TMP19:%.*]] = insertelement <4 x i8> undef, i8 [[TMP41]], i32 0
+; CHECK-NEXT:    [[TMP20:%.*]] = insertelement <4 x i8> [[TMP19]], i8 [[TMP82]], i32 1
+; CHECK-NEXT:    [[TMP21:%.*]] = insertelement <4 x i8> [[TMP20]], i8 [[TMP133]], i32 2
+; CHECK-NEXT:    [[TMP22:%.*]] = insertelement <4 x i8> [[TMP21]], i8 [[TMP184]], i32 3
+; CHECK-NEXT:    store <4 x i8> [[TMP22]], <4 x i8>* [[DST:%.*]]
+; CHECK-NEXT:    ret void
+;
+bb:
+  %tmp = add nuw i32 %v0, -1
+  %tmp1 = add nuw i32 %v1, %tmp
+  %tmp2 = zext i32 %tmp1 to i64
+  %tmp3 = getelementptr inbounds i8, i8* %src, i64 %tmp2
+  %tmp4 = load i8, i8* %tmp3, align 1
+  %tmp5 = add nuw i32 %v1, %v0
+  %tmp6 = zext i32 %tmp5 to i64
+  %tmp7 = getelementptr inbounds i8, i8* %src, i64 %tmp6
+  %tmp8 = load i8, i8* %tmp7, align 1
+  %tmp9 = add nuw i32 %v0, 1
+  %tmp10 = add nuw i32 %v1, %tmp9
+  %tmp11 = zext i32 %tmp10 to i64
+  %tmp12 = getelementptr inbounds i8, i8* %src, i64 %tmp11
+  %tmp13 = load i8, i8* %tmp12, align 1
+  %tmp14 = add nuw i32 %v0, 2
+  %tmp15 = add nuw i32 %v1, %tmp14
+  %tmp16 = zext i32 %tmp15 to i64
+  %tmp17 = getelementptr inbounds i8, i8* %src, i64 %tmp16
+  %tmp18 = load i8, i8* %tmp17, align 1
+  %tmp19 = insertelement <4 x i8> undef, i8 %tmp4, i32 0
+  %tmp20 = insertelement <4 x i8> %tmp19, i8 %tmp8, i32 1
+  %tmp21 = insertelement <4 x i8> %tmp20, i8 %tmp13, i32 2
+  %tmp22 = insertelement <4 x i8> %tmp21, i8 %tmp18, i32 3
+  store <4 x i8> %tmp22, <4 x i8>* %dst
+  ret void
+}
+
+; Make sure we don't vectorize the loads below because the source of
+; sext instructions doesn't have the nsw flag.
+
+define void @ld_v4i8_add_not_safe(i32 %v0, i32 %v1, i8* %src, <4 x i8>* %dst) {
+; CHECK-LABEL: @ld_v4i8_add_not_safe(
+; CHECK-NEXT:  bb:
+; CHECK-NEXT:    [[TMP:%.*]] = add nsw i32 [[V0:%.*]], -1
+; CHECK-NEXT:    [[TMP1:%.*]] = add i32 [[V1:%.*]], [[TMP]]
+; CHECK-NEXT:    [[TMP2:%.*]] = sext i32 [[TMP1]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i8, i8* [[SRC:%.*]], i64 [[TMP2]]
+; CHECK-NEXT:    [[TMP4:%.*]] = load i8, i8* [[TMP3]], align 1
+; CHECK-NEXT:    [[TMP5:%.*]] = add i32 [[V1]], [[V0]]
+; CHECK-NEXT:    [[TMP6:%.*]] = sext i32 [[TMP5]] to i64
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i8, i8* [[SRC]], i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP8:%.*]] = load i8, i8* [[TMP7]], align 1
+; CHECK-NEXT:    [[TMP9:%.*]] = add nsw i32 [[V0]], 1
+; CHECK-NEXT:    [[TMP10:%.*]] = add i32 [[V1]], [[TMP9]]
+; CHECK-NEXT:    [[TMP11:%.*]] = sext i32 [[TMP10]] to i64
+; CHECK-NEXT:    [[TMP12:%.*]] = getelementptr inbounds i8, i8* [[SRC]], i64 [[TMP11]]
+; CHECK-NEXT:    [[TMP13:%.*]] = load i8, i8* [[TMP12]], align 1
+; CHECK-NEXT:    [[TMP14:%.*]] = add nsw i32 [[V0]], 2
+; CHECK-NEXT:    [[TMP15:%.*]] = add i32 [[V1]], [[TMP14]]
+; CHECK-NEXT:    [[TMP16:%.*]] = sext i32 [[TMP15]] to i64
+; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr inbounds i8, i8* [[SRC]], i64 [[TMP16]]
+; CHECK-NEXT:    [[TMP18:%.*]] = load i8, i8* [[TMP17]], align 1
+; CHECK-NEXT:    [[TMP19:%.*]] = insertelement <4 x i8> undef, i8 [[TMP4]], i32 0
+; CHECK-NEXT:    [[TMP20:%.*]] = insertelement <4 x i8> [[TMP19]], i8 [[TMP8]], i32 1
+; CHECK-NEXT:    [[TMP21:%.*]] = insertelement <4 x i8> [[TMP20]], i8 [[TMP13]], i32 2
+; CHECK-NEXT:    [[TMP22:%.*]] = insertelement <4 x i8> [[TMP21]], i8 [[TMP18]], i32 3
+; CHECK-NEXT:    store <4 x i8> [[TMP22]], <4 x i8>* [[DST:%.*]]
+; CHECK-NEXT:    ret void
+;
+bb:
+  %tmp = add nsw i32 %v0, -1
+  %tmp1 = add i32 %v1, %tmp
+  %tmp2 = sext i32 %tmp1 to i64
+  %tmp3 = getelementptr inbounds i8, i8* %src, i64 %tmp2
+  %tmp4 = load i8, i8* %tmp3, align 1
+  %tmp5 = add i32 %v1, %v0
+  %tmp6 = sext i32 %tmp5 to i64
+  %tmp7 = getelementptr inbounds i8, i8* %src, i64 %tmp6
+  %tmp8 = load i8, i8* %tmp7, align 1
+  %tmp9 = add nsw i32 %v0, 1
+  %tmp10 = add i32 %v1, %tmp9
+  %tmp11 = sext i32 %tmp10 to i64
+  %tmp12 = getelementptr inbounds i8, i8* %src, i64 %tmp11
+  %tmp13 = load i8, i8* %tmp12, align 1
+  %tmp14 = add nsw i32 %v0, 2
+  %tmp15 = add i32 %v1, %tmp14
+  %tmp16 = sext i32 %tmp15 to i64
+  %tmp17 = getelementptr inbounds i8, i8* %src, i64 %tmp16
+  %tmp18 = load i8, i8* %tmp17, align 1
+  %tmp19 = insertelement <4 x i8> undef, i8 %tmp4, i32 0
+  %tmp20 = insertelement <4 x i8> %tmp19, i8 %tmp8, i32 1
+  %tmp21 = insertelement <4 x i8> %tmp20, i8 %tmp13, i32 2
+  %tmp22 = insertelement <4 x i8> %tmp21, i8 %tmp18, i32 3
+  store <4 x i8> %tmp22, <4 x i8>* %dst
+  ret void
+}


        


More information about the llvm-commits mailing list