[llvm] 968980e - [GlobalISel] NFC: Change LLT::scalarOrVector to take ElementCount.

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 25 03:27:19 PDT 2021


Author: Sander de Smalen
Date: 2021-06-25T11:26:16+01:00
New Revision: 968980ef08955ee03f406e8078089b1f2eb571e3

URL: https://github.com/llvm/llvm-project/commit/968980ef08955ee03f406e8078089b1f2eb571e3
DIFF: https://github.com/llvm/llvm-project/commit/968980ef08955ee03f406e8078089b1f2eb571e3.diff

LOG: [GlobalISel] NFC: Change LLT::scalarOrVector to take ElementCount.

Reviewed By: aemerson

Differential Revision: https://reviews.llvm.org/D104452

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
    llvm/include/llvm/Support/LowLevelTypeImpl.h
    llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
    llvm/lib/CodeGen/GlobalISel/Utils.cpp
    llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
    llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
    llvm/unittests/CodeGen/LowLevelTypeTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
index 722edb3938666..a7e6d37419a84 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
@@ -1038,7 +1038,8 @@ class LegalizeRuleSet {
         },
         [=](const LegalityQuery &Query) {
           LLT VecTy = Query.Types[TypeIdx];
-          LLT NewTy = LLT::scalarOrVector(MaxElements, VecTy.getElementType());
+          LLT NewTy = LLT::scalarOrVector(ElementCount::getFixed(MaxElements),
+                                          VecTy.getElementType());
           return std::make_pair(TypeIdx, NewTy);
         });
   }

diff  --git a/llvm/include/llvm/Support/LowLevelTypeImpl.h b/llvm/include/llvm/Support/LowLevelTypeImpl.h
index fb5aff66b0a4f..674d8f4d2b138 100644
--- a/llvm/include/llvm/Support/LowLevelTypeImpl.h
+++ b/llvm/include/llvm/Support/LowLevelTypeImpl.h
@@ -96,14 +96,12 @@ class LLT {
     return vector(ElementCount::getScalable(MinNumElements), ScalarTy);
   }
 
-  static LLT scalarOrVector(uint16_t NumElements, LLT ScalarTy) {
-    // FIXME: Migrate interface to use ElementCount
-    return NumElements == 1 ? ScalarTy
-                            : LLT::fixed_vector(NumElements, ScalarTy);
+  static LLT scalarOrVector(ElementCount EC, LLT ScalarTy) {
+    return EC.isScalar() ? ScalarTy : LLT::vector(EC, ScalarTy);
   }
 
-  static LLT scalarOrVector(uint16_t NumElements, unsigned ScalarSize) {
-    return scalarOrVector(NumElements, LLT::scalar(ScalarSize));
+  static LLT scalarOrVector(ElementCount EC, unsigned ScalarSize) {
+    return scalarOrVector(EC, LLT::scalar(ScalarSize));
   }
 
   explicit LLT(bool isPointer, bool isVector, ElementCount EC,
@@ -189,7 +187,8 @@ class LLT {
   LLT changeNumElements(unsigned NewNumElts) const {
     assert((!isVector() || !isScalable()) &&
            "Cannot use changeNumElements on a scalable vector");
-    return LLT::scalarOrVector(NewNumElts, getScalarType());
+    return LLT::scalarOrVector(ElementCount::getFixed(NewNumElts),
+                               getScalarType());
   }
 
   /// Return a type that is \p Factor times smaller. Reduces the number of
@@ -199,7 +198,8 @@ class LLT {
     assert(Factor != 1);
     if (isVector()) {
       assert(getNumElements() % Factor == 0);
-      return scalarOrVector(getNumElements() / Factor, getElementType());
+      return scalarOrVector(getElementCount().divideCoefficientBy(Factor),
+                            getElementType());
     }
 
     assert(getSizeInBits() % Factor == 0);

diff  --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 340fec0993984..a19fafdd4fbe4 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -60,7 +60,8 @@ getNarrowTypeBreakDown(LLT OrigTy, LLT NarrowTy, LLT &LeftoverTy) {
     unsigned EltSize = OrigTy.getScalarSizeInBits();
     if (LeftoverSize % EltSize != 0)
       return {-1, -1};
-    LeftoverTy = LLT::scalarOrVector(LeftoverSize / EltSize, EltSize);
+    LeftoverTy = LLT::scalarOrVector(
+        ElementCount::getFixed(LeftoverSize / EltSize), EltSize);
   } else {
     LeftoverTy = LLT::scalar(LeftoverSize);
   }
@@ -178,7 +179,8 @@ bool LegalizerHelper::extractParts(Register Reg, LLT RegTy,
     unsigned EltSize = MainTy.getScalarSizeInBits();
     if (LeftoverSize % EltSize != 0)
       return false;
-    LeftoverTy = LLT::scalarOrVector(LeftoverSize / EltSize, EltSize);
+    LeftoverTy = LLT::scalarOrVector(
+        ElementCount::getFixed(LeftoverSize / EltSize), EltSize);
   } else {
     LeftoverTy = LLT::scalar(LeftoverSize);
   }
@@ -2572,7 +2574,8 @@ LegalizerHelper::bitcastExtractVectorElt(MachineInstr &MI, unsigned TypeIdx,
 
     // Type of the intermediate result vector.
     const unsigned NewEltsPerOldElt = NewNumElts / OldNumElts;
-    LLT MidTy = LLT::scalarOrVector(NewEltsPerOldElt, NewEltTy);
+    LLT MidTy =
+        LLT::scalarOrVector(ElementCount::getFixed(NewEltsPerOldElt), NewEltTy);
 
     auto NewEltsPerOldEltK = MIRBuilder.buildConstant(IdxTy, NewEltsPerOldElt);
 
@@ -3300,9 +3303,6 @@ LegalizerHelper::fewerElementsVectorMultiEltType(
     return UnableToLegalize;
 
   const LLT NarrowTy0 = NarrowTyArg;
-  const unsigned NewNumElts =
-      NarrowTy0.isVector() ? NarrowTy0.getNumElements() : 1;
-
   const Register DstReg = MI.getOperand(0).getReg();
   LLT DstTy = MRI.getType(DstReg);
   LLT LeftoverTy0;
@@ -3322,7 +3322,9 @@ LegalizerHelper::fewerElementsVectorMultiEltType(
   for (unsigned I = 1, E = MI.getNumOperands(); I != E; ++I) {
     Register SrcReg = MI.getOperand(I).getReg();
     LLT SrcTyI = MRI.getType(SrcReg);
-    LLT NarrowTyI = LLT::scalarOrVector(NewNumElts, SrcTyI.getScalarType());
+    const auto NewEC = NarrowTy0.isVector() ? NarrowTy0.getElementCount()
+                                            : ElementCount::getFixed(1);
+    LLT NarrowTyI = LLT::scalarOrVector(NewEC, SrcTyI.getScalarType());
     LLT LeftoverTyI;
 
     // Split this operand into the requested typed registers, and any leftover
@@ -3685,7 +3687,7 @@ LegalizerHelper::fewerElementsVectorMulo(MachineInstr &MI, unsigned TypeIdx,
 
   LLT ElementType = SrcTy.getElementType();
   LLT OverflowElementTy = MRI.getType(Overflow).getElementType();
-  const int NumResult = SrcTy.getNumElements();
+  const ElementCount NumResult = SrcTy.getElementCount();
   LLT GCDTy = getGCDType(SrcTy, NarrowTy);
 
   // Unmerge the operands to smaller parts of GCD type.
@@ -3693,7 +3695,7 @@ LegalizerHelper::fewerElementsVectorMulo(MachineInstr &MI, unsigned TypeIdx,
   auto UnmergeRHS = MIRBuilder.buildUnmerge(GCDTy, RHS);
 
   const int NumOps = UnmergeLHS->getNumOperands() - 1;
-  const int PartsPerUnmerge = NumResult / NumOps;
+  const ElementCount PartsPerUnmerge = NumResult.divideCoefficientBy(NumOps);
   LLT OverflowTy = LLT::scalarOrVector(PartsPerUnmerge, OverflowElementTy);
   LLT ResultTy = LLT::scalarOrVector(PartsPerUnmerge, ElementType);
 
@@ -3711,7 +3713,7 @@ LegalizerHelper::fewerElementsVectorMulo(MachineInstr &MI, unsigned TypeIdx,
 
   LLT ResultLCMTy = buildLCMMergePieces(SrcTy, NarrowTy, GCDTy, ResultParts);
   LLT OverflowLCMTy =
-      LLT::scalarOrVector(ResultLCMTy.getNumElements(), OverflowElementTy);
+      LLT::scalarOrVector(ResultLCMTy.getElementCount(), OverflowElementTy);
 
   // Recombine the pieces to the original result and overflow registers.
   buildWidenedRemergeToDst(Result, ResultLCMTy, ResultParts);
@@ -3957,8 +3959,6 @@ LegalizerHelper::reduceOperationWidth(MachineInstr &MI, unsigned int TypeIdx,
   SmallVector<Register, 8> ExtractedRegs[3];
   SmallVector<Register, 8> Parts;
 
-  unsigned NarrowElts = NarrowTy.isVector() ? NarrowTy.getNumElements() : 1;
-
   // Break down all the sources into NarrowTy pieces we can operate on. This may
   // involve creating merges to a wider type, padded with undef.
   for (int I = 0; I != NumSrcOps; ++I) {
@@ -3979,7 +3979,9 @@ LegalizerHelper::reduceOperationWidth(MachineInstr &MI, unsigned int TypeIdx,
         SrcReg = MIRBuilder.buildBitcast(SrcTy, SrcReg).getReg(0);
       }
     } else {
-      OpNarrowTy = LLT::scalarOrVector(NarrowElts, SrcTy.getScalarType());
+      auto NarrowEC = NarrowTy.isVector() ? NarrowTy.getElementCount()
+                                          : ElementCount::getFixed(1);
+      OpNarrowTy = LLT::scalarOrVector(NarrowEC, SrcTy.getScalarType());
     }
 
     LLT GCDTy = extractGCDType(ExtractedRegs[I], SrcTy, OpNarrowTy, SrcReg);

diff  --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index 7b4780dd04c45..73e42927f8e63 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -816,7 +816,7 @@ LLT llvm::getGCDType(LLT OrigTy, LLT TargetTy) {
       if (OrigElt.getSizeInBits() == TargetElt.getSizeInBits()) {
         int GCD = greatestCommonDivisor(OrigTy.getNumElements(),
                                         TargetTy.getNumElements());
-        return LLT::scalarOrVector(GCD, OrigElt);
+        return LLT::scalarOrVector(ElementCount::getFixed(GCD), OrigElt);
       }
     } else {
       // If the source is a vector of pointers, return a pointer element.

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
index 47e63c6cf4d42..9eeb7954ac7af 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
@@ -107,7 +107,9 @@ static LegalizeMutation fewerEltsToSize64Vector(unsigned TypeIdx) {
     unsigned Size = Ty.getSizeInBits();
     unsigned Pieces = (Size + 63) / 64;
     unsigned NewNumElts = (Ty.getNumElements() + 1) / Pieces;
-    return std::make_pair(TypeIdx, LLT::scalarOrVector(NewNumElts, EltTy));
+    return std::make_pair(
+        TypeIdx,
+        LLT::scalarOrVector(ElementCount::getFixed(NewNumElts), EltTy));
   };
 }
 
@@ -139,7 +141,7 @@ static LLT getBitcastRegisterType(const LLT Ty) {
     return LLT::scalar(Size);
   }
 
-  return LLT::scalarOrVector(Size / 32, 32);
+  return LLT::scalarOrVector(ElementCount::getFixed(Size / 32), 32);
 }
 
 static LegalizeMutation bitcastToRegisterType(unsigned TypeIdx) {
@@ -154,7 +156,8 @@ static LegalizeMutation bitcastToVectorElement32(unsigned TypeIdx) {
     const LLT Ty = Query.Types[TypeIdx];
     unsigned Size = Ty.getSizeInBits();
     assert(Size % 32 == 0);
-    return std::make_pair(TypeIdx, LLT::scalarOrVector(Size / 32, 32));
+    return std::make_pair(
+        TypeIdx, LLT::scalarOrVector(ElementCount::getFixed(Size / 32), 32));
   };
 }
 
@@ -1214,7 +1217,8 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
 
                 if (MaxSize % EltSize == 0) {
                   return std::make_pair(
-                    0, LLT::scalarOrVector(MaxSize / EltSize, EltTy));
+                      0, LLT::scalarOrVector(
+                             ElementCount::getFixed(MaxSize / EltSize), EltTy));
                 }
 
                 unsigned NumPieces = Query.MMODescrs[0].SizeInBits / MaxSize;
@@ -1242,7 +1246,8 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
                 // should be OK, since the new parts will be further legalized.
                 unsigned FloorSize = PowerOf2Floor(DstSize);
                 return std::make_pair(
-                  0, LLT::scalarOrVector(FloorSize / EltSize, EltTy));
+                    0, LLT::scalarOrVector(
+                           ElementCount::getFixed(FloorSize / EltSize), EltTy));
               }
 
               // Need to split because of alignment.
@@ -4448,14 +4453,16 @@ bool AMDGPULegalizerInfo::legalizeImageIntrinsic(
   LLT RegTy;
 
   if (IsD16 && ST.hasUnpackedD16VMem()) {
-    RoundedTy = LLT::scalarOrVector(AdjustedNumElts, 32);
+    RoundedTy =
+        LLT::scalarOrVector(ElementCount::getFixed(AdjustedNumElts), 32);
     TFETy = LLT::fixed_vector(AdjustedNumElts + 1, 32);
     RegTy = S32;
   } else {
     unsigned EltSize = EltTy.getSizeInBits();
     unsigned RoundedElts = (AdjustedTy.getSizeInBits() + 31) / 32;
     unsigned RoundedSize = 32 * RoundedElts;
-    RoundedTy = LLT::scalarOrVector(RoundedSize / EltSize, EltSize);
+    RoundedTy = LLT::scalarOrVector(
+        ElementCount::getFixed(RoundedSize / EltSize), EltSize);
     TFETy = LLT::fixed_vector(RoundedSize / 32 + 1, S32);
     RegTy = !IsTFE && EltSize == 16 ? V2S16 : S32;
   }

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index e5dafbfb43eaf..4fbd4618e4317 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -675,12 +675,13 @@ static void setRegsToType(MachineRegisterInfo &MRI, ArrayRef<Register> Regs,
 
 static LLT getHalfSizedType(LLT Ty) {
   if (Ty.isVector()) {
-    assert(Ty.getNumElements() % 2 == 0);
-    return LLT::scalarOrVector(Ty.getNumElements() / 2, Ty.getElementType());
+    assert(Ty.getElementCount().isKnownMultipleOf(2));
+    return LLT::scalarOrVector(Ty.getElementCount().divideCoefficientBy(2),
+                               Ty.getElementType());
   }
 
-  assert(Ty.getSizeInBits() % 2 == 0);
-  return LLT::scalar(Ty.getSizeInBits() / 2);
+  assert(Ty.getScalarSizeInBits() % 2 == 0);
+  return LLT::scalar(Ty.getScalarSizeInBits() / 2);
 }
 
 /// Legalize instruction \p MI where operands in \p OpIndices must be SGPRs. If
@@ -1123,8 +1124,8 @@ static std::pair<LLT, LLT> splitUnequalType(LLT Ty, unsigned FirstSize) {
   unsigned FirstPartNumElts = FirstSize / EltSize;
   unsigned RemainderElts = (TotalSize - FirstSize) / EltSize;
 
-  return {LLT::scalarOrVector(FirstPartNumElts, EltTy),
-          LLT::scalarOrVector(RemainderElts, EltTy)};
+  return {LLT::scalarOrVector(ElementCount::getFixed(FirstPartNumElts), EltTy),
+          LLT::scalarOrVector(ElementCount::getFixed(RemainderElts), EltTy)};
 }
 
 static LLT widen96To128(LLT Ty) {

diff  --git a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp
index bac203e180e25..28eeb52fff7c7 100644
--- a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp
+++ b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp
@@ -100,17 +100,25 @@ TEST(LowLevelTypeTest, Vector) {
 
 TEST(LowLevelTypeTest, ScalarOrVector) {
   // Test version with number of bits for scalar type.
-  EXPECT_EQ(LLT::scalar(32), LLT::scalarOrVector(1, 32));
-  EXPECT_EQ(LLT::fixed_vector(2, 32), LLT::scalarOrVector(2, 32));
+  EXPECT_EQ(LLT::scalar(32),
+            LLT::scalarOrVector(ElementCount::getFixed(1), 32));
+  EXPECT_EQ(LLT::fixed_vector(2, 32),
+            LLT::scalarOrVector(ElementCount::getFixed(2), 32));
+  EXPECT_EQ(LLT::scalable_vector(1, 32),
+            LLT::scalarOrVector(ElementCount::getScalable(1), 32));
 
   // Test version with LLT for scalar type.
-  EXPECT_EQ(LLT::scalar(32), LLT::scalarOrVector(1, LLT::scalar(32)));
-  EXPECT_EQ(LLT::fixed_vector(2, 32), LLT::scalarOrVector(2, LLT::scalar(32)));
+  EXPECT_EQ(LLT::scalar(32),
+            LLT::scalarOrVector(ElementCount::getFixed(1), LLT::scalar(32)));
+  EXPECT_EQ(LLT::fixed_vector(2, 32),
+            LLT::scalarOrVector(ElementCount::getFixed(2), LLT::scalar(32)));
 
   // Test with pointer elements.
-  EXPECT_EQ(LLT::pointer(1, 32), LLT::scalarOrVector(1, LLT::pointer(1, 32)));
-  EXPECT_EQ(LLT::fixed_vector(2, LLT::pointer(1, 32)),
-            LLT::scalarOrVector(2, LLT::pointer(1, 32)));
+  EXPECT_EQ(LLT::pointer(1, 32), LLT::scalarOrVector(ElementCount::getFixed(1),
+                                                     LLT::pointer(1, 32)));
+  EXPECT_EQ(
+      LLT::fixed_vector(2, LLT::pointer(1, 32)),
+      LLT::scalarOrVector(ElementCount::getFixed(2), LLT::pointer(1, 32)));
 }
 
 TEST(LowLevelTypeTest, ChangeElementType) {


        


More information about the llvm-commits mailing list