[Mlir-commits] [mlir] [MLIR][ROCDL] Add ops for LDS read transpose and global to LDS intrinsics (PR #123530)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jan 19 13:46:12 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-llvm
Author: None (plognjen)
<details>
<summary>Changes</summary>
This PR adds missing ds\.read.tr4\.b64, ds\.read\.tr8\.b64, ds\.read\.tr6\.b96,
ds\.read\.tr16\.b64 and global\.load\.lds ops to
the ROCDL dialect.
The ops are converted to the corresponding intrinsic calls during the
translation from MLIR to LLVM IRs.
---
Full diff: https://github.com/llvm/llvm-project/pull/123530.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+30)
- (modified) mlir/test/Dialect/LLVMIR/rocdl.mlir (+26)
- (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+24)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 71dac3ad39b7b1..0e23c183b4d923 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -412,6 +412,36 @@ def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]
def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
+//===---------------------------------------------------------------------===//
+// LDS transpose intrinsics
+
+def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
+def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;
+
+class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> :
+ ROCDL_IntrOp<mnemonic, [1], [], [], 1>,
+ Arguments<(ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr)>{
+ let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)";
+ }
+
+def ROCDL_ds_read_tr4_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr4.b64">;
+def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">;
+def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">;
+def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">;
+
+//===---------------------------------------------------------------------===//
+// Global load to LDS intrinsic
+
+def ROCDL_GlobalLoadLDSOp :
+ ROCDL_IntrOp<"global.load.lds", [], [], [], 0>,
+ Arguments<(ins Arg<ROCDLGlobalBuffer, "", [MemRead]>:$globalPtr,
+ Arg<ROCDLBufferLDS, "", [MemWrite]>:$ldsPtr,
+ I32:$size,
+ I32:$offset,
+ I32:$aux)> {
+ let assemblyFormat = "operands attr-dict";
+}
+
//===---------------------------------------------------------------------===//
// Operations on raw buffer resources (stride of 0, bounds checks either off or in
// raw buffer mode).
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 92789246edb4f3..c80ebebaafe3ad 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -227,6 +227,32 @@ func.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
llvm.return
}
+llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
+ // CHECK-LABEL: rocdl.ds.read.tr
+ // CHECK: rocdl.ds.read.tr4.b64 {{.*}} : <3> -> vector<2xi32>
+ %r0 = rocdl.ds.read.tr4.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
+ // CHECK: rocdl.ds.read.tr6.b96 {{.*}} : <3> -> vector<3xi32>
+ %r1 = rocdl.ds.read.tr6.b96 %ptr : !llvm.ptr<3> -> vector<3xi32>
+ // CHECK: rocdl.ds.read.tr8.b64 {{.*}} : <3> -> vector<2xi32>
+ %r2 = rocdl.ds.read.tr8.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
+ // CHECK: rocdl.ds.read.tr16.b64 {{.*}} : <3> -> vector<4xf16>
+ %r3 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xf16>
+ // CHECK: rocdl.ds.read.tr16.b64 {{.*}} : <3> -> vector<4xbf16>
+ %r4 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xbf16>
+ llvm.return %r3 : vector<4xf16>
+}
+
+llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
+ %aux = llvm.mlir.constant(0 : i32) : i32
+ %offset = llvm.mlir.constant(0 : i32) : i32
+ %size = llvm.mlir.constant(10 : i32) : i32
+
+ //CHECK: rocdl.global.load.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
+ rocdl.global.load.lds %src, %dst, %size, %offset, %aux
+
+ llvm.return
+}
+
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
%stride : i16,
%numRecords : i32,
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 0620c23b5fdad7..996e0e34c790c9 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -424,6 +424,30 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
llvm.return %r0 : vector<8xf32>
}
+llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
+ // CHECK-LABEL: rocdl.ds.read.tr
+ // CHECK: call <2 x i32> @llvm.amdgcn.ds.read.tr4.b64.v2i32(ptr addrspace(3) %0)
+ %r0 = rocdl.ds.read.tr4.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
+ // CHECK: call <3 x i32> @llvm.amdgcn.ds.read.tr6.b96.v3i32(ptr addrspace(3) %0)
+ %r1 = rocdl.ds.read.tr6.b96 %ptr : !llvm.ptr<3> -> vector<3xi32>
+ // CHECK: call <2 x i32> @llvm.amdgcn.ds.read.tr8.b64.v2i32(ptr addrspace(3) %0)
+ %r2 = rocdl.ds.read.tr8.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
+ // CHECK: call <4 x half> @llvm.amdgcn.ds.read.tr16.b64.v4f16(ptr addrspace(3) %0)
+ %r3 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xf16>
+ // CHECK: call <4 x bfloat> @llvm.amdgcn.ds.read.tr16.b64.v4bf16(ptr addrspace(3) %0)
+ %r4 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xbf16>
+ llvm.return %r3 : vector<4xf16>
+}
+
+llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
+ %aux = llvm.mlir.constant(0 : i32) : i32
+ %offset = llvm.mlir.constant(0 : i32) : i32
+ %size = llvm.mlir.constant(10 : i32) : i32
+ //CHECK: call void @llvm.amdgcn.global.load.lds
+ rocdl.global.load.lds %src, %dst, %size, %offset, %aux
+ llvm.return
+}
+
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
%stride : i16,
%numRecords : i32,
``````````
</details>
https://github.com/llvm/llvm-project/pull/123530
More information about the Mlir-commits
mailing list