[llvm] 575e2af - [VectorCombine] Use constant range info for index scalarization legality.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Tue May 25 05:59:04 PDT 2021


Author: Florian Hahn
Date: 2021-05-25T13:58:42+01:00
New Revision: 575e2aff5574550d10278d9a41fca2926a5b8409

URL: https://github.com/llvm/llvm-project/commit/575e2aff5574550d10278d9a41fca2926a5b8409
DIFF: https://github.com/llvm/llvm-project/commit/575e2aff5574550d10278d9a41fca2926a5b8409.diff

LOG: [VectorCombine] Use constant range info for index scalarization legality.

We can only scalarize memory accesses if we know the index is valid.

This patch adjusts canScalarizeAcceess to fall back to
computeConstantRange to check if the index is known to be valid.

Reviewed By: nlopes

Differential Revision: https://reviews.llvm.org/D102476

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/VectorCombine.cpp
    llvm/test/Transforms/PhaseOrdering/AArch64/matrix-extract-insert.ll
    llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
    llvm/test/Transforms/VectorCombine/load-insert-store.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 39130325df2e..50916d0fb3e2 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -14,6 +14,7 @@
 
 #include "llvm/Transforms/Vectorize/VectorCombine.h"
 #include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/BasicAliasAnalysis.h"
 #include "llvm/Analysis/GlobalsModRef.h"
 #include "llvm/Analysis/Loads.h"
@@ -60,8 +61,8 @@ namespace {
 class VectorCombine {
 public:
   VectorCombine(Function &F, const TargetTransformInfo &TTI,
-                const DominatorTree &DT, AAResults &AA)
-      : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA) {}
+                const DominatorTree &DT, AAResults &AA, AssumptionCache &AC)
+      : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC) {}
 
   bool run();
 
@@ -71,6 +72,7 @@ class VectorCombine {
   const TargetTransformInfo &TTI;
   const DominatorTree &DT;
   AAResults &AA;
+  AssumptionCache ∾
 
   bool vectorizeLoadInsert(Instruction &I);
   ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
@@ -774,8 +776,16 @@ static bool isMemModifiedBetween(BasicBlock::iterator Begin,
 
 /// Check if it is legal to scalarize a memory access to \p VecTy at index \p
 /// Idx. \p Idx must access a valid vector element.
-static bool canScalarizeAccess(FixedVectorType *VecTy, ConstantInt *Idx) {
-  return Idx->getValue().ult(VecTy->getNumElements());
+static bool canScalarizeAccess(FixedVectorType *VecTy, Value *Idx,
+                               Instruction *CtxI, AssumptionCache &AC) {
+  if (auto *C = dyn_cast<ConstantInt>(Idx))
+    return C->getValue().ult(VecTy->getNumElements());
+
+  APInt Zero(Idx->getType()->getScalarSizeInBits(), 0);
+  APInt MaxElts(Idx->getType()->getScalarSizeInBits(), VecTy->getNumElements());
+  ConstantRange ValidIndices(Zero, MaxElts);
+  ConstantRange IdxRange = computeConstantRange(Idx, true, &AC, CtxI, 0);
+  return ValidIndices.contains(IdxRange);
 }
 
 // Combine patterns like:
@@ -796,10 +806,10 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
   // TargetTransformInfo.
   Instruction *Source;
   Value *NewElement;
-  ConstantInt *Idx;
+  Value *Idx;
   if (!match(SI->getValueOperand(),
              m_InsertElt(m_Instruction(Source), m_Value(NewElement),
-                         m_ConstantInt(Idx))))
+                         m_Value(Idx))))
     return false;
 
   if (auto *Load = dyn_cast<LoadInst>(Source)) {
@@ -810,7 +820,7 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
     // modified between, vector type matches store size, and index is inbounds.
     if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
         !DL.typeSizeEqualsStoreSize(Load->getType()) ||
-        !canScalarizeAccess(VecTy, Idx) ||
+        !canScalarizeAccess(VecTy, Idx, Load, AC) ||
         SrcAddr != SI->getPointerOperand()->stripPointerCasts() ||
         isMemModifiedBetween(Load->getIterator(), SI->getIterator(),
                              MemoryLocation::get(SI), AA))
@@ -835,8 +845,8 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
 /// Try to scalarize vector loads feeding extractelement instructions.
 bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   Value *Ptr;
-  ConstantInt *Idx;
-  if (!match(&I, m_ExtractElt(m_Load(m_Value(Ptr)), m_ConstantInt(Idx))))
+  Value *Idx;
+  if (!match(&I, m_ExtractElt(m_Load(m_Value(Ptr)), m_Value(Idx))))
     return false;
 
   auto *LI = cast<LoadInst>(I.getOperand(0));
@@ -848,7 +858,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   if (!FixedVT)
     return false;
 
-  if (!canScalarizeAccess(FixedVT, Idx))
+  if (!canScalarizeAccess(FixedVT, Idx, &I, AC))
     return false;
 
   InstructionCost OriginalCost = TTI.getMemoryOpCost(
@@ -971,6 +981,7 @@ class VectorCombineLegacyPass : public FunctionPass {
   }
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<AssumptionCacheTracker>();
     AU.addRequired<DominatorTreeWrapperPass>();
     AU.addRequired<TargetTransformInfoWrapperPass>();
     AU.addRequired<AAResultsWrapperPass>();
@@ -985,10 +996,11 @@ class VectorCombineLegacyPass : public FunctionPass {
   bool runOnFunction(Function &F) override {
     if (skipFunction(F))
       return false;
+    auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
     auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
-    VectorCombine Combiner(F, TTI, DT, AA);
+    VectorCombine Combiner(F, TTI, DT, AA, AC);
     return Combiner.run();
   }
 };
@@ -998,6 +1010,7 @@ char VectorCombineLegacyPass::ID = 0;
 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
                       "Optimize scalar/vector ops", false,
                       false)
+INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
                     "Optimize scalar/vector ops", false, false)
@@ -1007,10 +1020,11 @@ Pass *llvm::createVectorCombinePass() {
 
 PreservedAnalyses VectorCombinePass::run(Function &F,
                                          FunctionAnalysisManager &FAM) {
+  auto &AC = FAM.getResult<AssumptionAnalysis>(F);
   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
   AAResults &AA = FAM.getResult<AAManager>(F);
-  VectorCombine Combiner(F, TTI, DT, AA);
+  VectorCombine Combiner(F, TTI, DT, AA, AC);
   if (!Combiner.run())
     return PreservedAnalyses::all();
   PreservedAnalyses PA;

diff  --git a/llvm/test/Transforms/PhaseOrdering/AArch64/matrix-extract-insert.ll b/llvm/test/Transforms/PhaseOrdering/AArch64/matrix-extract-insert.ll
index 31d8a1680887..1089f54a0526 100644
--- a/llvm/test/Transforms/PhaseOrdering/AArch64/matrix-extract-insert.ll
+++ b/llvm/test/Transforms/PhaseOrdering/AArch64/matrix-extract-insert.ll
@@ -13,8 +13,8 @@ define void @matrix_extract_insert_scalar(i32 %i, i32 %k, i32 %j, [225 x double]
 ; CHECK-NEXT:    [[TMP2:%.*]] = icmp ult i64 [[TMP1]], 225
 ; CHECK-NEXT:    tail call void @llvm.assume(i1 [[TMP2]])
 ; CHECK-NEXT:    [[TMP3:%.*]] = bitcast [225 x double]* [[A:%.*]] to <225 x double>*
-; CHECK-NEXT:    [[TMP4:%.*]] = load <225 x double>, <225 x double>* [[TMP3]], align 8
-; CHECK-NEXT:    [[MATRIXEXT:%.*]] = extractelement <225 x double> [[TMP4]], i64 [[TMP1]]
+; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds <225 x double>, <225 x double>* [[TMP3]], i64 0, i64 [[TMP1]]
+; CHECK-NEXT:    [[MATRIXEXT:%.*]] = load double, double* [[TMP4]], align 8
 ; CHECK-NEXT:    [[CONV2:%.*]] = zext i32 [[I:%.*]] to i64
 ; CHECK-NEXT:    [[TMP5:%.*]] = add nuw nsw i64 [[TMP0]], [[CONV2]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = icmp ult i64 [[TMP5]], 225
@@ -25,8 +25,8 @@ define void @matrix_extract_insert_scalar(i32 %i, i32 %k, i32 %j, [225 x double]
 ; CHECK-NEXT:    [[MUL:%.*]] = fmul double [[MATRIXEXT]], [[MATRIXEXT4]]
 ; CHECK-NEXT:    [[MATRIXEXT7:%.*]] = extractelement <225 x double> [[TMP8]], i64 [[TMP1]]
 ; CHECK-NEXT:    [[SUB:%.*]] = fsub double [[MATRIXEXT7]], [[MUL]]
-; CHECK-NEXT:    [[MATINS:%.*]] = insertelement <225 x double> [[TMP8]], double [[SUB]], i64 [[TMP1]]
-; CHECK-NEXT:    store <225 x double> [[MATINS]], <225 x double>* [[TMP7]], align 8
+; CHECK-NEXT:    [[TMP9:%.*]] = getelementptr inbounds <225 x double>, <225 x double>* [[TMP7]], i64 0, i64 [[TMP1]]
+; CHECK-NEXT:    store double [[SUB]], double* [[TMP9]], align 8
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -112,8 +112,8 @@ define void @matrix_extract_insert_loop(i32 %i, [225 x double]* nonnull align 8
 ; CHECK-NEXT:    [[TMP6:%.*]] = add nuw nsw i64 [[TMP2]], [[CONV_US]]
 ; CHECK-NEXT:    [[TMP7:%.*]] = icmp ult i64 [[TMP6]], 225
 ; CHECK-NEXT:    tail call void @llvm.assume(i1 [[TMP7]])
-; CHECK-NEXT:    [[TMP8:%.*]] = load <225 x double>, <225 x double>* [[TMP0]], align 8
-; CHECK-NEXT:    [[MATRIXEXT_US:%.*]] = extractelement <225 x double> [[TMP8]], i64 [[TMP6]]
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds <225 x double>, <225 x double>* [[TMP0]], i64 0, i64 [[TMP6]]
+; CHECK-NEXT:    [[MATRIXEXT_US:%.*]] = load double, double* [[TMP8]], align 8
 ; CHECK-NEXT:    [[MATRIXEXT8_US:%.*]] = extractelement <225 x double> [[TMP5]], i64 [[TMP3]]
 ; CHECK-NEXT:    [[MUL_US:%.*]] = fmul double [[MATRIXEXT_US]], [[MATRIXEXT8_US]]
 ; CHECK-NEXT:    [[MATRIXEXT11_US:%.*]] = extractelement <225 x double> [[TMP5]], i64 [[TMP6]]

diff  --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
index 7179909701ee..4ffbae6265b1 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
@@ -1,5 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt -vector-combine -mtriple=arm64-apple-darwinos -S %s | FileCheck --check-prefixes=CHECK,LIMIT-DEFAULT %s
+; RUN: opt -vector-combine -enable-new-pm=false -mtriple=arm64-apple-darwinos -S %s | FileCheck --check-prefixes=CHECK,LIMIT-DEFAULT %s
 ; RUN: opt -vector-combine -mtriple=arm64-apple-darwinos -vector-combine-max-scan-instrs=2 -S %s | FileCheck --check-prefixes=CHECK,LIMIT2 %s
 
 define i32 @load_extract_idx_0(<4 x i32>* %x) {
@@ -90,9 +91,9 @@ define i32 @load_extract_idx_var_i64_known_valid_by_assume(<4 x i32>* %x, i64 %i
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[IDX:%.*]], 4
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
-; CHECK-NEXT:    [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16
 ; CHECK-NEXT:    call void @maythrow()
-; CHECK-NEXT:    [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX]]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i64 [[IDX]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, i32* [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:
@@ -146,8 +147,8 @@ define i32 @load_extract_idx_var_i64_known_valid_by_and(<4 x i32>* %x, i64 %idx)
 ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_and(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX:%.*]], 3
-; CHECK-NEXT:    [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, i32* [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:
@@ -176,8 +177,8 @@ define i32 @load_extract_idx_var_i64_known_valid_by_urem(<4 x i32>* %x, i64 %idx
 ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_urem(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX:%.*]], 4
-; CHECK-NEXT:    [[LV:%.*]] = load <4 x i32>, <4 x i32>* [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, <4 x i32>* [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, i32* [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:

diff  --git a/llvm/test/Transforms/VectorCombine/load-insert-store.ll b/llvm/test/Transforms/VectorCombine/load-insert-store.ll
index 611d66978019..1b43834e4805 100644
--- a/llvm/test/Transforms/VectorCombine/load-insert-store.ll
+++ b/llvm/test/Transforms/VectorCombine/load-insert-store.ll
@@ -130,9 +130,8 @@ define void @insert_store_nonconst_index_known_valid_by_assume(<16 x i8>* %q, i8
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[IDX:%.*]], 4
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
-; CHECK-NEXT:    [[TMP0:%.*]] = load <16 x i8>, <16 x i8>* [[Q:%.*]], align 16
-; CHECK-NEXT:    [[VECINS:%.*]] = insertelement <16 x i8> [[TMP0]], i8 [[S:%.*]], i32 [[IDX]]
-; CHECK-NEXT:    store <16 x i8> [[VECINS]], <16 x i8>* [[Q]], align 16
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <16 x i8>, <16 x i8>* [[Q:%.*]], i32 0, i32 [[IDX]]
+; CHECK-NEXT:    store i8 [[S:%.*]], i8* [[TMP0]], align 1
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -191,10 +190,9 @@ declare void @llvm.assume(i1)
 define void @insert_store_nonconst_index_known_valid_by_and(<16 x i8>* %q, i8 zeroext %s, i32 %idx) {
 ; CHECK-LABEL: @insert_store_nonconst_index_known_valid_by_and(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = load <16 x i8>, <16 x i8>* [[Q:%.*]], align 16
 ; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i32 [[IDX:%.*]], 7
-; CHECK-NEXT:    [[VECINS:%.*]] = insertelement <16 x i8> [[TMP0]], i8 [[S:%.*]], i32 [[IDX_CLAMPED]]
-; CHECK-NEXT:    store <16 x i8> [[VECINS]], <16 x i8>* [[Q]], align 16
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <16 x i8>, <16 x i8>* [[Q:%.*]], i32 0, i32 [[IDX_CLAMPED]]
+; CHECK-NEXT:    store i8 [[S:%.*]], i8* [[TMP0]], align 1
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -225,10 +223,9 @@ entry:
 define void @insert_store_nonconst_index_known_valid_by_urem(<16 x i8>* %q, i8 zeroext %s, i32 %idx) {
 ; CHECK-LABEL: @insert_store_nonconst_index_known_valid_by_urem(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = load <16 x i8>, <16 x i8>* [[Q:%.*]], align 16
 ; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i32 [[IDX:%.*]], 16
-; CHECK-NEXT:    [[VECINS:%.*]] = insertelement <16 x i8> [[TMP0]], i8 [[S:%.*]], i32 [[IDX_CLAMPED]]
-; CHECK-NEXT:    store <16 x i8> [[VECINS]], <16 x i8>* [[Q]], align 16
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <16 x i8>, <16 x i8>* [[Q:%.*]], i32 0, i32 [[IDX_CLAMPED]]
+; CHECK-NEXT:    store i8 [[S:%.*]], i8* [[TMP0]], align 1
 ; CHECK-NEXT:    ret void
 ;
 entry:


        


More information about the llvm-commits mailing list