[Mlir-commits] [mlir] [MLIR][AMDGPU] Add a wrapper for global LDS load intrinsics in AMDGPU (PR #133498)

Krzysztof Drewniak llvmlistbot at llvm.org
Tue Apr 1 16:42:29 PDT 2025


================
@@ -903,6 +903,81 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
+struct GlobalLoadLDSOpLowering
+    : public ConvertOpToLLVMPattern<GlobalLoadLDSOp> {
+  GlobalLoadLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<GlobalLoadLDSOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(GlobalLoadLDSOp op, GlobalLoadLDSOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+
+    auto elemType = cast<MemRefType>(op.getDst().getType()).getElementType();
+    size_t elemSizeInBits = elemType.getIntOrFloatBitWidth();
+    if (elemSizeInBits % 8 != 0)
+      return op.emitOpError("element size must be a multiple of 8");
+
+    // TODO: instead of only transfering one element per thread, we could
+    // augment it to transfer multiple elements per thread by issuing multiple
+    // `global_load_lds` instructions.
+    auto loadWidth = elemSizeInBits / 8;
+
+    const Chipset GlobalLoadEnabled{9, 0x4, 0x0};
+    if (chipset < GlobalLoadEnabled)
+      return op.emitOpError("chipset not supported");
+
+    // Currently only 1, 2, and 4 byte loads are supported.
+    if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4))
+      return op.emitOpError("chipset unsupported element size");
+
+    // Return pair of {base pointer, linearized index}.
+    auto getBasePtrAndLinearizedIndex =
+        [&](Value memref, MemRefType memrefType,
+            ValueRange indices) -> std::optional<std::pair<Value, Value>> {
+      MemRefDescriptor memRefDescriptor(memref);
+      int64_t offset = 0;
+      SmallVector<int64_t, 5> strides;
+      if (failed(memrefType.getStridesAndOffset(strides, offset)))
+        return {};
+      return std::make_pair(
+          memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
+                                     memrefType),
+          getLinearIndexI32(rewriter, loc, memRefDescriptor, indices, strides));
+    };
+
+    auto optSrcBuffer = getBasePtrAndLinearizedIndex(
+        adaptor.getSrc(), cast<MemRefType>(op.getSrc().getType()),
+        op.getSrcIndices());
+    if (!optSrcBuffer)
+      return op.emitOpError("failed to flatten source memref indices");
+    auto optDstBuffer = getBasePtrAndLinearizedIndex(
+        adaptor.getDst(), cast<MemRefType>(op.getDst().getType()),
+        op.getDstIndices());
+    if (!optDstBuffer)
+      return op.emitOpError("failed to flatten destination memref indices");
+
+    Type srcPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
+    Type dstPtrType = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
+    Value srcPtr = rewriter.create<LLVM::GEPOp>(
----------------
krzysz00 wrote:

This should follow `memref.load` and/or `memref.store`'s lowerings for getting the pointer.

https://github.com/llvm/llvm-project/pull/133498


More information about the Mlir-commits mailing list