[Mlir-commits] [mlir] [mlir][AMDGPU] Plumb address space 7 through MLIR, add address_space attr. (PR #125594)

Krzysztof Drewniak llvmlistbot at llvm.org
Mon Feb 10 13:14:05 PST 2025


================
@@ -76,11 +83,168 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
   return index ? index : createI32Constant(rewriter, loc, 0);
 }
 
+/// Compute the contents of the `num_records` field for a given memref
+/// descriptor - that is, the number of bytes that's one element past the
+/// greatest possible valid index into the memref.
+static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
+                           MemRefType memrefType,
+                           MemRefDescriptor &memrefDescriptor,
+                           ArrayRef<int64_t> strides,
+                           uint32_t elementByteWidth) {
+  if (memrefType.hasStaticShape() &&
+      !llvm::any_of(strides, ShapedType::isDynamic)) {
+    int64_t size = memrefType.getRank() == 0 ? 1 : 0;
+    ArrayRef<int64_t> shape = memrefType.getShape();
+    for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
+      size = std::max(shape[i] * strides[i], size);
+    size = size * elementByteWidth;
+    assert(size < std::numeric_limits<uint32_t>::max() &&
+           "the memref buffer is too large");
+    return createI32Constant(rewriter, loc, static_cast<int32_t>(size));
+  }
+  Value maxIndex;
+  for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
+    Value size = memrefDescriptor.size(rewriter, loc, i);
+    Value stride = memrefDescriptor.stride(rewriter, loc, i);
+    Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
+    maxIndex = maxIndex
+                   ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+                   : maxThisDim;
+  }
+  return rewriter.create<LLVM::MulOp>(
+      loc, convertUnsignedToI32(rewriter, loc, maxIndex),
+      createI32Constant(rewriter, loc, elementByteWidth));
+}
+
+static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
+                            Value basePointer, Value numRecords,
+                            bool boundsCheck, amdgpu::Chipset chipset,
+                            Value cacheSwizzleStride = nullptr) {
+  // The stride value is generally 0. However, on MI-300 and onward, you can
+  // enable a cache swizzling mode by setting bit 14 of the stride field
+  // and setting that stride to a cache stride.
+  Type i16 = rewriter.getI16Type();
+  Value stride;
+  if (chipset.majorVersion == 9 && chipset >= kGfx940 && cacheSwizzleStride) {
+    Value cacheStrideZext =
+        rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
+    Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
+        loc, i16, rewriter.getI16IntegerAttr(1 << 14));
+    stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
+                                         /*isDisjoint=*/true);
+  } else {
+    stride = rewriter.create<LLVM::ConstantOp>(loc, i16,
+                                               rewriter.getI16IntegerAttr(0));
+  }
+  // Get the number of elements.
+  // Flag word:
+  // bits 0-11: dst sel, ignored by these intrinsics
+  // bits 12-14: data format (ignored, must be nonzero, 7=float)
+  // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
+  // bit 19: In nested heap (0 here)
+  // bit 20: Behavior on unmap (0 means  "return 0 / ignore")
+  // bits 21-22: Index stride for swizzles (N/A)
+  // bit 23: Add thread ID (0)
+  // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
+  // bits 25-26: Reserved (0)
+  // bit 27: Buffer is non-volatile (CDNA only)
+  // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
+  //  none, 3 = either swizzles or testing against offset field) RDNA only
+  // bits 30-31: Type (must be 0)
+  uint32_t flags = (7 << 12) | (4 << 15);
+  if (chipset.majorVersion >= 10) {
+    flags |= (1 << 24);
+    uint32_t oob = boundsCheck ? 3 : 2;
+    flags |= (oob << 28);
+  }
+  Value flagsConst = createI32Constant(rewriter, loc, flags);
+  Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
+  Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
+      loc, rsrcType, basePointer, stride, numRecords, flagsConst);
+  return resource;
+}
+
 namespace {
-// Define commonly used chipsets versions for convenience.
-constexpr Chipset kGfx908 = Chipset(9, 0, 8);
-constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
-constexpr Chipset kGfx940 = Chipset(9, 4, 0);
+struct FatRawBufferCastLowering
+    : public ConvertOpToLLVMPattern<FatRawBufferCastOp> {
+  FatRawBufferCastLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
+        chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value memRef = adaptor.getSource();
+    Value unconvertedMemref = op.getSource();
+    MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
+    MemRefDescriptor descriptor(memRef);
+
+    DataLayout dataLayout = DataLayout::closest(op);
+    int64_t elementByteWidth =
+        dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
+
+    int64_t unusedOffset = 0;
+    SmallVector<int64_t, 5> strideVals;
+    if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
+      return op.emitOpError("Can't lower non-stride-offset memrefs");
+
+    Value numRecords = adaptor.getValidBytes();
+    if (!numRecords)
+      numRecords = getNumRecords(rewriter, loc, memrefType, descriptor,
+                                 strideVals, elementByteWidth);
+
+    Value basePointer =
+        adaptor.getResetOffset()
+            ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
+                                   memrefType)
+            : descriptor.alignedPtr(rewriter, loc);
+
+    Value offset;
+    if (adaptor.getResetOffset())
----------------
krzysz00 wrote:

I think this one's clearer as an if-else situation, but I'm willing to defer here- not that big a deal

https://github.com/llvm/llvm-project/pull/125594


More information about the Mlir-commits mailing list