[llvm] r298631 - Model ashr(shl(x, n), m) as mul(x, 2^(n-m)) when n > m

Zhaoshi Zheng via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 23 11:06:09 PDT 2017


Author: zzheng
Date: Thu Mar 23 13:06:09 2017
New Revision: 298631

URL: http://llvm.org/viewvc/llvm-project?rev=298631&view=rev
Log:
Model ashr(shl(x, n), m) as mul(x, 2^(n-m)) when n > m

Given below case:

  %y = shl %x, n
  %z = ashr %y, m

when n = m, SCEV models it as sext(trunc(x)). This patch tries to handle
the case where n > m by using sext(mul(trunc(x), 2^(n-m)))) as the SCEV
expression.

Added:
    llvm/trunk/test/Analysis/ScalarEvolution/sext-mul.ll
    llvm/trunk/test/Analysis/ScalarEvolution/sext-zero.ll
Modified:
    llvm/trunk/lib/Analysis/ScalarEvolution.cpp

Modified: llvm/trunk/lib/Analysis/ScalarEvolution.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolution.cpp?rev=298631&r1=298630&r2=298631&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/ScalarEvolution.cpp (original)
+++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp Thu Mar 23 13:06:09 2017
@@ -5356,28 +5356,55 @@ const SCEV *ScalarEvolution::createSCEV(
     break;
 
     case Instruction::AShr:
-      // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
-      if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS))
-        if (Operator *L = dyn_cast<Operator>(BO->LHS))
-          if (L->getOpcode() == Instruction::Shl &&
-              L->getOperand(1) == BO->RHS) {
-            uint64_t BitWidth = getTypeSizeInBits(BO->LHS->getType());
-
-            // If the shift count is not less than the bitwidth, the result of
-            // the shift is undefined. Don't try to analyze it, because the
-            // resolution chosen here may differ from the resolution chosen in
-            // other parts of the compiler.
-            if (CI->getValue().uge(BitWidth))
-              break;
-
-            uint64_t Amt = BitWidth - CI->getZExtValue();
-            if (Amt == BitWidth)
-              return getSCEV(L->getOperand(0)); // shift by zero --> noop
+      // AShr X, C, where C is a constant.
+      ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
+      if (!CI)
+        break;
+
+      Type *OuterTy = BO->LHS->getType();
+      uint64_t BitWidth = getTypeSizeInBits(OuterTy);
+      // If the shift count is not less than the bitwidth, the result of
+      // the shift is undefined. Don't try to analyze it, because the
+      // resolution chosen here may differ from the resolution chosen in
+      // other parts of the compiler.
+      if (CI->getValue().uge(BitWidth))
+        break;
+
+      if (CI->isNullValue())
+        return getSCEV(BO->LHS); // shift by zero --> noop
+
+      uint64_t AShrAmt = CI->getZExtValue();
+      Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
+
+      Operator *L = dyn_cast<Operator>(BO->LHS);
+      if (L && L->getOpcode() == Instruction::Shl) {
+        // X = Shl A, n
+        // Y = AShr X, m
+        // Both n and m are constant.
+
+        const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
+        if (L->getOperand(1) == BO->RHS)
+          // For a two-shift sext-inreg, i.e. n = m,
+          // use sext(trunc(x)) as the SCEV expression.
+          return getSignExtendExpr(
+              getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy);
+
+        ConstantInt *ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
+        if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) {
+          uint64_t ShlAmt = ShlAmtCI->getZExtValue();
+          if (ShlAmt > AShrAmt) {
+            // When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
+            // expression. We already checked that ShlAmt < BitWidth, so
+            // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
+            // ShlAmt - AShrAmt < Amt.
+            APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
+                                            ShlAmt - AShrAmt);
             return getSignExtendExpr(
-                getTruncateExpr(getSCEV(L->getOperand(0)),
-                                IntegerType::get(getContext(), Amt)),
-                BO->LHS->getType());
+                getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy),
+                getConstant(Mul)), OuterTy);
           }
+        }
+      }
       break;
     }
   }

Added: llvm/trunk/test/Analysis/ScalarEvolution/sext-mul.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Analysis/ScalarEvolution/sext-mul.ll?rev=298631&view=auto
==============================================================================
--- llvm/trunk/test/Analysis/ScalarEvolution/sext-mul.ll (added)
+++ llvm/trunk/test/Analysis/ScalarEvolution/sext-mul.ll Thu Mar 23 13:06:09 2017
@@ -0,0 +1,89 @@
+; RUN: opt < %s -analyze -scalar-evolution | FileCheck %s
+
+; CHECK: %tmp9 = shl i64 %tmp8, 33
+; CHECK-NEXT: --> {{.*}} Exits: (-8589934592 + (8589934592 * (zext i32 %arg2 to i64)))
+; CHECK: %tmp10 = ashr exact i64 %tmp9, 32
+; CHECK-NEXT: --> {{.*}} Exits: (sext i32 (-2 + (2 * %arg2)) to i64)
+; CHECK: %tmp11 = getelementptr inbounds i32, i32* %arg, i64 %tmp10
+; CHECK-NEXT: --> {{.*}} Exits: ((4 * (sext i32 (-2 + (2 * %arg2)) to i64)) + %arg)
+; CHECK:  %tmp14 = or i64 %tmp10, 1
+; CHECK-NEXT: --> {{.*}} Exits: (1 + (sext i32 (-2 + (2 * %arg2)) to i64))<nsw>
+; CHECK: %tmp15 = getelementptr inbounds i32, i32* %arg, i64 %tmp14
+; CHECK-NEXT: --> {{.*}} Exits: (4 + (4 * (sext i32 (-2 + (2 * %arg2)) to i64)) + %arg)
+; CHECK:Loop %bb7: backedge-taken count is (-1 + (zext i32 %arg2 to i64))<nsw>
+; CHECK-NEXT:Loop %bb7: max backedge-taken count is -1
+; CHECK-NEXT:Loop %bb7: Predicated backedge-taken count is (-1 + (zext i32 %arg2 to i64))<nsw>
+
+define void @foo(i32* nocapture %arg, i32 %arg1, i32 %arg2) {
+bb:
+  %tmp = icmp sgt i32 %arg2, 0
+  br i1 %tmp, label %bb3, label %bb6
+
+bb3:                                              ; preds = %bb
+  %tmp4 = zext i32 %arg2 to i64
+  br label %bb7
+
+bb5:                                              ; preds = %bb7
+  br label %bb6
+
+bb6:                                              ; preds = %bb5, %bb
+  ret void
+
+bb7:                                              ; preds = %bb7, %bb3
+  %tmp8 = phi i64 [ %tmp18, %bb7 ], [ 0, %bb3 ]
+  %tmp9 = shl i64 %tmp8, 33
+  %tmp10 = ashr exact i64 %tmp9, 32
+  %tmp11 = getelementptr inbounds i32, i32* %arg, i64 %tmp10
+  %tmp12 = load i32, i32* %tmp11, align 4
+  %tmp13 = sub nsw i32 %tmp12, %arg1
+  store i32 %tmp13, i32* %tmp11, align 4
+  %tmp14 = or i64 %tmp10, 1
+  %tmp15 = getelementptr inbounds i32, i32* %arg, i64 %tmp14
+  %tmp16 = load i32, i32* %tmp15, align 4
+  %tmp17 = mul nsw i32 %tmp16, %arg1
+  store i32 %tmp17, i32* %tmp15, align 4
+  %tmp18 = add nuw nsw i64 %tmp8, 1
+  %tmp19 = icmp eq i64 %tmp18, %tmp4
+  br i1 %tmp19, label %bb5, label %bb7
+}
+
+; CHECK: %t10 = ashr exact i128 %t9, 1
+; CHECK-NEXT: --> {{.*}} Exits: (sext i127 (-633825300114114700748351602688 + (633825300114114700748351602688 * (zext i32 %arg5 to i127))) to i128)
+; CHECK: %t14 = or i128 %t10, 1
+; CHECK-NEXT: --> {{.*}} Exits: (1 + (sext i127 (-633825300114114700748351602688 + (633825300114114700748351602688 * (zext i32 %arg5 to i127))) to i128))<nsw>
+; CHECK: Loop %bb7: backedge-taken count is (-1 + (zext i32 %arg5 to i128))<nsw>
+; CHECK-NEXT: Loop %bb7: max backedge-taken count is -1
+; CHECK-NEXT: Loop %bb7: Predicated backedge-taken count is (-1 + (zext i32 %arg5 to i128))<nsw>
+
+define void @goo(i32* nocapture %arg3, i32 %arg4, i32 %arg5) {
+bb:
+  %t = icmp sgt i32 %arg5, 0
+  br i1 %t, label %bb3, label %bb6
+
+bb3:                                              ; preds = %bb
+  %t4 = zext i32 %arg5 to i128
+  br label %bb7
+
+bb5:                                              ; preds = %bb7
+  br label %bb6
+
+bb6:                                              ; preds = %bb5, %bb
+  ret void
+
+bb7:                                              ; preds = %bb7, %bb3
+  %t8 = phi i128 [ %t18, %bb7 ], [ 0, %bb3 ]
+  %t9 = shl i128 %t8, 100
+  %t10 = ashr exact i128 %t9, 1
+  %t11 = getelementptr inbounds i32, i32* %arg3, i128 %t10
+  %t12 = load i32, i32* %t11, align 4
+  %t13 = sub nsw i32 %t12, %arg4
+  store i32 %t13, i32* %t11, align 4
+  %t14 = or i128 %t10, 1
+  %t15 = getelementptr inbounds i32, i32* %arg3, i128 %t14
+  %t16 = load i32, i32* %t15, align 4
+  %t17 = mul nsw i32 %t16, %arg4
+  store i32 %t17, i32* %t15, align 4
+  %t18 = add nuw nsw i128 %t8, 1
+  %t19 = icmp eq i128 %t18, %t4
+  br i1 %t19, label %bb5, label %bb7
+}

Added: llvm/trunk/test/Analysis/ScalarEvolution/sext-zero.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Analysis/ScalarEvolution/sext-zero.ll?rev=298631&view=auto
==============================================================================
--- llvm/trunk/test/Analysis/ScalarEvolution/sext-zero.ll (added)
+++ llvm/trunk/test/Analysis/ScalarEvolution/sext-zero.ll Thu Mar 23 13:06:09 2017
@@ -0,0 +1,39 @@
+; RUN: opt < %s -analyze -scalar-evolution | FileCheck %s
+
+; CHECK:  %tmp9 = shl i64 %tmp8, 33
+; CHECK-NEXT:  -->  {{.*}} Exits: (-8589934592 + (8589934592 * (zext i32 %arg2 to i64)))
+; CHECK-NEXT:  %tmp10 = ashr exact i64 %tmp9, 0
+; CHECK-NEXT:  -->  {{.*}} Exits: (-8589934592 + (8589934592 * (zext i32 %arg2 to i64)))
+
+define void @foo(i32* nocapture %arg, i32 %arg1, i32 %arg2) {
+bb:
+  %tmp = icmp sgt i32 %arg2, 0
+  br i1 %tmp, label %bb3, label %bb6
+
+bb3:                                              ; preds = %bb
+  %tmp4 = zext i32 %arg2 to i64
+  br label %bb7
+
+bb5:                                              ; preds = %bb7
+  br label %bb6
+
+bb6:                                              ; preds = %bb5, %bb
+  ret void
+
+bb7:                                              ; preds = %bb7, %bb3
+  %tmp8 = phi i64 [ %tmp18, %bb7 ], [ 0, %bb3 ]
+  %tmp9 = shl i64 %tmp8, 33
+  %tmp10 = ashr exact i64 %tmp9, 0
+  %tmp11 = getelementptr inbounds i32, i32* %arg, i64 %tmp10
+  %tmp12 = load i32, i32* %tmp11, align 4
+  %tmp13 = sub nsw i32 %tmp12, %arg1
+  store i32 %tmp13, i32* %tmp11, align 4
+  %tmp14 = or i64 %tmp10, 1
+  %tmp15 = getelementptr inbounds i32, i32* %arg, i64 %tmp14
+  %tmp16 = load i32, i32* %tmp15, align 4
+  %tmp17 = mul nsw i32 %tmp16, %arg1
+  store i32 %tmp17, i32* %tmp15, align 4
+  %tmp18 = add nuw nsw i64 %tmp8, 1
+  %tmp19 = icmp eq i64 %tmp18, %tmp4
+  br i1 %tmp19, label %bb5, label %bb7
+}




More information about the llvm-commits mailing list