[Mlir-commits] [mlir] [MLIR][ROCDL] Add ops for LDS read transpose and global to LDS intrinsics (PR #123530)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 20 02:51:12 PST 2025
https://github.com/plognjen updated https://github.com/llvm/llvm-project/pull/123530
>From 68987020e86940a2114e58e2a0682d5fbd034a9e Mon Sep 17 00:00:00 2001
From: Ognjen Plavsic <plognjen at amd.com>
Date: Fri, 17 Jan 2025 15:44:32 +0000
Subject: [PATCH 1/3] Add LDS transpose intrinstics
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 16 ++++++++++++++++
mlir/test/Dialect/LLVMIR/rocdl.mlir | 15 +++++++++++++++
mlir/test/Target/LLVMIR/rocdl.mlir | 15 +++++++++++++++
3 files changed, 46 insertions(+)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 71dac3ad39b7b1..94be2f1c01dc9e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -412,6 +412,22 @@ def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]
def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
+//===---------------------------------------------------------------------===//
+// LDS transpose intrinsics
+
+def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;
+
+class ROCDL_Ds_Read_Tr_IntrOp<string mnemonic> :
+ ROCDL_IntrOp<mnemonic, [1], [], [], 1>,
+ Arguments<(ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr)>{
+ let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)";
+ }
+
+def ROCDL_ds_read_tr4_b64 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr4.b64">;
+def ROCDL_ds_read_tr8_b64 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr8.b64">;
+def ROCDL_ds_read_tr6_b96 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr6.b96">;
+def ROCDL_ds_read_tr16_b64 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr16.b64">;
+
//===---------------------------------------------------------------------===//
// Operations on raw buffer resources (stride of 0, bounds checks either off or in
// raw buffer mode).
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 92789246edb4f3..6676219203ddfe 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -227,6 +227,21 @@ func.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
llvm.return
}
+llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
+ // CHECK-LABEL: rocdl.ds.read.tr
+ // CHECK: rocdl.ds.read.tr4.b64 {{.*}} : <3> -> vector<2xi32>
+ %r0 = rocdl.ds.read.tr4.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
+ // CHECK: rocdl.ds.read.tr6.b96 {{.*}} : <3> -> vector<3xi32>
+ %r1 = rocdl.ds.read.tr6.b96 %ptr : !llvm.ptr<3> -> vector<3xi32>
+ // CHECK: rocdl.ds.read.tr8.b64 {{.*}} : <3> -> vector<2xi32>
+ %r2 = rocdl.ds.read.tr8.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
+ // CHECK: rocdl.ds.read.tr16.b64 {{.*}} : <3> -> vector<4xf16>
+ %r3 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xf16>
+ // CHECK: rocdl.ds.read.tr16.b64 {{.*}} : <3> -> vector<4xbf16>
+ %r4 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xbf16>
+ llvm.return %r3 : vector<4xf16>
+}
+
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
%stride : i16,
%numRecords : i32,
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 0620c23b5fdad7..21c4f8ab55469a 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -424,6 +424,21 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
llvm.return %r0 : vector<8xf32>
}
+llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
+ // CHECK-LABEL: rocdl.ds.read.tr
+ // CHECK: call <2 x i32> @llvm.amdgcn.ds.read.tr4.b64.v2i32(ptr addrspace(3) %0)
+ %r0 = rocdl.ds.read.tr4.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
+ // CHECK: call <3 x i32> @llvm.amdgcn.ds.read.tr6.b96.v3i32(ptr addrspace(3) %0)
+ %r1 = rocdl.ds.read.tr6.b96 %ptr : !llvm.ptr<3> -> vector<3xi32>
+ // CHECK: call <2 x i32> @llvm.amdgcn.ds.read.tr8.b64.v2i32(ptr addrspace(3) %0)
+ %r2 = rocdl.ds.read.tr8.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
+ // CHECK: call <4 x half> @llvm.amdgcn.ds.read.tr16.b64.v4f16(ptr addrspace(3) %0)
+ %r3 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xf16>
+ // CHECK: call <4 x bfloat> @llvm.amdgcn.ds.read.tr16.b64.v4bf16(ptr addrspace(3) %0)
+ %r4 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xbf16>
+ llvm.return %r3 : vector<4xf16>
+}
+
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
%stride : i16,
%numRecords : i32,
>From 9dbdf1e35cb02228d96d5dab4f72fb29584ffc23 Mon Sep 17 00:00:00 2001
From: Ognjen Plavsic <plognjen at amd.com>
Date: Fri, 17 Jan 2025 18:47:05 +0000
Subject: [PATCH 2/3] Add global to LDS intrinsic
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 24 ++++++++++++++++----
mlir/test/Dialect/LLVMIR/rocdl.mlir | 11 +++++++++
mlir/test/Target/LLVMIR/rocdl.mlir | 9 ++++++++
3 files changed, 39 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 94be2f1c01dc9e..0e23c183b4d923 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -415,18 +415,32 @@ def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8",
//===---------------------------------------------------------------------===//
// LDS transpose intrinsics
+def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;
-class ROCDL_Ds_Read_Tr_IntrOp<string mnemonic> :
+class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> :
ROCDL_IntrOp<mnemonic, [1], [], [], 1>,
Arguments<(ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr)>{
let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)";
}
-def ROCDL_ds_read_tr4_b64 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr4.b64">;
-def ROCDL_ds_read_tr8_b64 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr8.b64">;
-def ROCDL_ds_read_tr6_b96 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr6.b96">;
-def ROCDL_ds_read_tr16_b64 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr16.b64">;
+def ROCDL_ds_read_tr4_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr4.b64">;
+def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">;
+def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">;
+def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">;
+
+//===---------------------------------------------------------------------===//
+// Global load to LDS intrinsic
+
+def ROCDL_GlobalLoadLDSOp :
+ ROCDL_IntrOp<"global.load.lds", [], [], [], 0>,
+ Arguments<(ins Arg<ROCDLGlobalBuffer, "", [MemRead]>:$globalPtr,
+ Arg<ROCDLBufferLDS, "", [MemWrite]>:$ldsPtr,
+ I32:$size,
+ I32:$offset,
+ I32:$aux)> {
+ let assemblyFormat = "operands attr-dict";
+}
//===---------------------------------------------------------------------===//
// Operations on raw buffer resources (stride of 0, bounds checks either off or in
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 6676219203ddfe..c80ebebaafe3ad 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -242,6 +242,17 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
llvm.return %r3 : vector<4xf16>
}
+llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
+ %aux = llvm.mlir.constant(0 : i32) : i32
+ %offset = llvm.mlir.constant(0 : i32) : i32
+ %size = llvm.mlir.constant(10 : i32) : i32
+
+ //CHECK: rocdl.global.load.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
+ rocdl.global.load.lds %src, %dst, %size, %offset, %aux
+
+ llvm.return
+}
+
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
%stride : i16,
%numRecords : i32,
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 21c4f8ab55469a..996e0e34c790c9 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -439,6 +439,15 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
llvm.return %r3 : vector<4xf16>
}
+llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
+ %aux = llvm.mlir.constant(0 : i32) : i32
+ %offset = llvm.mlir.constant(0 : i32) : i32
+ %size = llvm.mlir.constant(10 : i32) : i32
+ //CHECK: call void @llvm.amdgcn.global.load.lds
+ rocdl.global.load.lds %src, %dst, %size, %offset, %aux
+ llvm.return
+}
+
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
%stride : i16,
%numRecords : i32,
>From 05471984c75ae6186bffe43790c436d7d563b20d Mon Sep 17 00:00:00 2001
From: Ognjen Plavsic <plognjen at amd.com>
Date: Mon, 20 Jan 2025 10:50:31 +0000
Subject: [PATCH 3/3] Address review comments
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 0e23c183b4d923..0b8c0f7f381c4a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -413,7 +413,7 @@ def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8",
def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
//===---------------------------------------------------------------------===//
-// LDS transpose intrinsics
+// LDS transpose intrinsics (available in GFX950)
def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;
@@ -422,7 +422,7 @@ class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> :
ROCDL_IntrOp<mnemonic, [1], [], [], 1>,
Arguments<(ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr)>{
let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)";
- }
+}
def ROCDL_ds_read_tr4_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr4.b64">;
def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">;
@@ -430,7 +430,7 @@ def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">;
def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">;
//===---------------------------------------------------------------------===//
-// Global load to LDS intrinsic
+// Global load to LDS intrinsic (available in GFX950)
def ROCDL_GlobalLoadLDSOp :
ROCDL_IntrOp<"global.load.lds", [], [], [], 0>,
More information about the Mlir-commits
mailing list