[llvm] 087bb0f - Allow scalable vectors in computeKnownBits

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 30 09:00:46 PDT 2022


Author: Philip Reames
Date: 2022-10-30T08:44:37-07:00
New Revision: 087bb0f1fe1c5a3f17eeea3537218b2be5d18f34

URL: https://github.com/llvm/llvm-project/commit/087bb0f1fe1c5a3f17eeea3537218b2be5d18f34
DIFF: https://github.com/llvm/llvm-project/commit/087bb0f1fe1c5a3f17eeea3537218b2be5d18f34.diff

LOG: Allow scalable vectors in computeKnownBits

This extends the computeKnownBits analysis to support scalable vectors. The critical detail is in deciding how to represent the demanded elements of a vector whose length is unknown at compile time.

For this patch, I adopt the convention that we track one bit which corresponds to all lanes. That is, that bit is implicitly broadcast to all lanes of the scalable vector resulting in all lanes being demanded. This is the same convention we use in getSplatValue in SelectionDAG.

Note that this convention doesn't actually impact much. Most of the code is agnostic to the interpretation of the demanded elements, and the few cases which actually care need case by case handling anyways. In this patch, I just bail out of those cases.

A prior patch (D128159) proposed using a different convention in SDAG. I don't see any strong reason to prefer one scheme over the other, so I propose we go with this one as it's conceptually the simplest. Getting known and demanded bit optimizations unblocked at all is a significant win.

I've locally implemented this scheme in reasonable large parts of ValueTracking.cpp and SelectionDAG equivalents, and have not hit any blockers. If this is approved, I plan to post a series of patches plumbing this through all the relevant parts.

In the discussion on that patch, a preference was expressed for introducing some form of abstraction around the demanded elements. I'll note that I've played with several variations on that idea locally, and have yet to find anything which results in more readable code. If anyone has concrete ideas in this area, I'm happy to explore in follow up patches. I'd strongly prefer to be making API changes in NFC manner with tests in place.

Differential Revision: https://reviews.llvm.org/D136470

Added: 
    

Modified: 
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/test/Transforms/InstCombine/add.ll
    llvm/test/Transforms/InstCombine/intrinsics.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 563929a9771c7..64594d9012440 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -166,10 +166,11 @@ static const Instruction *safeCxtI(const Value *V1, const Value *V2, const Instr
 static bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf,
                                    const APInt &DemandedElts,
                                    APInt &DemandedLHS, APInt &DemandedRHS) {
-  // The length of scalable vectors is unknown at compile time, thus we
-  // cannot check their values
-  if (isa<ScalableVectorType>(Shuf->getType()))
-    return false;
+  if (isa<ScalableVectorType>(Shuf->getType())) {
+    assert(DemandedElts == APInt(1,1));
+    DemandedLHS = DemandedRHS = DemandedElts;
+    return true;
+  }
 
   int NumElts =
       cast<FixedVectorType>(Shuf->getOperand(0)->getType())->getNumElements();
@@ -206,13 +207,9 @@ static void computeKnownBits(const Value *V, const APInt &DemandedElts,
 
 static void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth,
                              const Query &Q) {
-  // FIXME: We currently have no way to represent the DemandedElts of a scalable
-  // vector
-  if (isa<ScalableVectorType>(V->getType())) {
-    Known.resetAll();
-    return;
-  }
-
+  // Since the number of lanes in a scalable vector is unknown at compile time,
+  // we track one bit which is implicitly broadcast to all lanes.  This means
+  // that all lanes in a scalable vector are considered demanded.
   auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
   APInt DemandedElts =
       FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
@@ -1238,7 +1235,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
     // Look through a cast from narrow vector elements to wider type.
     // Examples: v4i32 -> v2i64, v3i8 -> v24
     unsigned SubBitWidth = SrcVecTy->getScalarSizeInBits();
-    if (BitWidth % SubBitWidth == 0) {
+    if (BitWidth % SubBitWidth == 0 && !isa<ScalableVectorType>(I->getType())) {
       // Known bits are automatically intersected across demanded elements of a
       // vector. So for example, if a bit is computed as known zero, it must be
       // zero across all demanded elements of the vector.
@@ -1832,6 +1829,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
     break;
   }
   case Instruction::InsertElement: {
+    if (isa<ScalableVectorType>(I->getType())) {
+      Known.resetAll();
+      return;
+    }
     const Value *Vec = I->getOperand(0);
     const Value *Elt = I->getOperand(1);
     auto *CIdx = dyn_cast<ConstantInt>(I->getOperand(2));
@@ -1948,9 +1949,8 @@ KnownBits computeKnownBits(const Value *V, unsigned Depth, const Query &Q) {
 /// for all of the demanded elements in the vector specified by DemandedElts.
 void computeKnownBits(const Value *V, const APInt &DemandedElts,
                       KnownBits &Known, unsigned Depth, const Query &Q) {
-  if (!DemandedElts || isa<ScalableVectorType>(V->getType())) {
-    // No demanded elts or V is a scalable vector, better to assume we don't
-    // know anything.
+  if (!DemandedElts) {
+    // No demanded elts, better to assume we don't know anything.
     Known.resetAll();
     return;
   }
@@ -1971,7 +1971,7 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts,
         "DemandedElt width should equal the fixed vector number of elements");
   } else {
     assert(DemandedElts == APInt(1, 1) &&
-           "DemandedElt width should be 1 for scalars");
+           "DemandedElt width should be 1 for scalars or scalable vectors");
   }
 
   Type *ScalarTy = Ty->getScalarType();
@@ -1998,6 +1998,7 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts,
   // Handle a constant vector by taking the intersection of the known bits of
   // each element.
   if (const ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(V)) {
+    assert(!isa<ScalableVectorType>(V->getType()));
     // We know that CDV must be a vector of integers. Take the intersection of
     // each element.
     Known.Zero.setAllBits(); Known.One.setAllBits();
@@ -2012,6 +2013,7 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts,
   }
 
   if (const auto *CV = dyn_cast<ConstantVector>(V)) {
+    assert(!isa<ScalableVectorType>(V->getType()));
     // We know that CV must be a vector of integers. Take the intersection of
     // each element.
     Known.Zero.setAllBits(); Known.One.setAllBits();

diff  --git a/llvm/test/Transforms/InstCombine/add.ll b/llvm/test/Transforms/InstCombine/add.ll
index b5e7746bdcf32..3706720c90c38 100644
--- a/llvm/test/Transforms/InstCombine/add.ll
+++ b/llvm/test/Transforms/InstCombine/add.ll
@@ -2338,7 +2338,7 @@ define i8 @mul_not_negpow2(i8 %x) {
 define <vscale x 1 x i32> @add_to_or_scalable(<vscale x 1 x i32> %in) {
 ; CHECK-LABEL: @add_to_or_scalable(
 ; CHECK-NEXT:    [[SHL:%.*]] = shl <vscale x 1 x i32> [[IN:%.*]], shufflevector (<vscale x 1 x i32> insertelement (<vscale x 1 x i32> poison, i32 1, i32 0), <vscale x 1 x i32> poison, <vscale x 1 x i32> zeroinitializer)
-; CHECK-NEXT:    [[ADD:%.*]] = add <vscale x 1 x i32> [[SHL]], shufflevector (<vscale x 1 x i32> insertelement (<vscale x 1 x i32> poison, i32 1, i32 0), <vscale x 1 x i32> poison, <vscale x 1 x i32> zeroinitializer)
+; CHECK-NEXT:    [[ADD:%.*]] = or <vscale x 1 x i32> [[SHL]], shufflevector (<vscale x 1 x i32> insertelement (<vscale x 1 x i32> poison, i32 1, i32 0), <vscale x 1 x i32> poison, <vscale x 1 x i32> zeroinitializer)
 ; CHECK-NEXT:    ret <vscale x 1 x i32> [[ADD]]
 ;
   %shl = shl <vscale x 1 x i32> %in, shufflevector (<vscale x 1 x i32> insertelement (<vscale x 1 x i32> poison, i32 1, i32 0), <vscale x 1 x i32> poison, <vscale x 1 x i32> zeroinitializer)

diff  --git a/llvm/test/Transforms/InstCombine/intrinsics.ll b/llvm/test/Transforms/InstCombine/intrinsics.ll
index 143136e3c77ab..dd9fd68660604 100644
--- a/llvm/test/Transforms/InstCombine/intrinsics.ll
+++ b/llvm/test/Transforms/InstCombine/intrinsics.ll
@@ -127,9 +127,7 @@ define <2 x i1> @cttz_knownbits_vec(<2 x i32> %arg) {
 
 define <vscale x 1 x i1> @cttz_knownbits_scalable_vec(<vscale x 1 x i32> %arg) {
 ; CHECK-LABEL: @cttz_knownbits_scalable_vec(
-; CHECK-NEXT:    [[OR:%.*]] = and <vscale x 1 x i32> [[ARG:%.*]], shufflevector (<vscale x 1 x i32> insertelement (<vscale x 1 x i32> poison, i32 27, i32 0), <vscale x 1 x i32> poison, <vscale x 1 x i32> zeroinitializer)
-; CHECK-NEXT:    [[RES:%.*]] = icmp eq <vscale x 1 x i32> [[OR]], shufflevector (<vscale x 1 x i32> insertelement (<vscale x 1 x i32> poison, i32 20, i32 0), <vscale x 1 x i32> poison, <vscale x 1 x i32> zeroinitializer)
-; CHECK-NEXT:    ret <vscale x 1 x i1> [[RES]]
+; CHECK-NEXT:    ret <vscale x 1 x i1> zeroinitializer
 ;
   %or = or <vscale x 1 x i32> %arg, shufflevector (<vscale x 1 x i32> insertelement (<vscale x 1 x i32> poison, i32 4, i32 0), <vscale x 1 x i32> poison, <vscale x 1 x i32> zeroinitializer)
   %cnt = call <vscale x 1 x i32> @llvm.cttz.nxv1i32(<vscale x 1 x i32> %or, i1 true) nounwind readnone


        


More information about the llvm-commits mailing list