[Mlir-commits] [mlir] dae0ef5 - [MLIR][AMDGPU] Add a wrapper for global LDS load intrinsics in AMDGPU (#133498)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 8 06:18:34 PDT 2025


Author: Alan Li
Date: 2025-04-08T09:18:30-04:00
New Revision: dae0ef53a0b99c6c2b74143baee5896e8bc5c8e7

URL: https://github.com/llvm/llvm-project/commit/dae0ef53a0b99c6c2b74143baee5896e8bc5c8e7
DIFF: https://github.com/llvm/llvm-project/commit/dae0ef53a0b99c6c2b74143baee5896e8bc5c8e7.diff

LOG: [MLIR][AMDGPU] Add a wrapper for global LDS load intrinsics in AMDGPU (#133498)

Defining a new `amdgpu.global_load` op, which is a thin wrap around
ROCDL `global_load_lds` intrinsic, along with its lowering logics to
`rocdl.global.load.lds`.

Added: 
    mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir

Modified: 
    mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
    mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
    mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 9cdd961d96ff5..108d7237ff703 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -767,4 +767,40 @@ def AMDGPU_WMMAOp :
   let hasVerifier = 1;
 }
 
+def AMDGPU_GatherToLDSOp :
+    AMDGPU_Op<"gather_to_lds", [SameVariadicOperandSize]>,
+    Arguments<(ins
+                   Arg<AnyMemRef, "buffer to gather from", [MemRead]>:$src,
+                   Variadic<Index>:$srcIndices,
+                   Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
+                   Variadic<Index>:$dstIndices,
+                   TypeAttr:$transferType
+                   )>,
+    Results<(outs)> {
+  let summary = "MLIR wrapper for CDNA mfma instructions";
+  let description = [{
+    The `amdgpu.global_load` op is a wrapper around the `global_load_lds` instructions.
+
+    Operands:
+    * `$src`: global memory memref to read from.
+    * `$srcIndices`: indices into `$src` to read from for this thread.
+    * `$dst`: LDS memory memref to write to.
+    * `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread.
+      The elements gathered by the subgroup will be written in order of lane ID will be written
+      into contiguously starting at `$dst[$dstIndices]`.
+    * `$transferType`: type of the data to be transferred by each thread. This is used to determine
+      the size of the data to be transferred and the number of threads in the subgroup.
+      The transfer type must be a scalar type or a vector type with a single element type.
+
+    The `$dst`, along with its indices, points to the memory location the subgroup of this thread
+    will write to.
+  
+    Note: only enabled for gfx942 and later.
+  }];
+  let assemblyFormat = [{
+    $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // AMDGPU

diff  --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 56d40d6d123bf..5f697bdeef566 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1010,6 +1010,55 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
+struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
+  GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (chipset < kGfx942)
+      return op.emitOpError("chipset not supported");
+
+    Location loc = op.getLoc();
+
+    auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
+    auto dstMemRefType = cast<MemRefType>(op.getSrc().getType());
+
+    // TODO: instead of only transfering one element per thread, we could
+    // augment it to transfer multiple elements per thread by issuing multiple
+    // `global_load_lds` instructions.
+    Type transferType = op.getTransferType();
+    size_t loadWidth = [&]() -> size_t {
+      if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
+        return transferVectorType.getNumElements() *
+               (transferVectorType.getElementTypeBitWidth() / 8);
+      } else {
+        return transferType.getIntOrFloatBitWidth() / 8;
+      }
+    }();
+
+    // Currently only 1, 2, and 4 byte loads are supported.
+    if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
+      return op.emitOpError("chipset unsupported element size");
+
+    Value srcPtr = getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
+                                        (adaptor.getSrcIndices()), rewriter);
+    Value dstPtr = getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
+                                        (adaptor.getDstIndices()), rewriter);
+
+    rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
+        op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
+        createI32Constant(rewriter, loc, 0),
+        createI32Constant(rewriter, loc, 0), ArrayAttr{}, ArrayAttr{},
+        ArrayAttr{});
+
+    return success();
+  }
+};
+
 namespace {
 struct ExtPackedFp8OpLowering final
     : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
@@ -1393,6 +1442,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                ROCDL::RawPtrBufferAtomicCmpSwap>,
            AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
            MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
-           PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
-                                                                      chipset);
+           PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
+           GatherToLDSOpLowering>(converter, chipset);
 }

diff  --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 1e482515a4ee0..7f286f938ee60 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
@@ -24,6 +25,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/DerivedTypes.h"
 
 #include <limits>
 #include <optional>
@@ -112,21 +114,31 @@ LogicalResult FatRawBufferCastOp::verify() {
   return success();
 }
 
+static bool hasGlobalMemorySpace(Attribute memorySpace) {
+  if (!memorySpace)
+    return true;
+  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
+    return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
+  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+    return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
+  return false;
+}
+
+static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
+  if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
+    return intMemorySpace.getInt() == 3;
+  if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+    return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // RawBuffer*Op
 //===----------------------------------------------------------------------===//
 template <typename T>
 static LogicalResult verifyRawBufferOp(T &op) {
   MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
-  Attribute memorySpace = bufferType.getMemorySpace();
-  bool isGlobal = false;
-  if (!memorySpace)
-    isGlobal = true;
-  else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
-    isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
-  else if (auto gpuMemorySpace =
-               llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
-    isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
+  bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
 
   if (!isGlobal)
     return op.emitOpError(
@@ -461,6 +473,40 @@ LogicalResult DPPOp::verify() {
   return success();
 }
 
+LogicalResult GatherToLDSOp::verify() {
+  MemRefType srcType = cast<MemRefType>(getSrc().getType());
+  MemRefType dstType = cast<MemRefType>(getDst().getType());
+
+  if (!memref::isStaticShapeAndContiguousRowMajor(dstType))
+    return emitOpError(
+        "destination types must have static shape  and contiguous");
+
+  auto elemType = srcType.getElementType();
+  // Check $src and $dst element types are the same.
+  if (elemType != dstType.getElementType())
+    return emitOpError("source and destination element types must match");
+
+  // copy type sizes should be 1, 2, or 4 bytes.
+  auto transferType = getTransferType();
+  size_t transferSize;
+  if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
+    transferSize = vectorTransfer.getNumElements() *
+                   vectorTransfer.getElementTypeBitWidth();
+  } else {
+    transferSize = transferType.getIntOrFloatBitWidth();
+  }
+  if (transferSize != 8 && transferSize != 16 && transferSize != 32)
+    return emitOpError("Transfering type size must be 8, 16, or 32 bits");
+
+  if (!hasGlobalMemorySpace(srcType.getMemorySpace()))
+    return emitOpError("source memory address space must be Global");
+
+  if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
+    return emitOpError("destination memory address space must be Workgroup");
+
+  return success();
+}
+
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES

diff  --git a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
new file mode 100644
index 0000000000000..b1c16bd5db079
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
@@ -0,0 +1,143 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s
+
+#gpu_global_addrspace = 1
+#gpu_lds_addrspace = 3
+
+// CHECK-LABEL: func @global_load_to_rocdl_f32
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
+func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
+  %c0 = arith.constant 0 : index
+  %c12 = arith.constant 12 : index
+  %c32 = arith.constant 32 : index
+  %alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
+  // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
+  // CHECK: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
+  // CHECK: %[[C32:.*]] = arith.constant 32 : index
+  // CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
+
+  // CHECK: %[[ALLOC:.*]] = memref.alloc()
+  // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast
+  // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] 
+  
+  // CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
+  // CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
+  // CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
+
+  // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
+  // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+
+  // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
+  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+  // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
+
+  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
+  // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]]
+  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
+    : f32, memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
+  func.return
+}
+
+// CHECK-LABEL: func @global_load_to_rocdl_i8
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi8, 1>)
+func.func @global_load_to_rocdl_i8(%global : memref<128x72xi8, #gpu_global_addrspace>) {
+  // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
+  // CHECK: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
+  // CHECK: %[[C32:.*]] = arith.constant 32 : index
+  // CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
+
+  // CHECK: %[[ALLOC:.*]] = memref.alloc()
+  // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
+  // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] 
+  
+  // CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
+  // CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
+  // CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
+
+  // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
+  // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+
+  // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
+  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+  // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
+
+  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
+  // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C1]]
+  %c0 = arith.constant 0 : index
+  %c12 = arith.constant 12 : index
+  %c32 = arith.constant 32 : index
+  %alloc = memref.alloc() : memref<64x64xi8, #gpu_lds_addrspace>
+  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
+    : i8, memref<128x72xi8, #gpu_global_addrspace>, memref<64x64xi8, #gpu_lds_addrspace>
+  func.return
+}
+
+// CHECK-LABEL: func @global_load_to_rocdl_vec
+// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi16, 1>)
+func.func @global_load_to_rocdl_vec(%global : memref<128x72xi16, #gpu_global_addrspace>) {
+  // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+  
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
+  // CHECK: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
+  // CHECK: %[[C32:.*]] = arith.constant 32 : index
+  // CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
+
+  // CHECK: %[[ALLOC:.*]] = memref.alloc()
+  // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
+  // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] 
+  
+  // CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
+  // CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
+  // CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
+
+  // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
+  // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+
+  // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
+  // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
+  // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
+
+  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
+  // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]]
+  %c0 = arith.constant 0 : index
+  %c12 = arith.constant 12 : index
+  %c32 = arith.constant 32 : index
+  %alloc = memref.alloc() : memref<64x128xi16, #gpu_lds_addrspace>
+  amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
+    : vector<2 x i16>, memref<128x72xi16, #gpu_global_addrspace>, memref<64x128xi16, #gpu_lds_addrspace>
+  func.return
+}
+
+
+// CHECK-LABEL: func @global_load_to_rocdl_dynamic_indices
+// CHECK-SAME: (%[[ARG0:.*]]: memref<512xi32, 1>, %[[SRC_IDX:.*]]: index, %[[DST_IDX:.*]]: index)
+func.func @global_load_to_rocdl_dynamic_indices(%global : memref<512xi32, #gpu_global_addrspace>, %src_idx : index, %dst_idx : index) {
+  // CHECK: %[[DSTIDX_CAST:.*]] = builtin.unrealized_conversion_cast %[[DST_IDX]]
+  // CHECK: %[[SRCIDX_CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC_IDX]]
+  // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
+  // CHECK: %[[ALLOC:.*]] = memref.alloc()
+  // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
+  // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] 
+  // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRCIDX_CAST]]]
+  // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
+  // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DSTIDX_CAST]]]
+  // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]]
+  %alloc = memref.alloc() : memref<4x64xi32, #gpu_lds_addrspace>
+  %c0 = arith.constant 0 : index
+  amdgpu.gather_to_lds %global[%src_idx], %alloc[%dst_idx, %c0]
+    : i32, memref<512xi32, #gpu_global_addrspace>, memref<4x64xi32, #gpu_lds_addrspace>
+  func.return
+}


        


More information about the Mlir-commits mailing list