[llvm] [AMDGPU] Handle natively unsupported types in addrspace(7) lowering (PR #110572)

Krzysztof Drewniak via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 27 16:38:00 PDT 2024


================
@@ -576,6 +597,547 @@ bool StoreFatPtrsAsIntsVisitor::visitStoreInst(StoreInst &SI) {
   return true;
 }
 
+namespace {
+/// Convert loads/stores of types that the buffer intrinsics can't handle into
+/// one ore more such loads/stores that consist of legal types.
+///
+/// Do this by
+/// 1. Recursing into structs (and arrays that don't share a memory layout with
+/// vectors) since the intrinsics can't handle complex types.
+/// 2. Converting arrays of non-aggregate, byte-sized types into their
+/// correspondinng vectors
+/// 3. Bitcasting unsupported types, namely overly-long scalars and byte
+/// vectors, into vectors of supported types.
+/// 4. Splitting up excessively long reads/writes into multiple operations.
+///
+/// Note that this doesn't handle complex data strucures, but, in the future,
+/// the aggregate load splitter from SROA could be refactored to allow for that
+/// case.
+class LegalizeBufferContentTypesVisitor
+    : public InstVisitor<LegalizeBufferContentTypesVisitor, bool> {
+  friend class InstVisitor<LegalizeBufferContentTypesVisitor, bool>;
+
+  IRBuilder<> IRB;
+
+  const DataLayout &DL;
+
+  /// If T is [N x U], where U is a scalar type, return the vector type
+  /// <N x U>, otherwise, return T.
+  Type *scalarArrayTypeAsVector(Type *MaybeArrayType);
+  Value *arrayToVector(Value *V, Type *TargetType, const Twine &Name);
+  Value *vectorToArray(Value *V, Type *OrigType, const Twine &Name);
+
+  /// Break up the loads of a struct into the loads of its components
+
+  /// Convert a vector or scalar type that can't be operated on by buffer
+  /// intrinsics to one that would be legal through bitcasts and/or truncation.
+  /// Uses the wider of i32, i16, or i8 where possible.
+  Type *legalNonAggregateFor(Type *T);
+  Value *makeLegalNonAggregate(Value *V, Type *TargetType, const Twine &Name);
+  Value *makeIllegalNonAggregate(Value *V, Type *OrigType, const Twine &Name);
+
+  struct VecSlice {
+    uint64_t Index;
+    uint64_t Length;
+    VecSlice(uint64_t Index, uint64_t Length) : Index(Index), Length(Length) {}
+  };
+  // Return the [index, length] pairs into which `T` needs to be cut to form
+  // legal buffer load or store operations. Clears `Slices`. Creates an empty
+  // `Slices` for non-vector inputs and creates one slice if no slicing will be
+  // needed.
+  void getVecSlices(Type *T, SmallVectorImpl<VecSlice> &Slices);
+
+  Value *extractSlice(Value *Vec, VecSlice S, const Twine &Name);
+  Value *insertSlice(Value *Whole, Value *Part, VecSlice S, const Twine &Name);
+
+  // In most cases, return `LegalType`. However, when given an input that would
+  // normally be a legal type for the buffer intrinsics to return but that isn't
+  // hooked up through SelectionDAG, return a type of the same width that can be
+  // used with the relevant intrinsics. Specifically, handle the cases:
+  // - <1 x T> => T for all T
+  // - <N x i8> <=> i16, i32, 2xi32, 4xi32 (as needed)
+  // - <N x T> where T is under 32 bits and the total size is 96 bits <=> <3 x
+  // i32>
+  Type *intrinsicTypeFor(Type *LegalType);
+
+  bool visitLoadImpl(LoadInst &OrigLI, Type *PartType,
+                     SmallVectorImpl<uint32_t> &AggIdxs, uint64_t AggByteOffset,
+                     Value *&Result, const Twine &Name);
+  // Return value is (Changed, ModifiedInPlace)
+  std::pair<bool, bool> visitStoreImpl(StoreInst &OrigSI, Type *PartType,
+                                       SmallVectorImpl<uint32_t> &AggIdxs,
+                                       uint64_t AggByteOffset,
+                                       const Twine &Name);
+
+  bool visitInstruction(Instruction &I) { return false; }
+  bool visitLoadInst(LoadInst &LI);
+  bool visitStoreInst(StoreInst &SI);
+
+public:
+  LegalizeBufferContentTypesVisitor(const DataLayout &DL, LLVMContext &Ctx)
+      : IRB(Ctx), DL(DL) {}
+  bool processFunction(Function &F);
+};
+} // namespace
+
+Type *LegalizeBufferContentTypesVisitor::scalarArrayTypeAsVector(Type *T) {
+  ArrayType *AT = dyn_cast<ArrayType>(T);
+  if (!AT)
+    return T;
+  Type *ET = AT->getElementType();
+  if (!ET->isSingleValueType() || isa<VectorType>(ET))
+    report_fatal_error("loading non-scalar arrays from buffer fat pointers "
+                       "should have recursed");
+  if (!DL.typeSizeEqualsStoreSize(AT))
+    report_fatal_error(
+        "loading padded arrays from buffer fat pinters should have recursed");
+  return FixedVectorType::get(ET, AT->getNumElements());
+}
+
+Value *LegalizeBufferContentTypesVisitor::arrayToVector(Value *V,
+                                                        Type *TargetType,
+                                                        const Twine &Name) {
+  Value *VectorRes = PoisonValue::get(TargetType);
+  auto *VT = cast<FixedVectorType>(TargetType);
+  unsigned EC = VT->getNumElements();
+  for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) {
+    Value *Elem = IRB.CreateExtractValue(V, I, Name + ".elem." + Twine(I));
+    VectorRes = IRB.CreateInsertElement(VectorRes, Elem, I,
+                                        Name + ".as.vec." + Twine(I));
+  }
+  return VectorRes;
+}
+
+Value *LegalizeBufferContentTypesVisitor::vectorToArray(Value *V,
+                                                        Type *OrigType,
+                                                        const Twine &Name) {
+  Value *ArrayRes = PoisonValue::get(OrigType);
+  ArrayType *AT = cast<ArrayType>(OrigType);
+  unsigned EC = AT->getNumElements();
+  for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) {
+    Value *Elem = IRB.CreateExtractElement(V, I, Name + ".elem." + Twine(I));
+    ArrayRes = IRB.CreateInsertValue(ArrayRes, Elem, I,
+                                     Name + ".as.array." + Twine(I));
+  }
+  return ArrayRes;
+}
+
+Type *LegalizeBufferContentTypesVisitor::legalNonAggregateFor(Type *T) {
+  TypeSize Size = DL.getTypeStoreSizeInBits(T);
+  // Implicitly zero-extend to the next byte if needed
+  if (!DL.typeSizeEqualsStoreSize(T))
+    T = IRB.getIntNTy(Size.getFixedValue());
+  Type *ElemTy = T;
+  if (auto *VT = dyn_cast<FixedVectorType>(T)) {
+    ElemTy = VT->getElementType();
+  }
+  if (isa<PointerType, ScalableVectorType>(ElemTy))
+    // Pointers are always big enough, and we'll let scalable vectors through to
+    // fail in codegen.
+    return T;
+  unsigned ElemSize = DL.getTypeSizeInBits(ElemTy).getFixedValue();
+  if (isPowerOf2_32(ElemSize) && ElemSize >= 16 && ElemSize <= 128) {
+    // [vectors of] anything that's 16/32/64/128 bits can be cast and split into
+    // legal buffer operations.
+    return T;
+  }
+  Type *BestVectorElemType = nullptr;
+  if (Size.isKnownMultipleOf(32))
+    BestVectorElemType = IRB.getInt32Ty();
+  else if (Size.isKnownMultipleOf(16))
+    BestVectorElemType = IRB.getInt16Ty();
+  else
+    BestVectorElemType = IRB.getInt8Ty();
+  unsigned NumCastElems =
+      Size.getFixedValue() / BestVectorElemType->getIntegerBitWidth();
+  if (NumCastElems == 1)
+    return BestVectorElemType;
+  return FixedVectorType::get(BestVectorElemType, NumCastElems);
+}
+
+Value *LegalizeBufferContentTypesVisitor::makeLegalNonAggregate(
+    Value *V, Type *TargetType, const Twine &Name) {
+  Type *SourceType = V->getType();
+  if (DL.getTypeSizeInBits(SourceType) != DL.getTypeSizeInBits(TargetType)) {
+    Type *ShortScalarTy =
+        IRB.getIntNTy(DL.getTypeSizeInBits(SourceType).getFixedValue());
+    Type *ByteScalarTy =
+        IRB.getIntNTy(DL.getTypeSizeInBits(TargetType).getFixedValue());
+    Value *AsScalar = IRB.CreateBitCast(V, ShortScalarTy, Name + ".as.scalar");
+    Value *Zext = IRB.CreateZExt(AsScalar, ByteScalarTy, Name + ".zext");
+    V = Zext;
+    SourceType = ByteScalarTy;
+  }
+  return IRB.CreateBitCast(V, TargetType, Name + ".legal");
+}
+
+Value *LegalizeBufferContentTypesVisitor::makeIllegalNonAggregate(
+    Value *V, Type *OrigType, const Twine &Name) {
+  Type *LegalType = V->getType();
+  if (DL.getTypeSizeInBits(LegalType) != DL.getTypeSizeInBits(OrigType)) {
+    Type *ShortScalarTy =
+        IRB.getIntNTy(DL.getTypeSizeInBits(OrigType).getFixedValue());
+    Type *ByteScalarTy =
+        IRB.getIntNTy(DL.getTypeSizeInBits(LegalType).getFixedValue());
+    Value *AsScalar = IRB.CreateBitCast(V, ByteScalarTy, Name + ".bytes.cast");
+    Value *Trunc = IRB.CreateTrunc(AsScalar, ShortScalarTy, Name + ".trunc");
+    return IRB.CreateBitCast(Trunc, OrigType, Name + ".orig");
+  }
+  return IRB.CreateBitCast(V, OrigType, Name + ".real.ty");
+}
+
+Type *LegalizeBufferContentTypesVisitor::intrinsicTypeFor(Type *LegalType) {
+  auto *VT = dyn_cast<FixedVectorType>(LegalType);
+  if (!VT)
+    return LegalType;
+  Type *ET = VT->getElementType();
+  if (VT->getNumElements() == 1)
+    return ET;
+  if (DL.getTypeSizeInBits(LegalType) == 96 && DL.getTypeSizeInBits(ET) < 32)
+    return FixedVectorType::get(IRB.getInt32Ty(), 3);
+  if (ET->isIntegerTy(8)) {
+    switch (VT->getNumElements()) {
+    default:
+      return LegalType; // Let it crash later
+    case 1:
+      return IRB.getInt8Ty();
+    case 2:
+      return IRB.getInt16Ty();
+    case 4:
+      return IRB.getInt32Ty();
+    case 8:
+      return FixedVectorType::get(IRB.getInt32Ty(), 2);
+    case 16:
+      return FixedVectorType::get(IRB.getInt32Ty(), 4);
+    }
+  }
+  return LegalType;
+}
+
+void LegalizeBufferContentTypesVisitor::getVecSlices(
+    Type *T, SmallVectorImpl<VecSlice> &Slices) {
+  Slices.clear();
+  auto *VT = dyn_cast<FixedVectorType>(T);
+  if (!VT)
+    return;
+
+  uint64_t ElemBitWidth =
+      DL.getTypeSizeInBits(VT->getElementType()).getFixedValue();
+
+  uint64_t ElemsPer4Words = 128 / ElemBitWidth;
+  uint64_t ElemsPer2Words = ElemsPer4Words / 2;
+  uint64_t ElemsPerWord = ElemsPer2Words / 2;
+  uint64_t ElemsPerShort = ElemsPerWord / 2;
+  uint64_t ElemsPerByte = ElemsPerShort / 2;
+  // If the elements evenly pack into 32-bit words, we can use 3-word stores,
+  // such as for <6 x bfloat> or <3 x i32>, but we can't dot his for, for
+  // example, <3 x i64>, since that's not slicing.
+  uint64_t ElemsPer3Words = ElemsPerWord * 3;
+
+  uint64_t TotalElems = VT->getNumElements();
+  uint64_t Index = 0;
+  auto TrySlice = [&](unsigned MaybeLen) {
+    if (MaybeLen > 0 && Index + MaybeLen <= TotalElems) {
+      Slices.emplace_back(/*Index=*/Index, /*Length=*/MaybeLen);
+      Index += MaybeLen;
+      return true;
+    }
+    return false;
+  };
+  while (Index < TotalElems) {
+    TrySlice(ElemsPer4Words) || TrySlice(ElemsPer3Words) ||
+        TrySlice(ElemsPer2Words) || TrySlice(ElemsPerWord) ||
+        TrySlice(ElemsPerShort) || TrySlice(ElemsPerByte);
+  }
+}
+
+Value *LegalizeBufferContentTypesVisitor::extractSlice(Value *Vec, VecSlice S,
+                                                       const Twine &Name) {
+  auto *VecVT = dyn_cast<FixedVectorType>(Vec->getType());
+  if (!VecVT)
+    return Vec;
+  if (S.Length == VecVT->getNumElements() && S.Index == 0)
+    return Vec;
+  if (S.Length == 1)
+    return IRB.CreateExtractElement(Vec, S.Index,
+                                    Name + ".slice." + Twine(S.Index));
+  SmallVector<int> Mask = llvm::to_vector(
+      llvm::iota_range<int>(S.Index, S.Index + S.Length, /*Inclusive=*/false));
+  return IRB.CreateShuffleVector(Vec, Mask, Name + ".slice." + Twine(S.Index));
+}
+
+Value *LegalizeBufferContentTypesVisitor::insertSlice(Value *Whole, Value *Part,
+                                                      VecSlice S,
+                                                      const Twine &Name) {
+  auto *WholeVT = dyn_cast<FixedVectorType>(Whole->getType());
+  if (!WholeVT)
+    return Part;
+  if (S.Length == WholeVT->getNumElements() && S.Index == 0)
+    return Part;
+  if (S.Length == 1) {
+    return IRB.CreateInsertElement(Whole, Part, S.Index,
+                                   Name + ".slice." + Twine(S.Index));
+  }
+  int NumElems = cast<FixedVectorType>(Whole->getType())->getNumElements();
+
+  // Extend the slice with poisons to make the main shufflevector happy.
+  SmallVector<int> ExtPartMask(NumElems, -1);
+  for (auto [I, E] : llvm::enumerate(
+           MutableArrayRef<int>(ExtPartMask).take_front(S.Length))) {
+    E = I;
+  }
+  Value *ExtPart = IRB.CreateShuffleVector(Part, ExtPartMask,
+                                           Name + ".ext." + Twine(S.Index));
+
+  SmallVector<int> Mask =
+      llvm::to_vector(llvm::iota_range<int>(0, NumElems, /*Inclusive=*/false));
+  for (auto [I, E] :
+       llvm::enumerate(MutableArrayRef<int>(Mask).slice(S.Index, S.Length)))
+    E = I + NumElems;
+  return IRB.CreateShuffleVector(Whole, ExtPart, Mask,
+                                 Name + ".parts." + Twine(S.Index));
+}
+
+bool LegalizeBufferContentTypesVisitor::visitLoadImpl(
+    LoadInst &OrigLI, Type *PartType, SmallVectorImpl<uint32_t> &AggIdxs,
+    uint64_t AggByteOff, Value *&Result, const Twine &Name) {
+  if (auto *ST = dyn_cast<StructType>(PartType)) {
+    const StructLayout *Layout = DL.getStructLayout(ST);
+    bool Changed = false;
+    for (auto [I, ElemTy, Offset] :
+         llvm::enumerate(ST->elements(), Layout->getMemberOffsets())) {
+      AggIdxs.push_back(I);
+      Changed |= visitLoadImpl(OrigLI, ElemTy, AggIdxs,
+                               AggByteOff + Offset.getFixedValue(), Result,
+                               Name + "." + Twine(I));
+      AggIdxs.pop_back();
+    }
+    return Changed;
+  }
+  if (auto *AT = dyn_cast<ArrayType>(PartType)) {
+    Type *ElemTy = AT->getElementType();
+    TypeSize AllocSize = DL.getTypeAllocSize(ElemTy);
+    if (!(ElemTy->isSingleValueType() &&
+          DL.getTypeSizeInBits(ElemTy) == 8 * AllocSize &&
+          !ElemTy->isVectorTy())) {
+      bool Changed = false;
+      for (auto I : llvm::iota_range<uint32_t>(0, AT->getNumElements(),
+                                               /*Inclusive=*/false)) {
+        AggIdxs.push_back(I);
+        Changed |= visitLoadImpl(OrigLI, ElemTy, AggIdxs,
+                                 AggByteOff + I * AllocSize.getFixedValue(),
+                                 Result, Name + Twine(I));
+        AggIdxs.pop_back();
+      }
+      return Changed;
+    }
+  }
+
+  // Typical case
+
+  Type *ArrayAsVecType = scalarArrayTypeAsVector(PartType);
+  Type *LegalType = legalNonAggregateFor(ArrayAsVecType);
+
+  SmallVector<VecSlice> Slices;
+  getVecSlices(LegalType, Slices);
+  bool HasSlices = Slices.size() > 1;
+  bool IsAggPart = !AggIdxs.empty();
+  Value *LoadsRes;
+  if (!HasSlices && !IsAggPart) {
+    Type *LoadableType = intrinsicTypeFor(LegalType);
+    if (LoadableType == PartType)
+      return false;
+
+    IRB.SetInsertPoint(&OrigLI);
+    auto *NLI = cast<LoadInst>(OrigLI.clone());
+    NLI->mutateType(LoadableType);
+    NLI = IRB.Insert(NLI);
+    NLI->setName(Name + ".loadable");
+
+    LoadsRes = IRB.CreateBitCast(NLI, LegalType, Name + ".from.loadable");
+  } else {
+    IRB.SetInsertPoint(&OrigLI);
+    LoadsRes = PoisonValue::get(LegalType);
+    Value *OrigPtr = OrigLI.getPointerOperand();
+    // If we're needing to spill something into more than one load, its legal
+    // type will be a vector (ex. an i256 load will have LegalType = <8 x i32>).
+    // But if we're already a scalar (which can happen if we're splitting up a
+    // struct), the element type will be the legal type itself.
+    Type *ElemType = LegalType;
+    if (auto *VT = dyn_cast<FixedVectorType>(LegalType))
+      ElemType = VT->getElementType();
+    unsigned ElemBytes = DL.getTypeStoreSize(ElemType);
+    AAMDNodes AANodes = OrigLI.getAAMetadata();
+    if (IsAggPart && Slices.empty())
+      Slices.emplace_back(/*Index=*/0, /*Length=*/1);
+    for (VecSlice S : Slices) {
+      Type *SliceType =
+          S.Length != 1 ? FixedVectorType::get(ElemType, S.Length) : ElemType;
+      int64_t ByteOffset = AggByteOff + S.Index * ElemBytes;
+      // You can't reasonably expect loads to wrap around the edge of memory.
+      Value *NewPtr = IRB.CreateGEP(
+          IRB.getInt8Ty(), OrigLI.getPointerOperand(), IRB.getInt32(ByteOffset),
+          OrigPtr->getName() + ".off.ptr." + Twine(ByteOffset),
+          GEPNoWrapFlags::noUnsignedWrap());
+      Type *LoadableType = intrinsicTypeFor(SliceType);
+      LoadInst *NewLI = IRB.CreateAlignedLoad(
+          LoadableType, NewPtr, commonAlignment(OrigLI.getAlign(), ByteOffset),
+          Name + ".off." + Twine(ByteOffset));
+      copyMetadataForLoad(*NewLI, OrigLI);
+      NewLI->setAAMetadata(
+          AANodes.adjustForAccess(ByteOffset, LoadableType, DL));
+      if (OrigLI.isAtomic())
+        NewLI->setAtomic(OrigLI.getOrdering(), OrigLI.getSyncScopeID());
+      if (OrigLI.isVolatile())
+        NewLI->setVolatile(OrigLI.isVolatile());
+      Value *Loaded = IRB.CreateBitCast(NewLI, SliceType,
+                                        NewLI->getName() + ".from.loadable");
+      LoadsRes = insertSlice(LoadsRes, Loaded, S, Name);
+    }
+  }
+  if (LegalType != ArrayAsVecType)
+    LoadsRes = makeIllegalNonAggregate(LoadsRes, ArrayAsVecType, Name);
+  if (ArrayAsVecType != PartType)
+    LoadsRes = vectorToArray(LoadsRes, PartType, Name);
+
+  if (IsAggPart)
+    Result = IRB.CreateInsertValue(Result, LoadsRes, AggIdxs, Name);
+  else
+    Result = LoadsRes;
+  return true;
+}
+
+bool LegalizeBufferContentTypesVisitor::visitLoadInst(LoadInst &LI) {
+  if (LI.getPointerAddressSpace() != AMDGPUAS::BUFFER_FAT_POINTER)
+    return false;
+
+  SmallVector<uint32_t> AggIdxs;
+  Type *OrigType = LI.getType();
+  Value *Result = PoisonValue::get(OrigType);
+  bool Changed = visitLoadImpl(LI, OrigType, AggIdxs, 0, Result, LI.getName());
+  if (!Changed)
+    return false;
+  Result->takeName(&LI);
+  LI.replaceAllUsesWith(Result);
+  LI.eraseFromParent();
+  return Changed;
+}
+
+std::pair<bool, bool> LegalizeBufferContentTypesVisitor::visitStoreImpl(
+    StoreInst &OrigSI, Type *PartType, SmallVectorImpl<uint32_t> &AggIdxs,
+    uint64_t AggByteOff, const Twine &Name) {
+  if (auto *ST = dyn_cast<StructType>(PartType)) {
+    const StructLayout *Layout = DL.getStructLayout(ST);
+    bool Changed = false;
+    for (auto [I, ElemTy, Offset] :
+         llvm::enumerate(ST->elements(), Layout->getMemberOffsets())) {
+      AggIdxs.push_back(I);
+      Changed |= std::get<0>(visitStoreImpl(OrigSI, ElemTy, AggIdxs,
+                                            AggByteOff + Offset.getFixedValue(),
+                                            Name + "." + Twine(I)));
+      AggIdxs.pop_back();
+    }
+    return std::make_pair(Changed, /*ModifiedInPlace=*/false);
+  }
+  if (auto *AT = dyn_cast<ArrayType>(PartType)) {
+    Type *ElemTy = AT->getElementType();
+    TypeSize AllocSize = DL.getTypeAllocSize(ElemTy);
+    if (!(ElemTy->isSingleValueType() &&
+          DL.getTypeSizeInBits(ElemTy) == 8 * AllocSize &&
----------------
krzysz00 wrote:

Ah

So what's the condition for being able to replace a store of a `[N x T]` with a store of a `<N x T>` (or a load)? 

https://github.com/llvm/llvm-project/pull/110572


More information about the llvm-commits mailing list