[llvm] 5a9a02f - [SCEV] Compute SCEV for ashr(add(shl(x, n), c), m) instr triplet

Vedant Paranjape via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 24 22:47:41 PDT 2023


Author: Vedant Paranjape
Date: 2023-08-25T05:42:08Z
New Revision: 5a9a02f67b771fb2edcf0662146fb023892d8ec7

URL: https://github.com/llvm/llvm-project/commit/5a9a02f67b771fb2edcf0662146fb023892d8ec7
DIFF: https://github.com/llvm/llvm-project/commit/5a9a02f67b771fb2edcf0662146fb023892d8ec7.diff

LOG: [SCEV] Compute SCEV for ashr(add(shl(x, n), c), m) instr triplet

%x = shl i64 %w, n
%y = add i64 %x, c
%z = ashr i64 %y, m

The above given instruction triplet is seen many times in the generated
LLVM IR, but SCEV model is not able to compute the SCEV value of AShr
instruction in this case.

This patch models the two cases of the above instruction pattern using
the following expression:

=> sext(add(mul(trunc(w), 2^(n-m)), c >> m))

1) when n = m the expression reduces to sext(add(trunc(w), c >> n))
as n-m=0, and multiplying with 2^0 gives the same result.

2) when n > m the expression works as given above.

It also adds several unittest to verify that SCEV is able to compute
the value.

$ opt sext-add-inreg.ll -passes="print<scalar-evolution>"

Comparing the snippets of the result of SCEV analysis:

* SCEV of ashr before change
----------------------------
%idxprom = ashr exact i64 %sext, 32
  -->  %idxprom U: [-2147483648,2147483648) S: [-2147483648,2147483648)
  Exits: 8                LoopDispositions: { %for.body: Variant }

* SCEV of ashr after change
---------------------------
%idxprom = ashr exact i64 %sext, 32
  -->  {0,+,1}<nuw><nsw><%for.body> U: [0,9) S: [0,9)
  Exits: 8                LoopDispositions: { %for.body: Computable }

LoopDisposition of the given SCEV was LoopVariant before, after adding
the new way to model the instruction, the LoopDisposition becomes
LoopComputable as it is able to compute the SCEV of the instruction.

Differential Revision: https://reviews.llvm.org/D152278

Added: 
    llvm/test/Analysis/ScalarEvolution/sext-add-inreg-loop.ll
    llvm/test/Analysis/ScalarEvolution/sext-add-inreg-unequal.ll
    llvm/test/Analysis/ScalarEvolution/sext-add-inreg.ll

Modified: 
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/test/Transforms/LoopStrengthReduce/scaling-factor-incompat-type.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 8a9cbfe79b2664..7cc964e03c2479 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -7854,7 +7854,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
       }
       break;
 
-    case Instruction::AShr: {
+    case Instruction::AShr:
       // AShr X, C, where C is a constant.
       ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
       if (!CI)
@@ -7876,37 +7876,69 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
       Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
 
       Operator *L = dyn_cast<Operator>(BO->LHS);
-      if (L && L->getOpcode() == Instruction::Shl) {
+      const SCEV *AddTruncateExpr = nullptr;
+      ConstantInt *ShlAmtCI = nullptr;
+      const SCEV *AddConstant = nullptr;
+
+      if (L && L->getOpcode() == Instruction::Add) {
+        // X = Shl A, n
+        // Y = Add X, c
+        // Z = AShr Y, m
+        // n, c and m are constants.
+
+        Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
+        ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
+        if (LShift && LShift->getOpcode() == Instruction::Shl) {
+          if (AddOperandCI) {
+            const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
+            ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
+            // since we truncate to TruncTy, the AddConstant should be of the
+            // same type, so create a new Constant with type same as TruncTy.
+            // Also, the Add constant should be shifted right by AShr amount.
+            APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
+            AddConstant = getConstant(TruncTy, AddOperand.getZExtValue(),
+                                      AddOperand.isSignBitSet());
+            // we model the expression as sext(add(trunc(A), c << n)), since the
+            // sext(trunc) part is already handled below, we create a
+            // AddExpr(TruncExp) which will be used later.
+            AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
+          }
+        }
+      } else if (L && L->getOpcode() == Instruction::Shl) {
         // X = Shl A, n
         // Y = AShr X, m
         // Both n and m are constant.
 
         const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
-        if (L->getOperand(1) == BO->RHS)
-          // For a two-shift sext-inreg, i.e. n = m,
-          // use sext(trunc(x)) as the SCEV expression.
-          return getSignExtendExpr(
-              getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy);
-
-        ConstantInt *ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
-        if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) {
-          uint64_t ShlAmt = ShlAmtCI->getZExtValue();
-          if (ShlAmt > AShrAmt) {
-            // When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
-            // expression. We already checked that ShlAmt < BitWidth, so
-            // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
-            // ShlAmt - AShrAmt < Amt.
-            APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
-                                            ShlAmt - AShrAmt);
-            return getSignExtendExpr(
-                getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy),
-                getConstant(Mul)), OuterTy);
-          }
+        ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
+        AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
+      }
+
+      if (AddTruncateExpr && ShlAmtCI) {
+        // We can merge the two given cases into a single SCEV statement,
+        // incase n = m, the mul expression will be 2^0, so it gets resolved to
+        // a simpler case. The following code handles the two cases:
+        //
+        // 1) For a two-shift sext-inreg, i.e. n = m,
+        //    use sext(trunc(x)) as the SCEV expression.
+        //
+        // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
+        //    expression. We already checked that ShlAmt < BitWidth, so
+        //    the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
+        //    ShlAmt - AShrAmt < Amt.
+        uint64_t ShlAmt = ShlAmtCI->getZExtValue();
+        if (ShlAmtCI->getValue().ult(BitWidth) && ShlAmt >= AShrAmt) {
+          APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt, ShlAmt - AShrAmt);
+          const SCEV *CompositeExpr =
+              getMulExpr(AddTruncateExpr, getConstant(Mul));
+          if (L->getOpcode() != Instruction::Shl)
+            CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
+
+          return getSignExtendExpr(CompositeExpr, OuterTy);
         }
       }
       break;
     }
-    }
   }
 
   switch (U->getOpcode()) {

diff  --git a/llvm/test/Analysis/ScalarEvolution/sext-add-inreg-loop.ll b/llvm/test/Analysis/ScalarEvolution/sext-add-inreg-loop.ll
new file mode 100644
index 00000000000000..92becb7995919f
--- /dev/null
+++ b/llvm/test/Analysis/ScalarEvolution/sext-add-inreg-loop.ll
@@ -0,0 +1,52 @@
+; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py UTC_ARGS: --version 2
+; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" 2>&1 | FileCheck %s
+
+ at .str = private unnamed_addr constant [3 x i8] c"%x\00", align 1
+
+define dso_local i32 @test_loop(ptr nocapture noundef readonly %x) {
+; CHECK-LABEL: 'test_loop'
+; CHECK-NEXT:  Classifying expressions for: @test_loop
+; CHECK-NEXT:    %i.03 = phi i64 [ 1, %entry ], [ %inc, %for.body ]
+; CHECK-NEXT:    --> {1,+,1}<nuw><nsw><%for.body> U: [1,10) S: [1,10) Exits: 9 LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:    %conv = shl nuw nsw i64 %i.03, 32
+; CHECK-NEXT:    --> {4294967296,+,4294967296}<nuw><nsw><%for.body> U: [4294967296,38654705665) S: [4294967296,38654705665) Exits: 38654705664 LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:    %sext = add nsw i64 %conv, -4294967296
+; CHECK-NEXT:    --> {0,+,4294967296}<nuw><nsw><%for.body> U: [0,34359738369) S: [0,34359738369) Exits: 34359738368 LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:    %idxprom = ashr exact i64 %sext, 32
+; CHECK-NEXT:    --> {0,+,1}<nuw><nsw><%for.body> U: [0,9) S: [0,9) Exits: 8 LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:    %arrayidx = getelementptr inbounds i32, ptr %x, i64 %idxprom
+; CHECK-NEXT:    --> {%x,+,4}<nuw><%for.body> U: full-set S: full-set Exits: (32 + %x) LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:    %0 = load i32, ptr %arrayidx, align 4
+; CHECK-NEXT:    --> %0 U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Variant }
+; CHECK-NEXT:    %call = tail call i32 (ptr, ...) @printf(ptr noundef nonnull dereferenceable(1) @.str, i32 noundef %0)
+; CHECK-NEXT:    --> %call U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Variant }
+; CHECK-NEXT:    %inc = add nuw nsw i64 %i.03, 1
+; CHECK-NEXT:    --> {2,+,1}<nuw><nsw><%for.body> U: [2,11) S: [2,11) Exits: 10 LoopDispositions: { %for.body: Computable }
+; CHECK-NEXT:  Determining loop execution counts for: @test_loop
+; CHECK-NEXT:  Loop %for.body: backedge-taken count is 8
+; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is 8
+; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is 8
+; CHECK-NEXT:  Loop %for.body: Predicated backedge-taken count is 8
+; CHECK-NEXT:   Predicates:
+; CHECK:       Loop %for.body: Trip multiple is 9
+;
+entry:
+  br label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body
+  ret i32 0
+
+for.body:                                         ; preds = %entry, %for.body
+  %i.03 = phi i64 [ 1, %entry ], [ %inc, %for.body ]
+  %conv = shl nuw nsw i64 %i.03, 32
+  %sext = add nsw i64 %conv, -4294967296
+  %idxprom = ashr exact i64 %sext, 32
+  %arrayidx = getelementptr inbounds i32, ptr %x, i64 %idxprom
+  %0 = load i32, ptr %arrayidx, align 4
+  %call = tail call i32 (ptr, ...) @printf(ptr noundef nonnull dereferenceable(1) @.str, i32 noundef %0)
+  %inc = add nuw nsw i64 %i.03, 1
+  %exitcond.not = icmp eq i64 %inc, 10
+  br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
+}
+
+declare noundef i32 @printf(ptr nocapture noundef readonly, ...)

diff  --git a/llvm/test/Analysis/ScalarEvolution/sext-add-inreg-unequal.ll b/llvm/test/Analysis/ScalarEvolution/sext-add-inreg-unequal.ll
new file mode 100644
index 00000000000000..17aa084b242442
--- /dev/null
+++ b/llvm/test/Analysis/ScalarEvolution/sext-add-inreg-unequal.ll
@@ -0,0 +1,53 @@
+; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py UTC_ARGS: --version 2
+; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" 2>&1 | FileCheck %s
+
+define i64 @test00(i64 %a) {
+; CHECK-LABEL: 'test00'
+; CHECK-NEXT:  Classifying expressions for: @test00
+; CHECK-NEXT:    %add = shl i64 %a, 10
+; CHECK-NEXT:    --> (1024 * %a) U: [0,-1023) S: [-9223372036854775808,9223372036854774785)
+; CHECK-NEXT:    %shl = add i64 %add, 256
+; CHECK-NEXT:    --> (256 + (1024 * %a))<nuw><nsw> U: [256,-767) S: [-9223372036854775552,9223372036854775041)
+; CHECK-NEXT:    %ashr = ashr exact i64 %shl, 8
+; CHECK-NEXT:    --> (1 + (sext i56 (4 * (trunc i64 %a to i56)) to i64))<nuw><nsw> U: [1,-2) S: [-36028797018963967,36028797018963966)
+; CHECK-NEXT:  Determining loop execution counts for: @test00
+;
+  %add = shl i64 %a, 10
+  %shl = add i64 %add, 256
+  %ashr = ashr exact i64 %shl, 8
+  ret i64 %ashr
+}
+
+define i64 @test01(i64 %a) {
+; CHECK-LABEL: 'test01'
+; CHECK-NEXT:  Classifying expressions for: @test01
+; CHECK-NEXT:    %add = shl i64 %a, 6
+; CHECK-NEXT:    --> (64 * %a) U: [0,-63) S: [-9223372036854775808,9223372036854775745)
+; CHECK-NEXT:    %shl = add i64 %add, 256
+; CHECK-NEXT:    --> (256 + (64 * %a)) U: [0,-63) S: [-9223372036854775808,9223372036854775745)
+; CHECK-NEXT:    %ashr = ashr exact i64 %shl, 8
+; CHECK-NEXT:    --> %ashr U: [-36028797018963968,36028797018963968) S: [-36028797018963968,36028797018963968)
+; CHECK-NEXT:  Determining loop execution counts for: @test01
+;
+  %add = shl i64 %a, 6
+  %shl = add i64 %add, 256
+  %ashr = ashr exact i64 %shl, 8
+  ret i64 %ashr
+}
+
+define i64 @test02(i64 %a) {
+; CHECK-LABEL: 'test02'
+; CHECK-NEXT:  Classifying expressions for: @test02
+; CHECK-NEXT:    %add = shl i64 %a, 12
+; CHECK-NEXT:    --> (4096 * %a) U: [0,-4095) S: [-9223372036854775808,9223372036854771713)
+; CHECK-NEXT:    %shl = add i64 %add, 4096
+; CHECK-NEXT:    --> (4096 + (4096 * %a)) U: [0,-4095) S: [-9223372036854775808,9223372036854771713)
+; CHECK-NEXT:    %ashr = ashr exact i64 %shl, 8
+; CHECK-NEXT:    --> (sext i56 (16 + (16 * (trunc i64 %a to i56))) to i64) U: [0,-15) S: [-36028797018963968,36028797018963953)
+; CHECK-NEXT:  Determining loop execution counts for: @test02
+;
+  %add = shl i64 %a, 12
+  %shl = add i64 %add, 4096
+  %ashr = ashr exact i64 %shl, 8
+  ret i64 %ashr
+}

diff  --git a/llvm/test/Analysis/ScalarEvolution/sext-add-inreg.ll b/llvm/test/Analysis/ScalarEvolution/sext-add-inreg.ll
new file mode 100644
index 00000000000000..8aac64ea4f6dc4
--- /dev/null
+++ b/llvm/test/Analysis/ScalarEvolution/sext-add-inreg.ll
@@ -0,0 +1,19 @@
+; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py UTC_ARGS: --version 2
+; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" 2>&1 | FileCheck %s
+
+define i64 @test(i64 %a) {
+; CHECK-LABEL: 'test'
+; CHECK-NEXT:  Classifying expressions for: @test
+; CHECK-NEXT:    %add = shl i64 %a, 8
+; CHECK-NEXT:    --> (256 * %a) U: [0,-255) S: [-9223372036854775808,9223372036854775553)
+; CHECK-NEXT:    %shl = add i64 %add, 256
+; CHECK-NEXT:    --> (256 + (256 * %a)) U: [0,-255) S: [-9223372036854775808,9223372036854775553)
+; CHECK-NEXT:    %ashr = ashr exact i64 %shl, 8
+; CHECK-NEXT:    --> (sext i56 (1 + (trunc i64 %a to i56)) to i64) U: [-36028797018963968,36028797018963968) S: [-36028797018963968,36028797018963968)
+; CHECK-NEXT:  Determining loop execution counts for: @test
+;
+  %add = shl i64 %a, 8
+  %shl = add i64 %add, 256
+  %ashr = ashr exact i64 %shl, 8
+  ret i64 %ashr
+}

diff  --git a/llvm/test/Transforms/LoopStrengthReduce/scaling-factor-incompat-type.ll b/llvm/test/Transforms/LoopStrengthReduce/scaling-factor-incompat-type.ll
index 8b271cee7cbd73..8cf4f8e9c129f4 100644
--- a/llvm/test/Transforms/LoopStrengthReduce/scaling-factor-incompat-type.ll
+++ b/llvm/test/Transforms/LoopStrengthReduce/scaling-factor-incompat-type.ll
@@ -4,7 +4,6 @@
 ; see pr42770
 ; REQUIRES: asserts
 ; RUN: opt < %s -loop-reduce -S | FileCheck %s
-
 target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128-ni:1"
 
 define void @foo() {
@@ -12,23 +11,23 @@ define void @foo() {
 ; CHECK-NEXT:  bb:
 ; CHECK-NEXT:    br label [[BB4:%.*]]
 ; CHECK:       bb1:
-; CHECK-NEXT:    [[T3:%.*]] = ashr i64 [[LSR_IV_NEXT:%.*]], 32
+; CHECK-NEXT:    [[T:%.*]] = shl i64 [[T14:%.*]], 32
+; CHECK-NEXT:    [[T2:%.*]] = add i64 [[T]], 1
+; CHECK-NEXT:    [[T3:%.*]] = ashr i64 [[T2]], 32
 ; CHECK-NEXT:    ret void
 ; CHECK:       bb4:
-; CHECK-NEXT:    [[LSR_IV1:%.*]] = phi i16 [ [[LSR_IV_NEXT2:%.*]], [[BB13:%.*]] ], [ 6, [[BB:%.*]] ]
-; CHECK-NEXT:    [[LSR_IV:%.*]] = phi i64 [ [[LSR_IV_NEXT]], [[BB13]] ], [ 8589934593, [[BB]] ]
-; CHECK-NEXT:    [[T5:%.*]] = phi i64 [ 2, [[BB]] ], [ [[T14:%.*]], [[BB13]] ]
+; CHECK-NEXT:    [[LSR_IV:%.*]] = phi i16 [ [[LSR_IV_NEXT:%.*]], [[BB13:%.*]] ], [ 6, [[BB:%.*]] ]
+; CHECK-NEXT:    [[T5:%.*]] = phi i64 [ 2, [[BB]] ], [ [[T14]], [[BB13]] ]
 ; CHECK-NEXT:    [[T6:%.*]] = add i64 [[T5]], 4
 ; CHECK-NEXT:    [[T7:%.*]] = trunc i64 [[T6]] to i16
 ; CHECK-NEXT:    [[T8:%.*]] = urem i16 [[T7]], 3
 ; CHECK-NEXT:    [[T9:%.*]] = mul i16 [[T8]], 2
-; CHECK-NEXT:    [[LSR_IV_NEXT]] = add nuw nsw i64 [[LSR_IV]], 25769803776
-; CHECK-NEXT:    [[LSR_IV_NEXT2]] = add nuw nsw i16 [[LSR_IV1]], 6
+; CHECK-NEXT:    [[LSR_IV_NEXT]] = add nuw nsw i16 [[LSR_IV]], 6
 ; CHECK-NEXT:    [[T14]] = add nuw nsw i64 [[T5]], 6
 ; CHECK-NEXT:    [[T10:%.*]] = icmp eq i16 [[T9]], 1
 ; CHECK-NEXT:    br i1 [[T10]], label [[BB11:%.*]], label [[BB13]]
 ; CHECK:       bb11:
-; CHECK-NEXT:    [[T12:%.*]] = udiv i16 1, [[LSR_IV1]]
+; CHECK-NEXT:    [[T12:%.*]] = udiv i16 1, [[LSR_IV]]
 ; CHECK-NEXT:    unreachable
 ; CHECK:       bb13:
 ; CHECK-NEXT:    br i1 true, label [[BB1:%.*]], label [[BB4]]


        


More information about the llvm-commits mailing list