[llvm] 7b8fd8f - [LLVM][SCEV] Look through common vscale multiplicand when simplifying compares. (#141798)

via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 19 04:57:17 PDT 2025


Author: Paul Walker
Date: 2025-09-19T12:57:13+01:00
New Revision: 7b8fd8f31bc6d65a59b6e09ebbeb77fdfb95360f

URL: https://github.com/llvm/llvm-project/commit/7b8fd8f31bc6d65a59b6e09ebbeb77fdfb95360f
DIFF: https://github.com/llvm/llvm-project/commit/7b8fd8f31bc6d65a59b6e09ebbeb77fdfb95360f.diff

LOG: [LLVM][SCEV] Look through common vscale multiplicand when simplifying compares. (#141798)

My usecase is simplifying the control flow generated by LoopVectorize
when vectorising loops whose tripcount is a function of the runtime
vector length. This can be problematic because:

* CSE is a pre-LoopVectorize transform and so it's common for an IR
function to include several calls to llvm.vscale(). (NOTE: Code
generation will typically remove the duplicates)
* Pre-LoopVectorize instcombines will rewrite some multiplies as shifts.
This leads to a mismatch between VL based maths of the scalar loop and
that created for the vector loop, which prevents some obvious
simplifications.

SCEV does not suffer these issues because it effectively does CSE during
construction and shifts are represented as multiplies.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/test/Transforms/LoopVectorize/AArch64/sve-vscale-based-trip-counts.ll
    llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index 75b2bc1c067ec..7a45ae93b185b 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -183,7 +183,8 @@ m_scev_PtrToInt(const Op0_t &Op0) {
 }
 
 /// Match a binary SCEV.
-template <typename SCEVTy, typename Op0_t, typename Op1_t>
+template <typename SCEVTy, typename Op0_t, typename Op1_t,
+          bool Commutable = false>
 struct SCEVBinaryExpr_match {
   Op0_t Op0;
   Op1_t Op1;
@@ -192,15 +193,18 @@ struct SCEVBinaryExpr_match {
 
   bool match(const SCEV *S) const {
     auto *E = dyn_cast<SCEVTy>(S);
-    return E && E->getNumOperands() == 2 && Op0.match(E->getOperand(0)) &&
-           Op1.match(E->getOperand(1));
+    return E && E->getNumOperands() == 2 &&
+           ((Op0.match(E->getOperand(0)) && Op1.match(E->getOperand(1))) ||
+            (Commutable && Op0.match(E->getOperand(1)) &&
+             Op1.match(E->getOperand(0))));
   }
 };
 
-template <typename SCEVTy, typename Op0_t, typename Op1_t>
-inline SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t>
+template <typename SCEVTy, typename Op0_t, typename Op1_t,
+          bool Commutable = false>
+inline SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t, Commutable>
 m_scev_Binary(const Op0_t &Op0, const Op1_t &Op1) {
-  return SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t>(Op0, Op1);
+  return SCEVBinaryExpr_match<SCEVTy, Op0_t, Op1_t, Commutable>(Op0, Op1);
 }
 
 template <typename Op0_t, typename Op1_t>
@@ -215,6 +219,12 @@ m_scev_Mul(const Op0_t &Op0, const Op1_t &Op1) {
   return m_scev_Binary<SCEVMulExpr>(Op0, Op1);
 }
 
+template <typename Op0_t, typename Op1_t>
+inline SCEVBinaryExpr_match<SCEVMulExpr, Op0_t, Op1_t, true>
+m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1) {
+  return m_scev_Binary<SCEVMulExpr, Op0_t, Op1_t, true>(Op0, Op1);
+}
+
 template <typename Op0_t, typename Op1_t>
 inline SCEVBinaryExpr_match<SCEVUDivExpr, Op0_t, Op1_t>
 m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 079d7da5750b6..b08399b381f34 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -10810,6 +10810,25 @@ bool ScalarEvolution::SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS,
   if (Depth >= 3)
     return false;
 
+  const SCEV *NewLHS, *NewRHS;
+  if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
+      match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
+    const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
+    const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
+
+    // (X * vscale) pred (Y * vscale) ==> X pred Y
+    //     when both multiples are NSW.
+    // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
+    //     when both multiples are NUW.
+    if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
+        (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
+         !ICmpInst::isSigned(Pred))) {
+      LHS = NewLHS;
+      RHS = NewRHS;
+      Changed = true;
+    }
+  }
+
   // Canonicalize a constant to the right side.
   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
     // Check for both operands constant.
@@ -10984,7 +11003,7 @@ bool ScalarEvolution::SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS,
   // Recursively simplify until we either hit a recursion limit or nothing
   // changes.
   if (Changed)
-    return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
+    (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
 
   return Changed;
 }

diff  --git a/llvm/test/Transforms/LoopVectorize/AArch64/sve-vscale-based-trip-counts.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sve-vscale-based-trip-counts.ll
index 643ff29136d22..df6431c1ee282 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/sve-vscale-based-trip-counts.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/sve-vscale-based-trip-counts.ll
@@ -9,14 +9,14 @@ define void @vscale_mul_4(ptr noalias noundef readonly captures(none) %a, ptr no
 ; CHECK-NEXT:  [[ENTRY:.*]]:
 ; CHECK-NEXT:    [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64()
 ; CHECK-NEXT:    [[TMP1:%.*]] = shl nuw nsw i64 [[TMP0]], 2
-; CHECK-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
-; CHECK-NEXT:    [[TMP5:%.*]] = mul nuw i64 [[TMP4]], 4
-; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[TMP1]], [[TMP5]]
+; CHECK-NEXT:    [[TMP10:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NEXT:    [[TMP3:%.*]] = mul nuw i64 [[TMP10]], 4
+; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[TMP1]], [[TMP3]]
 ; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 [[TMP1]], [[N_MOD_VF]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x float>, ptr [[A]], align 4
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 4 x float>, ptr [[B]], align 4
-; CHECK-NEXT:    [[TMP10:%.*]] = fmul <vscale x 4 x float> [[WIDE_LOAD]], [[WIDE_LOAD1]]
-; CHECK-NEXT:    store <vscale x 4 x float> [[TMP10]], ptr [[B]], align 4
+; CHECK-NEXT:    [[TMP4:%.*]] = fmul <vscale x 4 x float> [[WIDE_LOAD]], [[WIDE_LOAD1]]
+; CHECK-NEXT:    store <vscale x 4 x float> [[TMP4]], ptr [[B]], align 4
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[TMP1]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[FOR_BODY:.*]]
 ; CHECK:       [[FOR_COND_CLEANUP]]:
@@ -121,36 +121,29 @@ define void @vscale_mul_12(ptr noalias noundef readonly captures(none) %a, ptr n
 ; CHECK-NEXT:  [[ENTRY:.*]]:
 ; CHECK-NEXT:    [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64()
 ; CHECK-NEXT:    [[MUL1:%.*]] = mul nuw nsw i64 [[TMP0]], 12
-; CHECK-NEXT:    [[TMP2:%.*]] = shl nuw nsw i64 [[TMP0]], 2
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[MUL1]], [[TMP2]]
-; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
-; CHECK:       [[VECTOR_PH]]:
 ; CHECK-NEXT:    [[TMP3:%.*]] = call i64 @llvm.vscale.i64()
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw i64 [[TMP3]], 4
 ; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[MUL1]], [[TMP4]]
 ; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 [[MUL1]], [[N_MOD_VF]]
 ; CHECK-NEXT:    br label %[[VECTOR_BODY:.*]]
 ; CHECK:       [[VECTOR_BODY]]:
-; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x float>, ptr [[TMP7]], align 4
-; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDEX]]
-; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 4 x float>, ptr [[TMP9]], align 4
-; CHECK-NEXT:    [[TMP11:%.*]] = fmul <vscale x 4 x float> [[WIDE_LOAD]], [[WIDE_LOAD1]]
-; CHECK-NEXT:    store <vscale x 4 x float> [[TMP11]], ptr [[TMP9]], align 4
+; CHECK-NEXT:    [[WIDE_LOAD2:%.*]] = load <vscale x 4 x float>, ptr [[TMP7]], align 4
+; CHECK-NEXT:    [[TMP12:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDEX]]
+; CHECK-NEXT:    [[WIDE_LOAD4:%.*]] = load <vscale x 4 x float>, ptr [[TMP12]], align 4
+; CHECK-NEXT:    [[TMP25:%.*]] = fmul <vscale x 4 x float> [[WIDE_LOAD2]], [[WIDE_LOAD4]]
+; CHECK-NEXT:    store <vscale x 4 x float> [[TMP25]], ptr [[TMP12]], align 4
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP4]]
-; CHECK-NEXT:    [[TMP12:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[TMP12]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
+; CHECK-NEXT:    [[TMP22:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEXT:    br i1 [[TMP22]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
 ; CHECK:       [[MIDDLE_BLOCK]]:
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[MUL1]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[SCALAR_PH]]
-; CHECK:       [[SCALAR_PH]]:
-; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
-; CHECK-NEXT:    br label %[[FOR_BODY:.*]]
+; CHECK-NEXT:    br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[FOR_BODY:.*]]
 ; CHECK:       [[FOR_COND_CLEANUP]]:
 ; CHECK-NEXT:    ret void
 ; CHECK:       [[FOR_BODY]]:
-; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ]
+; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ], [ [[N_VEC]], %[[MIDDLE_BLOCK]] ]
 ; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDVARS_IV]]
 ; CHECK-NEXT:    [[TMP13:%.*]] = load float, ptr [[ARRAYIDX]], align 4
 ; CHECK-NEXT:    [[ARRAYIDX4:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDVARS_IV]]
@@ -188,17 +181,13 @@ define void @vscale_mul_31(ptr noalias noundef readonly captures(none) %a, ptr n
 ; CHECK-NEXT:  [[ENTRY:.*]]:
 ; CHECK-NEXT:    [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64()
 ; CHECK-NEXT:    [[MUL1:%.*]] = mul nuw nsw i64 [[TMP0]], 31
-; CHECK-NEXT:    [[TMP2:%.*]] = shl nuw nsw i64 [[TMP0]], 3
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[MUL1]], [[TMP2]]
-; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
-; CHECK:       [[VECTOR_PH]]:
 ; CHECK-NEXT:    [[TMP3:%.*]] = call i64 @llvm.vscale.i64()
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw i64 [[TMP3]], 8
 ; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[MUL1]], [[TMP4]]
 ; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 [[MUL1]], [[N_MOD_VF]]
 ; CHECK-NEXT:    br label %[[VECTOR_BODY:.*]]
 ; CHECK:       [[VECTOR_BODY]]:
-; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDEX]]
 ; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vscale.i64()
 ; CHECK-NEXT:    [[TMP10:%.*]] = shl nuw i64 [[TMP9]], 2
@@ -220,14 +209,11 @@ define void @vscale_mul_31(ptr noalias noundef readonly captures(none) %a, ptr n
 ; CHECK-NEXT:    br i1 [[TMP22]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
 ; CHECK:       [[MIDDLE_BLOCK]]:
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[MUL1]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[SCALAR_PH]]
-; CHECK:       [[SCALAR_PH]]:
-; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
-; CHECK-NEXT:    br label %[[FOR_BODY:.*]]
+; CHECK-NEXT:    br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[FOR_BODY:.*]]
 ; CHECK:       [[FOR_COND_CLEANUP]]:
 ; CHECK-NEXT:    ret void
 ; CHECK:       [[FOR_BODY]]:
-; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ]
+; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ], [ [[N_VEC]], %[[MIDDLE_BLOCK]] ]
 ; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDVARS_IV]]
 ; CHECK-NEXT:    [[TMP23:%.*]] = load float, ptr [[ARRAYIDX]], align 4
 ; CHECK-NEXT:    [[ARRAYIDX4:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDVARS_IV]]
@@ -265,17 +251,13 @@ define void @vscale_mul_64(ptr noalias noundef readonly captures(none) %a, ptr n
 ; CHECK-NEXT:  [[ENTRY:.*]]:
 ; CHECK-NEXT:    [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64()
 ; CHECK-NEXT:    [[MUL1:%.*]] = mul nuw nsw i64 [[TMP0]], 64
-; CHECK-NEXT:    [[TMP2:%.*]] = shl nuw nsw i64 [[TMP0]], 3
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[MUL1]], [[TMP2]]
-; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
-; CHECK:       [[VECTOR_PH]]:
 ; CHECK-NEXT:    [[TMP3:%.*]] = call i64 @llvm.vscale.i64()
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw i64 [[TMP3]], 8
 ; CHECK-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[MUL1]], [[TMP4]]
 ; CHECK-NEXT:    [[N_VEC:%.*]] = sub i64 [[MUL1]], [[N_MOD_VF]]
 ; CHECK-NEXT:    br label %[[VECTOR_BODY:.*]]
 ; CHECK:       [[VECTOR_BODY]]:
-; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, %[[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT:%.*]], %[[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDEX]]
 ; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vscale.i64()
 ; CHECK-NEXT:    [[TMP10:%.*]] = shl nuw i64 [[TMP9]], 2
@@ -297,14 +279,11 @@ define void @vscale_mul_64(ptr noalias noundef readonly captures(none) %a, ptr n
 ; CHECK-NEXT:    br i1 [[TMP22]], label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
 ; CHECK:       [[MIDDLE_BLOCK]]:
 ; CHECK-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[MUL1]], [[N_VEC]]
-; CHECK-NEXT:    br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[SCALAR_PH]]
-; CHECK:       [[SCALAR_PH]]:
-; CHECK-NEXT:    [[BC_RESUME_VAL:%.*]] = phi i64 [ [[N_VEC]], %[[MIDDLE_BLOCK]] ], [ 0, %[[ENTRY]] ]
-; CHECK-NEXT:    br label %[[FOR_BODY:.*]]
+; CHECK-NEXT:    br i1 [[CMP_N]], label %[[FOR_COND_CLEANUP:.*]], label %[[FOR_BODY:.*]]
 ; CHECK:       [[FOR_COND_CLEANUP]]:
 ; CHECK-NEXT:    ret void
 ; CHECK:       [[FOR_BODY]]:
-; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], %[[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ]
+; CHECK-NEXT:    [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], %[[FOR_BODY]] ], [ [[N_VEC]], %[[MIDDLE_BLOCK]] ]
 ; CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds nuw float, ptr [[A]], i64 [[INDVARS_IV]]
 ; CHECK-NEXT:    [[TMP23:%.*]] = load float, ptr [[ARRAYIDX]], align 4
 ; CHECK-NEXT:    [[ARRAYIDX4:%.*]] = getelementptr inbounds nuw float, ptr [[B]], i64 [[INDVARS_IV]]

diff  --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index 678960418d7d7..1a68823b4f254 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1768,4 +1768,141 @@ TEST_F(ScalarEvolutionsTest, ComplexityComparatorIsStrictWeakOrdering3) {
   SE.getSCEV(Or1);
 }
 
+TEST_F(ScalarEvolutionsTest, SimplifyICmpOperands) {
+  LLVMContext C;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M =
+      parseAssemblyString("define i32 @foo(ptr %loc, i32 %a, i32 %b) {"
+                          "entry: "
+                          "  ret i32 %a "
+                          "} ",
+                          Err, C);
+
+  ASSERT_TRUE(M && "Could not parse module?");
+  ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
+
+  // Remove common factor when there's no signed wrapping.
+  runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+    const SCEV *A = SE.getSCEV(getArgByName(F, "a"));
+    const SCEV *B = SE.getSCEV(getArgByName(F, "b"));
+    const SCEV *VS = SE.getVScale(A->getType());
+    const SCEV *VSxA = SE.getMulExpr(VS, A, SCEV::FlagNSW);
+    const SCEV *VSxB = SE.getMulExpr(VS, B, SCEV::FlagNSW);
+
+    {
+      CmpPredicate NewPred = ICmpInst::ICMP_SLT;
+      const SCEV *NewLHS = VSxA;
+      const SCEV *NewRHS = VSxB;
+      EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
+      EXPECT_EQ(NewPred, ICmpInst::ICMP_SLT);
+      EXPECT_EQ(NewLHS, A);
+      EXPECT_EQ(NewRHS, B);
+    }
+
+    {
+      CmpPredicate NewPred = ICmpInst::ICMP_ULT;
+      const SCEV *NewLHS = VSxA;
+      const SCEV *NewRHS = VSxB;
+      EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
+      EXPECT_EQ(NewPred, ICmpInst::ICMP_ULT);
+      EXPECT_EQ(NewLHS, A);
+      EXPECT_EQ(NewRHS, B);
+    }
+
+    {
+      CmpPredicate NewPred = ICmpInst::ICMP_EQ;
+      const SCEV *NewLHS = VSxA;
+      const SCEV *NewRHS = VSxB;
+      EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
+      EXPECT_EQ(NewPred, ICmpInst::ICMP_EQ);
+      EXPECT_EQ(NewLHS, A);
+      EXPECT_EQ(NewRHS, B);
+    }
+
+    // Verify the common factor's position doesn't impede simplification.
+    {
+      const SCEV *C = SE.getConstant(A->getType(), 100);
+      const SCEV *CxVS = SE.getMulExpr(C, VS, SCEV::FlagNSW);
+
+      // Verify common factor is available at 
diff erent indices.
+      ASSERT_TRUE(isa<SCEVVScale>(cast<SCEVMulExpr>(VSxA)->getOperand(0)) !=
+                  isa<SCEVVScale>(cast<SCEVMulExpr>(CxVS)->getOperand(0)));
+
+      CmpPredicate NewPred = ICmpInst::ICMP_SLT;
+      const SCEV *NewLHS = VSxA;
+      const SCEV *NewRHS = CxVS;
+      EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
+      EXPECT_EQ(NewPred, ICmpInst::ICMP_SLT);
+      EXPECT_EQ(NewLHS, A);
+      EXPECT_EQ(NewRHS, C);
+    }
+  });
+
+  // Remove common factor when there's no unsigned wrapping.
+  runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+    const SCEV *A = SE.getSCEV(getArgByName(F, "a"));
+    const SCEV *B = SE.getSCEV(getArgByName(F, "b"));
+    const SCEV *VS = SE.getVScale(A->getType());
+    const SCEV *VSxA = SE.getMulExpr(VS, A, SCEV::FlagNUW);
+    const SCEV *VSxB = SE.getMulExpr(VS, B, SCEV::FlagNUW);
+
+    {
+      CmpPredicate NewPred = ICmpInst::ICMP_SLT;
+      const SCEV *NewLHS = VSxA;
+      const SCEV *NewRHS = VSxB;
+      EXPECT_FALSE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
+    }
+
+    {
+      CmpPredicate NewPred = ICmpInst::ICMP_ULT;
+      const SCEV *NewLHS = VSxA;
+      const SCEV *NewRHS = VSxB;
+      EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
+      EXPECT_EQ(NewPred, ICmpInst::ICMP_ULT);
+      EXPECT_EQ(NewLHS, A);
+      EXPECT_EQ(NewRHS, B);
+    }
+
+    {
+      CmpPredicate NewPred = ICmpInst::ICMP_EQ;
+      const SCEV *NewLHS = VSxA;
+      const SCEV *NewRHS = VSxB;
+      EXPECT_TRUE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
+      EXPECT_EQ(NewPred, ICmpInst::ICMP_EQ);
+      EXPECT_EQ(NewLHS, A);
+      EXPECT_EQ(NewRHS, B);
+    }
+  });
+
+  // Do not remove common factor due to wrap flag mismatch.
+  runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+    const SCEV *A = SE.getSCEV(getArgByName(F, "a"));
+    const SCEV *B = SE.getSCEV(getArgByName(F, "b"));
+    const SCEV *VS = SE.getVScale(A->getType());
+    const SCEV *VSxA = SE.getMulExpr(VS, A, SCEV::FlagNSW);
+    const SCEV *VSxB = SE.getMulExpr(VS, B, SCEV::FlagNUW);
+
+    {
+      CmpPredicate NewPred = ICmpInst::ICMP_SLT;
+      const SCEV *NewLHS = VSxA;
+      const SCEV *NewRHS = VSxB;
+      EXPECT_FALSE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
+    }
+
+    {
+      CmpPredicate NewPred = ICmpInst::ICMP_ULT;
+      const SCEV *NewLHS = VSxA;
+      const SCEV *NewRHS = VSxB;
+      EXPECT_FALSE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
+    }
+
+    {
+      CmpPredicate NewPred = ICmpInst::ICMP_EQ;
+      const SCEV *NewLHS = VSxA;
+      const SCEV *NewRHS = VSxB;
+      EXPECT_FALSE(SE.SimplifyICmpOperands(NewPred, NewLHS, NewRHS));
+    }
+  });
+}
+
 }  // end namespace llvm


        


More information about the llvm-commits mailing list