[llvm] c24fc37 - [VectorCombine] Support AND/UREM indices that require freezing.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 13 03:24:54 PDT 2021


Author: Florian Hahn
Date: 2021-09-13T11:21:45+01:00
New Revision: c24fc37e4773c70a868d8a708b0611c0604c7ab3

URL: https://github.com/llvm/llvm-project/commit/c24fc37e4773c70a868d8a708b0611c0604c7ab3
DIFF: https://github.com/llvm/llvm-project/commit/c24fc37e4773c70a868d8a708b0611c0604c7ab3.diff

LOG: [VectorCombine] Support AND/UREM indices that require freezing.

38b098be6605 limited scalarization to indices that are known non-poison.
For certain patterns that restrict the range of an index, we can insert
a freeze of the original value, to prevent propagation of poison.

Reviewed By: lebedev.ri

Differential Revision: https://reviews.llvm.org/D107580

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/VectorCombine.cpp
    llvm/test/Transforms/VectorCombine/load-insert-store.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index e138e890acb22..2c0a80556b382 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -774,21 +774,91 @@ static bool isMemModifiedBetween(BasicBlock::iterator Begin,
   });
 }
 
+/// Helper class to indicate whether a vector index can be safely scalarized and
+/// if a freeze needs to be inserted.
+class ScalarizationResult {
+  enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
+
+  StatusTy Status;
+  Value *ToFreeze;
+
+  ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr)
+      : Status(Status), ToFreeze(ToFreeze) {}
+
+public:
+  ScalarizationResult(const ScalarizationResult &Other) = default;
+  ~ScalarizationResult() {
+    assert(!ToFreeze && "freeze() not called with ToFreeze being set");
+  }
+
+  static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
+  static ScalarizationResult safe() { return {StatusTy::Safe}; }
+  static ScalarizationResult safeWithFreeze(Value *ToFreeze) {
+    return {StatusTy::SafeWithFreeze, ToFreeze};
+  }
+
+  /// Returns true if the index can be scalarize without requiring a freeze.
+  bool isSafe() const { return Status == StatusTy::Safe; }
+  /// Returns true if the index cannot be scalarized.
+  bool isUnsafe() const { return Status == StatusTy::Unsafe; }
+  /// Returns true if the index can be scalarize, but requires inserting a
+  /// freeze.
+  bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; }
+
+  /// Freeze the ToFreeze and update the use in \p User to use it.
+  void freeze(IRBuilder<> &Builder, Instruction &UserI) {
+    assert(isSafeWithFreeze() &&
+           "should only be used when freezing is required");
+    assert(is_contained(ToFreeze->users(), &UserI) &&
+           "UserI must be a user of ToFreeze");
+    IRBuilder<>::InsertPointGuard Guard(Builder);
+    Builder.SetInsertPoint(cast<Instruction>(&UserI));
+    Value *Frozen =
+        Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen");
+    for (Use &U : make_early_inc_range((UserI.operands())))
+      if (U.get() == ToFreeze)
+        U.set(Frozen);
+
+    ToFreeze = nullptr;
+  }
+};
+
 /// Check if it is legal to scalarize a memory access to \p VecTy at index \p
 /// Idx. \p Idx must access a valid vector element.
-static bool canScalarizeAccess(FixedVectorType *VecTy, Value *Idx,
-                               Instruction *CtxI, AssumptionCache &AC) {
-  if (auto *C = dyn_cast<ConstantInt>(Idx))
-    return C->getValue().ult(VecTy->getNumElements());
-
-  if (!isGuaranteedNotToBePoison(Idx, &AC))
-    return false;
+static ScalarizationResult canScalarizeAccess(FixedVectorType *VecTy,
+                                              Value *Idx, Instruction *CtxI,
+                                              AssumptionCache &AC) {
+  if (auto *C = dyn_cast<ConstantInt>(Idx)) {
+    if (C->getValue().ult(VecTy->getNumElements()))
+      return ScalarizationResult::safe();
+    return ScalarizationResult::unsafe();
+  }
 
-  APInt Zero(Idx->getType()->getScalarSizeInBits(), 0);
-  APInt MaxElts(Idx->getType()->getScalarSizeInBits(), VecTy->getNumElements());
+  unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
+  APInt Zero(IntWidth, 0);
+  APInt MaxElts(IntWidth, VecTy->getNumElements());
   ConstantRange ValidIndices(Zero, MaxElts);
-  ConstantRange IdxRange = computeConstantRange(Idx, true, &AC, CtxI, 0);
-  return ValidIndices.contains(IdxRange);
+  ConstantRange IdxRange(IntWidth, true);
+
+  if (isGuaranteedNotToBePoison(Idx, &AC)) {
+    if (ValidIndices.contains(computeConstantRange(Idx, true, &AC, CtxI, 0)))
+      return ScalarizationResult::safe();
+    return ScalarizationResult::unsafe();
+  }
+
+  // If the index may be poison, check if we can insert a freeze before the
+  // range of the index is restricted.
+  Value *IdxBase;
+  ConstantInt *CI;
+  if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) {
+    IdxRange = IdxRange.binaryAnd(CI->getValue());
+  } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) {
+    IdxRange = IdxRange.urem(CI->getValue());
+  }
+
+  if (ValidIndices.contains(IdxRange))
+    return ScalarizationResult::safeWithFreeze(IdxBase);
+  return ScalarizationResult::unsafe();
 }
 
 /// The memory operation on a vector of \p ScalarType had alignment of
@@ -836,12 +906,17 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
     // modified between, vector type matches store size, and index is inbounds.
     if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
         !DL.typeSizeEqualsStoreSize(Load->getType()) ||
-        !canScalarizeAccess(VecTy, Idx, Load, AC) ||
-        SrcAddr != SI->getPointerOperand()->stripPointerCasts() ||
+        SrcAddr != SI->getPointerOperand()->stripPointerCasts())
+      return false;
+
+    auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC);
+    if (ScalarizableIdx.isUnsafe() ||
         isMemModifiedBetween(Load->getIterator(), SI->getIterator(),
                              MemoryLocation::get(SI), AA))
       return false;
 
+    if (ScalarizableIdx.isSafeWithFreeze())
+      ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx));
     Value *GEP = Builder.CreateInBoundsGEP(
         SI->getValueOperand()->getType(), SI->getPointerOperand(),
         {ConstantInt::get(Idx->getType(), 0), Idx});
@@ -912,8 +987,11 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
     else if (LastCheckedInst->comesBefore(UI))
       LastCheckedInst = UI;
 
-    if (!canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC))
+    auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC);
+    if (!ScalarIdx.isSafe()) {
+      // TODO: Freeze index if it is safe to do so.
       return false;
+    }
 
     auto *Index = dyn_cast<ConstantInt>(UI->getOperand(1));
     OriginalCost +=

diff  --git a/llvm/test/Transforms/VectorCombine/load-insert-store.ll b/llvm/test/Transforms/VectorCombine/load-insert-store.ll
index 5c3f1659e35fe..636ab238c17cf 100644
--- a/llvm/test/Transforms/VectorCombine/load-insert-store.ll
+++ b/llvm/test/Transforms/VectorCombine/load-insert-store.ll
@@ -310,10 +310,10 @@ entry:
 define void @insert_store_nonconst_index_known_valid_by_and_but_may_be_poison(<16 x i8>* %q, i8 zeroext %s, i32 %idx) {
 ; CHECK-LABEL: @insert_store_nonconst_index_known_valid_by_and_but_may_be_poison(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = load <16 x i8>, <16 x i8>* [[Q:%.*]], align 16
-; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i32 [[IDX:%.*]], 7
-; CHECK-NEXT:    [[VECINS:%.*]] = insertelement <16 x i8> [[TMP0]], i8 [[S:%.*]], i32 [[IDX_CLAMPED]]
-; CHECK-NEXT:    store <16 x i8> [[VECINS]], <16 x i8>* [[Q]], align 16
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze i32 [[IDX:%.*]]
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = and i32 [[TMP0]], 7
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <16 x i8>, <16 x i8>* [[Q:%.*]], i32 0, i32 [[IDX_CLAMPED]]
+; CHECK-NEXT:    store i8 [[S:%.*]], i8* [[TMP1]], align 1
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -413,10 +413,10 @@ entry:
 define void @insert_store_nonconst_index_known_valid_by_urem_but_may_be_poison(<16 x i8>* %q, i8 zeroext %s, i32 %idx) {
 ; CHECK-LABEL: @insert_store_nonconst_index_known_valid_by_urem_but_may_be_poison(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = load <16 x i8>, <16 x i8>* [[Q:%.*]], align 16
-; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i32 [[IDX:%.*]], 16
-; CHECK-NEXT:    [[VECINS:%.*]] = insertelement <16 x i8> [[TMP0]], i8 [[S:%.*]], i32 [[IDX_CLAMPED]]
-; CHECK-NEXT:    store <16 x i8> [[VECINS]], <16 x i8>* [[Q]], align 16
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze i32 [[IDX:%.*]]
+; CHECK-NEXT:    [[IDX_CLAMPED:%.*]] = urem i32 [[TMP0]], 16
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds <16 x i8>, <16 x i8>* [[Q:%.*]], i32 0, i32 [[IDX_CLAMPED]]
+; CHECK-NEXT:    store i8 [[S:%.*]], i8* [[TMP1]], align 1
 ; CHECK-NEXT:    ret void
 ;
 entry:


        


More information about the llvm-commits mailing list