[Mlir-commits] [mlir] 27f15ad - [MLIR][ROCDL] Add ops for LDS read transpose and global to LDS intrinsics (#123530)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 21 08:40:50 PST 2025
Author: plognjen
Date: 2025-01-21T16:40:46Z
New Revision: 27f15add7c82efb99c15051a1c9b2c660843b356
URL: https://github.com/llvm/llvm-project/commit/27f15add7c82efb99c15051a1c9b2c660843b356
DIFF: https://github.com/llvm/llvm-project/commit/27f15add7c82efb99c15051a1c9b2c660843b356.diff
LOG: [MLIR][ROCDL] Add ops for LDS read transpose and global to LDS intrinsics (#123530)
This PR adds missing ds\.read.tr4\.b64, ds\.read\.tr8\.b64,
ds\.read\.tr6\.b96,
ds\.read\.tr16\.b64 and global\.load\.lds ops to
the ROCDL dialect.
The ops are converted to the corresponding intrinsic calls during the
translation from MLIR to LLVM IRs.
---------
Co-authored-by: Ognjen Plavsic <plognjen at amd.com>
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
mlir/test/Dialect/LLVMIR/rocdl.mlir
mlir/test/Target/LLVMIR/rocdl.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 71dac3ad39b7b1..0b8c0f7f381c4a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -412,6 +412,36 @@ def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]
def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
+//===---------------------------------------------------------------------===//
+// LDS transpose intrinsics (available in GFX950)
+
+def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
+def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;
+
+class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> :
+ ROCDL_IntrOp<mnemonic, [1], [], [], 1>,
+ Arguments<(ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr)>{
+ let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)";
+}
+
+def ROCDL_ds_read_tr4_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr4.b64">;
+def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">;
+def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">;
+def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">;
+
+//===---------------------------------------------------------------------===//
+// Global load to LDS intrinsic (available in GFX950)
+
+def ROCDL_GlobalLoadLDSOp :
+ ROCDL_IntrOp<"global.load.lds", [], [], [], 0>,
+ Arguments<(ins Arg<ROCDLGlobalBuffer, "", [MemRead]>:$globalPtr,
+ Arg<ROCDLBufferLDS, "", [MemWrite]>:$ldsPtr,
+ I32:$size,
+ I32:$offset,
+ I32:$aux)> {
+ let assemblyFormat = "operands attr-dict";
+}
+
//===---------------------------------------------------------------------===//
// Operations on raw buffer resources (stride of 0, bounds checks either off or in
// raw buffer mode).
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 92789246edb4f3..c80ebebaafe3ad 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -227,6 +227,32 @@ func.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
llvm.return
}
+llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
+ // CHECK-LABEL: rocdl.ds.read.tr
+ // CHECK: rocdl.ds.read.tr4.b64 {{.*}} : <3> -> vector<2xi32>
+ %r0 = rocdl.ds.read.tr4.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
+ // CHECK: rocdl.ds.read.tr6.b96 {{.*}} : <3> -> vector<3xi32>
+ %r1 = rocdl.ds.read.tr6.b96 %ptr : !llvm.ptr<3> -> vector<3xi32>
+ // CHECK: rocdl.ds.read.tr8.b64 {{.*}} : <3> -> vector<2xi32>
+ %r2 = rocdl.ds.read.tr8.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
+ // CHECK: rocdl.ds.read.tr16.b64 {{.*}} : <3> -> vector<4xf16>
+ %r3 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xf16>
+ // CHECK: rocdl.ds.read.tr16.b64 {{.*}} : <3> -> vector<4xbf16>
+ %r4 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xbf16>
+ llvm.return %r3 : vector<4xf16>
+}
+
+llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
+ %aux = llvm.mlir.constant(0 : i32) : i32
+ %offset = llvm.mlir.constant(0 : i32) : i32
+ %size = llvm.mlir.constant(10 : i32) : i32
+
+ //CHECK: rocdl.global.load.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
+ rocdl.global.load.lds %src, %dst, %size, %offset, %aux
+
+ llvm.return
+}
+
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
%stride : i16,
%numRecords : i32,
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 0620c23b5fdad7..996e0e34c790c9 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -424,6 +424,30 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
llvm.return %r0 : vector<8xf32>
}
+llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
+ // CHECK-LABEL: rocdl.ds.read.tr
+ // CHECK: call <2 x i32> @llvm.amdgcn.ds.read.tr4.b64.v2i32(ptr addrspace(3) %0)
+ %r0 = rocdl.ds.read.tr4.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
+ // CHECK: call <3 x i32> @llvm.amdgcn.ds.read.tr6.b96.v3i32(ptr addrspace(3) %0)
+ %r1 = rocdl.ds.read.tr6.b96 %ptr : !llvm.ptr<3> -> vector<3xi32>
+ // CHECK: call <2 x i32> @llvm.amdgcn.ds.read.tr8.b64.v2i32(ptr addrspace(3) %0)
+ %r2 = rocdl.ds.read.tr8.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
+ // CHECK: call <4 x half> @llvm.amdgcn.ds.read.tr16.b64.v4f16(ptr addrspace(3) %0)
+ %r3 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xf16>
+ // CHECK: call <4 x bfloat> @llvm.amdgcn.ds.read.tr16.b64.v4bf16(ptr addrspace(3) %0)
+ %r4 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xbf16>
+ llvm.return %r3 : vector<4xf16>
+}
+
+llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
+ %aux = llvm.mlir.constant(0 : i32) : i32
+ %offset = llvm.mlir.constant(0 : i32) : i32
+ %size = llvm.mlir.constant(10 : i32) : i32
+ //CHECK: call void @llvm.amdgcn.global.load.lds
+ rocdl.global.load.lds %src, %dst, %size, %offset, %aux
+ llvm.return
+}
+
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
%stride : i16,
%numRecords : i32,
More information about the Mlir-commits
mailing list