[llvm] [VectorCombine] Enable transform 'scalarizeLoadExtract' for non constant indexes (PR #65445)

Ben Shi via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 25 02:47:10 PDT 2023


https://github.com/benshi001 updated https://github.com/llvm/llvm-project/pull/65445

>From 31a1b0440d20fce15db264547f6a17bda5be50eb Mon Sep 17 00:00:00 2001
From: Ben Shi <2283975856 at qq.com>
Date: Thu, 7 Sep 2023 12:32:02 +0800
Subject: [PATCH] [VectorCombine] Enable transform 'scalarizeLoadExtract' for
 non constant indexes

Enable the transform if a non constant index is guaranteed to be safe
via a UREM/AND.
---
 .../Transforms/Vectorize/VectorCombine.cpp    | 21 +++--
 .../load-extractelement-scalarization.ll      | 82 ++++++++++++++-----
 2 files changed, 74 insertions(+), 29 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 658cbc59fe6f328..37b59f8270469e8 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -13,6 +13,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Vectorize/VectorCombine.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/BasicAliasAnalysis.h"
@@ -1253,6 +1254,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
 
   Instruction *LastCheckedInst = LI;
   unsigned NumInstChecked = 0;
+  DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
   // Check if all users of the load are extracts with no memory modifications
   // between the load and the extract. Compute the cost of both the original
   // code and the scalarized version.
@@ -1261,9 +1263,6 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
     if (!UI || UI->getParent() != LI->getParent())
       return false;
 
-    if (!isGuaranteedNotToBePoison(UI->getOperand(1), &AC, LI, &DT))
-      return false;
-
     // Check if any instruction between the load and the extract may modify
     // memory.
     if (LastCheckedInst->comesBefore(UI)) {
@@ -1279,10 +1278,11 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
     }
 
     auto ScalarIdx = canScalarizeAccess(VecTy, UI->getOperand(1), &I, AC, DT);
-    if (!ScalarIdx.isSafe()) {
-      // TODO: Freeze index if it is safe to do so.
-      ScalarIdx.discard();
+    if (ScalarIdx.isUnsafe())
       return false;
+    if (ScalarIdx.isSafeWithFreeze()) {
+      NeedFreeze.try_emplace(UI, std::move(ScalarIdx));
+      ScalarIdx.discard();
     }
 
     auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
@@ -1302,9 +1302,14 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   // Replace extracts with narrow scalar loads.
   for (User *U : LI->users()) {
     auto *EI = cast<ExtractElementInst>(U);
-    Builder.SetInsertPoint(EI);
-
     Value *Idx = EI->getOperand(1);
+
+    // Insert 'freeze' for poison indexes.
+    auto It = NeedFreeze.find(EI);
+    if (It != NeedFreeze.end())
+      It->second.freeze(Builder, *cast<Instruction>(Idx));
+
+    Builder.SetInsertPoint(EI);
     Value *GEP =
         Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
     auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
index c7e5979aa9e7bd9..f11a136a9bc59aa 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/load-extractelement-scalarization.ll
@@ -309,9 +309,10 @@ declare void @llvm.assume(i1)
 define i32 @load_extract_idx_var_i64_known_valid_by_and(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_and(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX:%.*]], 3
-; CHECK-NEXT:    [[LV:%.*]] = load <4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX_FROZEN]], 3
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:
@@ -324,9 +325,10 @@ entry:
 define i32 @vscale_load_extract_idx_var_i64_known_valid_by_and(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @vscale_load_extract_idx_var_i64_known_valid_by_and(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX:%.*]], 3
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX_FROZEN]], 3
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:
@@ -384,9 +386,10 @@ entry:
 define i32 @load_extract_idx_var_i64_known_valid_by_urem(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @load_extract_idx_var_i64_known_valid_by_urem(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX:%.*]], 4
-; CHECK-NEXT:    [[LV:%.*]] = load <4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX_FROZEN]], 4
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:
@@ -399,9 +402,10 @@ entry:
 define i32 @vscale_load_extract_idx_var_i64_known_valid_by_urem(ptr %x, i64 %idx) {
 ; CHECK-LABEL: @vscale_load_extract_idx_var_i64_known_valid_by_urem(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX:%.*]], 4
-; CHECK-NEXT:    [[LV:%.*]] = load <vscale x 4 x i32>, ptr [[X:%.*]], align 16
-; CHECK-NEXT:    [[R:%.*]] = extractelement <vscale x 4 x i32> [[LV]], i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i64 [[IDX_FROZEN]], 4
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds <vscale x 4 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; CHECK-NEXT:    [[R:%.*]] = load i32, ptr [[TMP0]], align 4
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
 entry:
@@ -789,11 +793,14 @@ define i32 @load_multiple_extracts_with_variable_indices_large_vector_only_first
 
 define i32 @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and(ptr %x, i64 %idx.0, i64 %idx.1) {
 ; CHECK-LABEL: @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and(
-; CHECK-NEXT:    [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0:%.*]], 15
-; CHECK-NEXT:    [[IDX_1_CLAMPED:%.*]] = and i64 [[IDX_1:%.*]], 15
-; CHECK-NEXT:    [[LV:%.*]] = load <16 x i32>, ptr [[X:%.*]], align 64
-; CHECK-NEXT:    [[E_0:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_0_CLAMPED]]
-; CHECK-NEXT:    [[E_1:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT:    [[IDX_0_FROZEN:%.*]] = freeze i64 [[IDX_0:%.*]]
+; CHECK-NEXT:    [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0_FROZEN]], 15
+; CHECK-NEXT:    [[IDX_1_FROZEN:%.*]] = freeze i64 [[IDX_1:%.*]]
+; CHECK-NEXT:    [[IDX_1_CLAMPED:%.*]] = and i64 [[IDX_1_FROZEN]], 15
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_0_CLAMPED]]
+; CHECK-NEXT:    [[E_0:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X]], i32 0, i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT:    [[E_1:%.*]] = load i32, ptr [[TMP2]], align 4
 ; CHECK-NEXT:    [[RES:%.*]] = add i32 [[E_0]], [[E_1]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
@@ -807,13 +814,46 @@ define i32 @load_multiple_extracts_with_variable_indices_large_vector_all_valid_
   ret i32 %res
 }
 
+define i32 @load_multiple_extracts_with_unique_variable_indices_large_vector_valid_by_and(ptr %x, ptr %y, i64 %idx) {
+; LIMIT-DEFAULT-LABEL: @load_multiple_extracts_with_unique_variable_indices_large_vector_valid_by_and(
+; LIMIT-DEFAULT-NEXT:    [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; LIMIT-DEFAULT-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX_FROZEN]], 15
+; LIMIT-DEFAULT-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; LIMIT-DEFAULT-NEXT:    [[E_0:%.*]] = load i32, ptr [[TMP1]], align 4
+; LIMIT-DEFAULT-NEXT:    [[TMP2:%.*]] = getelementptr inbounds <16 x i32>, ptr [[Y:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; LIMIT-DEFAULT-NEXT:    [[E_1:%.*]] = load i32, ptr [[TMP2]], align 4
+; LIMIT-DEFAULT-NEXT:    [[RES:%.*]] = add i32 [[E_0]], [[E_1]]
+; LIMIT-DEFAULT-NEXT:    ret i32 [[RES]]
+;
+; LIMIT2-LABEL: @load_multiple_extracts_with_unique_variable_indices_large_vector_valid_by_and(
+; LIMIT2-NEXT:    [[IDX_FROZEN:%.*]] = freeze i64 [[IDX:%.*]]
+; LIMIT2-NEXT:    [[IDX_CLAMPED:%.*]] = and i64 [[IDX_FROZEN]], 15
+; LIMIT2-NEXT:    [[LY:%.*]] = load <16 x i32>, ptr [[Y:%.*]], align 64
+; LIMIT2-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_CLAMPED]]
+; LIMIT2-NEXT:    [[E_0:%.*]] = load i32, ptr [[TMP1]], align 4
+; LIMIT2-NEXT:    [[E_1:%.*]] = extractelement <16 x i32> [[LY]], i64 [[IDX_CLAMPED]]
+; LIMIT2-NEXT:    [[RES:%.*]] = add i32 [[E_0]], [[E_1]]
+; LIMIT2-NEXT:    ret i32 [[RES]]
+;
+  %idx.clamped = and i64 %idx, 15
+
+  %lx = load <16 x i32>, ptr %x
+  %ly = load <16 x i32>, ptr %y
+  %e.0 = extractelement <16 x i32> %lx, i64 %idx.clamped
+  %e.1 = extractelement <16 x i32> %ly, i64 %idx.clamped
+  %res = add i32 %e.0, %e.1
+  ret i32 %res
+}
+
 define i32 @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and_some_noundef(ptr %x, i64 %idx.0, i64 noundef %idx.1) {
 ; CHECK-LABEL: @load_multiple_extracts_with_variable_indices_large_vector_all_valid_by_and_some_noundef(
-; CHECK-NEXT:    [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0:%.*]], 15
+; CHECK-NEXT:    [[IDX_0_FROZEN:%.*]] = freeze i64 [[IDX_0:%.*]]
+; CHECK-NEXT:    [[IDX_0_CLAMPED:%.*]] = and i64 [[IDX_0_FROZEN]], 15
 ; CHECK-NEXT:    [[IDX_1_CLAMPED:%.*]] = and i64 [[IDX_1:%.*]], 15
-; CHECK-NEXT:    [[LV:%.*]] = load <16 x i32>, ptr [[X:%.*]], align 64
-; CHECK-NEXT:    [[E_0:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_0_CLAMPED]]
-; CHECK-NEXT:    [[E_1:%.*]] = extractelement <16 x i32> [[LV]], i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X:%.*]], i32 0, i64 [[IDX_0_CLAMPED]]
+; CHECK-NEXT:    [[E_0:%.*]] = load i32, ptr [[TMP1]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds <16 x i32>, ptr [[X]], i32 0, i64 [[IDX_1_CLAMPED]]
+; CHECK-NEXT:    [[E_1:%.*]] = load i32, ptr [[TMP2]], align 4
 ; CHECK-NEXT:    [[RES:%.*]] = add i32 [[E_0]], [[E_1]]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;



More information about the llvm-commits mailing list