[Mlir-commits] [mlir] [mlir][AMDGPU] Plumb address space 7 through MLIR, add address_space attr. (PR #125594)
Krzysztof Drewniak
llvmlistbot at llvm.org
Mon Feb 10 13:14:05 PST 2025
================
@@ -76,11 +83,168 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
return index ? index : createI32Constant(rewriter, loc, 0);
}
+/// Compute the contents of the `num_records` field for a given memref
+/// descriptor - that is, the number of bytes that's one element past the
+/// greatest possible valid index into the memref.
+static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
+ MemRefType memrefType,
+ MemRefDescriptor &memrefDescriptor,
+ ArrayRef<int64_t> strides,
+ uint32_t elementByteWidth) {
+ if (memrefType.hasStaticShape() &&
+ !llvm::any_of(strides, ShapedType::isDynamic)) {
+ int64_t size = memrefType.getRank() == 0 ? 1 : 0;
+ ArrayRef<int64_t> shape = memrefType.getShape();
+ for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
+ size = std::max(shape[i] * strides[i], size);
+ size = size * elementByteWidth;
+ assert(size < std::numeric_limits<uint32_t>::max() &&
+ "the memref buffer is too large");
+ return createI32Constant(rewriter, loc, static_cast<int32_t>(size));
+ }
+ Value maxIndex;
+ for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
+ Value size = memrefDescriptor.size(rewriter, loc, i);
+ Value stride = memrefDescriptor.stride(rewriter, loc, i);
+ Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
+ maxIndex = maxIndex
+ ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+ : maxThisDim;
+ }
+ return rewriter.create<LLVM::MulOp>(
+ loc, convertUnsignedToI32(rewriter, loc, maxIndex),
+ createI32Constant(rewriter, loc, elementByteWidth));
+}
+
+static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
+ Value basePointer, Value numRecords,
+ bool boundsCheck, amdgpu::Chipset chipset,
+ Value cacheSwizzleStride = nullptr) {
+ // The stride value is generally 0. However, on MI-300 and onward, you can
+ // enable a cache swizzling mode by setting bit 14 of the stride field
+ // and setting that stride to a cache stride.
+ Type i16 = rewriter.getI16Type();
+ Value stride;
+ if (chipset.majorVersion == 9 && chipset >= kGfx940 && cacheSwizzleStride) {
+ Value cacheStrideZext =
+ rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
+ Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
+ loc, i16, rewriter.getI16IntegerAttr(1 << 14));
+ stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
+ /*isDisjoint=*/true);
+ } else {
+ stride = rewriter.create<LLVM::ConstantOp>(loc, i16,
+ rewriter.getI16IntegerAttr(0));
+ }
+ // Get the number of elements.
+ // Flag word:
+ // bits 0-11: dst sel, ignored by these intrinsics
+ // bits 12-14: data format (ignored, must be nonzero, 7=float)
+ // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
+ // bit 19: In nested heap (0 here)
+ // bit 20: Behavior on unmap (0 means "return 0 / ignore")
+ // bits 21-22: Index stride for swizzles (N/A)
+ // bit 23: Add thread ID (0)
+ // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
+ // bits 25-26: Reserved (0)
+ // bit 27: Buffer is non-volatile (CDNA only)
+ // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
+ // none, 3 = either swizzles or testing against offset field) RDNA only
+ // bits 30-31: Type (must be 0)
+ uint32_t flags = (7 << 12) | (4 << 15);
+ if (chipset.majorVersion >= 10) {
+ flags |= (1 << 24);
+ uint32_t oob = boundsCheck ? 3 : 2;
+ flags |= (oob << 28);
+ }
+ Value flagsConst = createI32Constant(rewriter, loc, flags);
+ Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
+ Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
+ loc, rsrcType, basePointer, stride, numRecords, flagsConst);
+ return resource;
+}
+
namespace {
-// Define commonly used chipsets versions for convenience.
-constexpr Chipset kGfx908 = Chipset(9, 0, 8);
-constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
-constexpr Chipset kGfx940 = Chipset(9, 4, 0);
+struct FatRawBufferCastLowering
+ : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
+ FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
+ chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value memRef = adaptor.getSource();
+ Value unconvertedMemref = op.getSource();
+ MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
+ MemRefDescriptor descriptor(memRef);
+
+ DataLayout dataLayout = DataLayout::closest(op);
+ int64_t elementByteWidth =
+ dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
+
+ int64_t unusedOffset = 0;
+ SmallVector<int64_t, 5> strideVals;
+ if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
+ return op.emitOpError("Can't lower non-stride-offset memrefs");
+
+ Value numRecords = adaptor.getValidBytes();
+ if (!numRecords)
+ numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
+ strideVals, elementByteWidth);
+
+ Value basePointer =
+ adaptor.getResetOffset()
+ ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
+ memrefType)
+ : descriptor.alignedPtr(rewriter, loc);
+
+ Value offset;
+ if (adaptor.getResetOffset())
----------------
krzysz00 wrote:
I think this one's clearer as an if-else situation, but I'm willing to defer here- not that big a deal
https://github.com/llvm/llvm-project/pull/125594
More information about the Mlir-commits
mailing list