[Mlir-commits] [mlir] [MLIR][ROCDL] Add ops for LDS read transpose and global to LDS intrinsics (PR #123530)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 19 13:45:22 PST 2025


https://github.com/plognjen created https://github.com/llvm/llvm-project/pull/123530

This PR adds missing ds\.read.tr4\.b64, ds\.read\.tr8\.b64, ds\.read\.tr6\.b96, 
ds\.read\.tr16\.b64 and global\.load\.lds ops to
 the ROCDL dialect.
The ops are converted to the corresponding intrinsic calls during the
translation from MLIR to LLVM IRs. 

>From 68987020e86940a2114e58e2a0682d5fbd034a9e Mon Sep 17 00:00:00 2001
From: Ognjen Plavsic <plognjen at amd.com>
Date: Fri, 17 Jan 2025 15:44:32 +0000
Subject: [PATCH 1/2] Add LDS transpose intrinstics

---
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 16 ++++++++++++++++
 mlir/test/Dialect/LLVMIR/rocdl.mlir          | 15 +++++++++++++++
 mlir/test/Target/LLVMIR/rocdl.mlir           | 15 +++++++++++++++
 3 files changed, 46 insertions(+)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 71dac3ad39b7b1..94be2f1c01dc9e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -412,6 +412,22 @@ def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]
 def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
 def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
 
+//===---------------------------------------------------------------------===//
+// LDS transpose intrinsics
+
+def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;
+
+class ROCDL_Ds_Read_Tr_IntrOp<string mnemonic> :
+  ROCDL_IntrOp<mnemonic, [1], [], [], 1>,
+  Arguments<(ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr)>{
+  let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)";
+  }
+
+def ROCDL_ds_read_tr4_b64 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr4.b64">;
+def ROCDL_ds_read_tr8_b64 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr8.b64">;
+def ROCDL_ds_read_tr6_b96 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr6.b96">;
+def ROCDL_ds_read_tr16_b64 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr16.b64">;
+
 //===---------------------------------------------------------------------===//
 // Operations on raw buffer resources (stride of 0, bounds checks either off or in
 // raw buffer mode).
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 92789246edb4f3..6676219203ddfe 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -227,6 +227,21 @@ func.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
   llvm.return
 }
 
+llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
+  // CHECK-LABEL: rocdl.ds.read.tr
+  // CHECK: rocdl.ds.read.tr4.b64 {{.*}} : <3> -> vector<2xi32>
+  %r0 = rocdl.ds.read.tr4.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
+  // CHECK: rocdl.ds.read.tr6.b96 {{.*}} : <3> -> vector<3xi32>
+  %r1 = rocdl.ds.read.tr6.b96 %ptr : !llvm.ptr<3> -> vector<3xi32>
+  // CHECK: rocdl.ds.read.tr8.b64 {{.*}} : <3> -> vector<2xi32>
+  %r2 = rocdl.ds.read.tr8.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
+  // CHECK: rocdl.ds.read.tr16.b64 {{.*}} : <3> -> vector<4xf16>
+  %r3 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xf16>
+  // CHECK: rocdl.ds.read.tr16.b64 {{.*}} : <3> -> vector<4xbf16>
+  %r4 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xbf16>
+  llvm.return %r3 : vector<4xf16>
+}
+
 llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
                                   %stride : i16,
                                   %numRecords : i32,
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 0620c23b5fdad7..21c4f8ab55469a 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -424,6 +424,21 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
   llvm.return %r0 : vector<8xf32>
 }
 
+llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
+  // CHECK-LABEL: rocdl.ds.read.tr
+  // CHECK: call <2 x i32> @llvm.amdgcn.ds.read.tr4.b64.v2i32(ptr addrspace(3) %0)
+  %r0 = rocdl.ds.read.tr4.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
+  // CHECK: call <3 x i32> @llvm.amdgcn.ds.read.tr6.b96.v3i32(ptr addrspace(3) %0)
+  %r1 = rocdl.ds.read.tr6.b96 %ptr : !llvm.ptr<3> -> vector<3xi32>
+  // CHECK: call <2 x i32> @llvm.amdgcn.ds.read.tr8.b64.v2i32(ptr addrspace(3) %0)
+  %r2 = rocdl.ds.read.tr8.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
+  // CHECK: call <4 x half> @llvm.amdgcn.ds.read.tr16.b64.v4f16(ptr addrspace(3) %0)
+  %r3 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xf16>
+  // CHECK: call <4 x bfloat> @llvm.amdgcn.ds.read.tr16.b64.v4bf16(ptr addrspace(3) %0)
+  %r4 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xbf16>
+  llvm.return %r3 : vector<4xf16>
+}
+
 llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
                                   %stride : i16,
                                   %numRecords : i32,

>From 9dbdf1e35cb02228d96d5dab4f72fb29584ffc23 Mon Sep 17 00:00:00 2001
From: Ognjen Plavsic <plognjen at amd.com>
Date: Fri, 17 Jan 2025 18:47:05 +0000
Subject: [PATCH 2/2] Add global to LDS intrinsic

---
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 24 ++++++++++++++++----
 mlir/test/Dialect/LLVMIR/rocdl.mlir          | 11 +++++++++
 mlir/test/Target/LLVMIR/rocdl.mlir           |  9 ++++++++
 3 files changed, 39 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 94be2f1c01dc9e..0e23c183b4d923 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -415,18 +415,32 @@ def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8",
 //===---------------------------------------------------------------------===//
 // LDS transpose intrinsics
 
+def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
 def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;
 
-class ROCDL_Ds_Read_Tr_IntrOp<string mnemonic> :
+class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> :
   ROCDL_IntrOp<mnemonic, [1], [], [], 1>,
   Arguments<(ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr)>{
   let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)";
   }
 
-def ROCDL_ds_read_tr4_b64 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr4.b64">;
-def ROCDL_ds_read_tr8_b64 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr8.b64">;
-def ROCDL_ds_read_tr6_b96 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr6.b96">;
-def ROCDL_ds_read_tr16_b64 : ROCDL_Ds_Read_Tr_IntrOp<"ds.read.tr16.b64">;
+def ROCDL_ds_read_tr4_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr4.b64">;
+def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">;
+def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">;
+def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">;
+
+//===---------------------------------------------------------------------===//
+// Global load to LDS intrinsic
+
+def ROCDL_GlobalLoadLDSOp :
+  ROCDL_IntrOp<"global.load.lds", [], [], [], 0>,
+  Arguments<(ins Arg<ROCDLGlobalBuffer, "", [MemRead]>:$globalPtr,
+                 Arg<ROCDLBufferLDS, "", [MemWrite]>:$ldsPtr,
+                 I32:$size,
+                 I32:$offset,
+                 I32:$aux)> {
+  let assemblyFormat = "operands attr-dict";
+}
 
 //===---------------------------------------------------------------------===//
 // Operations on raw buffer resources (stride of 0, bounds checks either off or in
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 6676219203ddfe..c80ebebaafe3ad 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -242,6 +242,17 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
   llvm.return %r3 : vector<4xf16>
 }
 
+llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
+  %aux = llvm.mlir.constant(0 : i32) : i32
+  %offset = llvm.mlir.constant(0 : i32) : i32
+  %size = llvm.mlir.constant(10 : i32) : i32
+
+  //CHECK: rocdl.global.load.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
+  rocdl.global.load.lds %src, %dst, %size, %offset, %aux
+
+  llvm.return
+}
+
 llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
                                   %stride : i16,
                                   %numRecords : i32,
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 21c4f8ab55469a..996e0e34c790c9 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -439,6 +439,15 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
   llvm.return %r3 : vector<4xf16>
 }
 
+llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
+  %aux = llvm.mlir.constant(0 : i32) : i32
+  %offset = llvm.mlir.constant(0 : i32) : i32
+  %size = llvm.mlir.constant(10 : i32) : i32
+  //CHECK: call void @llvm.amdgcn.global.load.lds
+  rocdl.global.load.lds %src, %dst, %size, %offset, %aux
+  llvm.return
+}
+
 llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
                                   %stride : i16,
                                   %numRecords : i32,



More information about the Mlir-commits mailing list