[llvm] [AMDGPU] Handle natively unsupported types in addrspace(7) lowering (PR #110572)
Krzysztof Drewniak via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 23 11:05:09 PDT 2024
================
@@ -576,6 +597,547 @@ bool StoreFatPtrsAsIntsVisitor::visitStoreInst(StoreInst &SI) {
return true;
}
+namespace {
+/// Convert loads/stores of types that the buffer intrinsics can't handle into
+/// one ore more such loads/stores that consist of legal types.
+///
+/// Do this by
+/// 1. Recursing into structs (and arrays that don't share a memory layout with
+/// vectors) since the intrinsics can't handle complex types.
+/// 2. Converting arrays of non-aggregate, byte-sized types into their
+/// correspondinng vectors
+/// 3. Bitcasting unsupported types, namely overly-long scalars and byte
+/// vectors, into vectors of supported types.
+/// 4. Splitting up excessively long reads/writes into multiple operations.
+///
+/// Note that this doesn't handle complex data strucures, but, in the future,
+/// the aggregate load splitter from SROA could be refactored to allow for that
+/// case.
+class LegalizeBufferContentTypesVisitor
+ : public InstVisitor<LegalizeBufferContentTypesVisitor, bool> {
+ friend class InstVisitor<LegalizeBufferContentTypesVisitor, bool>;
+
+ IRBuilder<> IRB;
+
+ const DataLayout &DL;
+
+ /// If T is [N x U], where U is a scalar type, return the vector type
+ /// <N x U>, otherwise, return T.
+ Type *scalarArrayTypeAsVector(Type *MaybeArrayType);
+ Value *arrayToVector(Value *V, Type *TargetType, const Twine &Name);
+ Value *vectorToArray(Value *V, Type *OrigType, const Twine &Name);
+
+ /// Break up the loads of a struct into the loads of its components
+
+ /// Convert a vector or scalar type that can't be operated on by buffer
+ /// intrinsics to one that would be legal through bitcasts and/or truncation.
+ /// Uses the wider of i32, i16, or i8 where possible.
+ Type *legalNonAggregateFor(Type *T);
+ Value *makeLegalNonAggregate(Value *V, Type *TargetType, const Twine &Name);
+ Value *makeIllegalNonAggregate(Value *V, Type *OrigType, const Twine &Name);
+
+ struct VecSlice {
+ uint64_t Index;
+ uint64_t Length;
+ VecSlice(uint64_t Index, uint64_t Length) : Index(Index), Length(Length) {}
+ };
+ // Return the [index, length] pairs into which `T` needs to be cut to form
+ // legal buffer load or store operations. Clears `Slices`. Creates an empty
+ // `Slices` for non-vector inputs and creates one slice if no slicing will be
+ // needed.
+ void getVecSlices(Type *T, SmallVectorImpl<VecSlice> &Slices);
+
+ Value *extractSlice(Value *Vec, VecSlice S, const Twine &Name);
+ Value *insertSlice(Value *Whole, Value *Part, VecSlice S, const Twine &Name);
+
+ // In most cases, return `LegalType`. However, when given an input that would
+ // normally be a legal type for the buffer intrinsics to return but that isn't
+ // hooked up through SelectionDAG, return a type of the same width that can be
+ // used with the relevant intrinsics. Specifically, handle the cases:
+ // - <1 x T> => T for all T
+ // - <N x i8> <=> i16, i32, 2xi32, 4xi32 (as needed)
+ // - <N x T> where T is under 32 bits and the total size is 96 bits <=> <3 x
+ // i32>
+ Type *intrinsicTypeFor(Type *LegalType);
+
+ bool visitLoadImpl(LoadInst &OrigLI, Type *PartType,
+ SmallVectorImpl<uint32_t> &AggIdxs, uint64_t AggByteOffset,
+ Value *&Result, const Twine &Name);
+ // Return value is (Changed, ModifiedInPlace)
+ std::pair<bool, bool> visitStoreImpl(StoreInst &OrigSI, Type *PartType,
+ SmallVectorImpl<uint32_t> &AggIdxs,
+ uint64_t AggByteOffset,
+ const Twine &Name);
+
+ bool visitInstruction(Instruction &I) { return false; }
+ bool visitLoadInst(LoadInst &LI);
+ bool visitStoreInst(StoreInst &SI);
+
+public:
+ LegalizeBufferContentTypesVisitor(const DataLayout &DL, LLVMContext &Ctx)
+ : IRB(Ctx), DL(DL) {}
+ bool processFunction(Function &F);
+};
+} // namespace
+
+Type *LegalizeBufferContentTypesVisitor::scalarArrayTypeAsVector(Type *T) {
+ ArrayType *AT = dyn_cast<ArrayType>(T);
+ if (!AT)
+ return T;
+ Type *ET = AT->getElementType();
+ if (!ET->isSingleValueType() || isa<VectorType>(ET))
+ report_fatal_error("loading non-scalar arrays from buffer fat pointers "
+ "should have recursed");
+ if (!DL.typeSizeEqualsStoreSize(AT))
+ report_fatal_error(
+ "loading padded arrays from buffer fat pinters should have recursed");
+ return FixedVectorType::get(ET, AT->getNumElements());
+}
+
+Value *LegalizeBufferContentTypesVisitor::arrayToVector(Value *V,
+ Type *TargetType,
+ const Twine &Name) {
+ Value *VectorRes = PoisonValue::get(TargetType);
+ auto *VT = cast<FixedVectorType>(TargetType);
+ unsigned EC = VT->getNumElements();
+ for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) {
+ Value *Elem = IRB.CreateExtractValue(V, I, Name + ".elem." + Twine(I));
+ VectorRes = IRB.CreateInsertElement(VectorRes, Elem, I,
+ Name + ".as.vec." + Twine(I));
+ }
+ return VectorRes;
+}
+
+Value *LegalizeBufferContentTypesVisitor::vectorToArray(Value *V,
+ Type *OrigType,
+ const Twine &Name) {
+ Value *ArrayRes = PoisonValue::get(OrigType);
+ ArrayType *AT = cast<ArrayType>(OrigType);
+ unsigned EC = AT->getNumElements();
+ for (auto I : iota_range<unsigned>(0, EC, /*Inclusive=*/false)) {
+ Value *Elem = IRB.CreateExtractElement(V, I, Name + ".elem." + Twine(I));
+ ArrayRes = IRB.CreateInsertValue(ArrayRes, Elem, I,
+ Name + ".as.array." + Twine(I));
+ }
+ return ArrayRes;
+}
+
+Type *LegalizeBufferContentTypesVisitor::legalNonAggregateFor(Type *T) {
+ TypeSize Size = DL.getTypeStoreSizeInBits(T);
+ // Implicitly zero-extend to the next byte if needed
+ if (!DL.typeSizeEqualsStoreSize(T))
+ T = IRB.getIntNTy(Size.getFixedValue());
+ Type *ElemTy = T;
+ if (auto *VT = dyn_cast<FixedVectorType>(T)) {
+ ElemTy = VT->getElementType();
+ }
+ if (isa<PointerType, ScalableVectorType>(ElemTy))
+ // Pointers are always big enough, and we'll let scalable vectors through to
+ // fail in codegen.
+ return T;
+ unsigned ElemSize = DL.getTypeSizeInBits(ElemTy).getFixedValue();
+ if (isPowerOf2_32(ElemSize) && ElemSize >= 16 && ElemSize <= 128) {
+ // [vectors of] anything that's 16/32/64/128 bits can be cast and split into
+ // legal buffer operations.
+ return T;
+ }
+ Type *BestVectorElemType = nullptr;
+ if (Size.isKnownMultipleOf(32))
+ BestVectorElemType = IRB.getInt32Ty();
+ else if (Size.isKnownMultipleOf(16))
+ BestVectorElemType = IRB.getInt16Ty();
+ else
+ BestVectorElemType = IRB.getInt8Ty();
+ unsigned NumCastElems =
+ Size.getFixedValue() / BestVectorElemType->getIntegerBitWidth();
+ if (NumCastElems == 1)
+ return BestVectorElemType;
+ return FixedVectorType::get(BestVectorElemType, NumCastElems);
+}
+
+Value *LegalizeBufferContentTypesVisitor::makeLegalNonAggregate(
+ Value *V, Type *TargetType, const Twine &Name) {
+ Type *SourceType = V->getType();
+ if (DL.getTypeSizeInBits(SourceType) != DL.getTypeSizeInBits(TargetType)) {
+ Type *ShortScalarTy =
+ IRB.getIntNTy(DL.getTypeSizeInBits(SourceType).getFixedValue());
+ Type *ByteScalarTy =
+ IRB.getIntNTy(DL.getTypeSizeInBits(TargetType).getFixedValue());
+ Value *AsScalar = IRB.CreateBitCast(V, ShortScalarTy, Name + ".as.scalar");
+ Value *Zext = IRB.CreateZExt(AsScalar, ByteScalarTy, Name + ".zext");
+ V = Zext;
+ SourceType = ByteScalarTy;
+ }
+ return IRB.CreateBitCast(V, TargetType, Name + ".legal");
+}
+
+Value *LegalizeBufferContentTypesVisitor::makeIllegalNonAggregate(
+ Value *V, Type *OrigType, const Twine &Name) {
+ Type *LegalType = V->getType();
+ if (DL.getTypeSizeInBits(LegalType) != DL.getTypeSizeInBits(OrigType)) {
+ Type *ShortScalarTy =
+ IRB.getIntNTy(DL.getTypeSizeInBits(OrigType).getFixedValue());
+ Type *ByteScalarTy =
+ IRB.getIntNTy(DL.getTypeSizeInBits(LegalType).getFixedValue());
+ Value *AsScalar = IRB.CreateBitCast(V, ByteScalarTy, Name + ".bytes.cast");
+ Value *Trunc = IRB.CreateTrunc(AsScalar, ShortScalarTy, Name + ".trunc");
+ return IRB.CreateBitCast(Trunc, OrigType, Name + ".orig");
+ }
+ return IRB.CreateBitCast(V, OrigType, Name + ".real.ty");
+}
+
+Type *LegalizeBufferContentTypesVisitor::intrinsicTypeFor(Type *LegalType) {
+ auto *VT = dyn_cast<FixedVectorType>(LegalType);
+ if (!VT)
+ return LegalType;
+ Type *ET = VT->getElementType();
+ if (VT->getNumElements() == 1)
+ return ET;
----------------
krzysz00 wrote:
Except not, because for `<1 x T>` we want to return `T` here
https://github.com/llvm/llvm-project/pull/110572
More information about the llvm-commits
mailing list