[Mlir-commits] [mlir] [ROCDL][WIP] Added matrix load-transpose ops for gfx1250+ (PR #165564)

Ravil Dorozhinskii llvmlistbot at llvm.org
Wed Oct 29 06:45:54 PDT 2025


https://github.com/ravil-mobile updated https://github.com/llvm/llvm-project/pull/165564

>From 895975f980bf24ef7f70e17f3f416e84ffa51493 Mon Sep 17 00:00:00 2001
From: ravil-mobile <ravil.aviva.com at gmail.com>
Date: Wed, 29 Oct 2025 13:40:23 +0000
Subject: [PATCH]  [ROCDL][WIP] Added matrix load-transpose ops for gfx1250+

---
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 73 +++++++++++++++++++-
 mlir/test/Dialect/LLVMIR/rocdl.mlir          | 24 +++++++
 mlir/test/Target/LLVMIR/rocdl.mlir           | 24 +++++++
 3 files changed, 119 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 5241f9a6f2b43..59f9ec18e1608 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -321,6 +321,7 @@ def ROCDL_BarrierOp : ROCDL_Op<"barrier"> {
   let assemblyFormat = "attr-dict";
 }
 
+def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
 def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;
 
 def ROCDL_BarrierInitOp : ROCDL_IntrOp<"s.barrier.init", [], [], [], 0, 0, 0, 0, [1], ["id"]>,
@@ -631,8 +632,6 @@ def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]
 //===---------------------------------------------------------------------===//
 // LDS transpose intrinsics (available in GFX950)
 
-def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
-
 class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> :
   ROCDL_IntrOp<mnemonic, [1], [], [], 1, 0, 1> {
   dag args = (ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr);
@@ -650,6 +649,76 @@ def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">;
 def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">;
 def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">;
 
+
+
+//===---------------------------------------------------------------------===//
+// Glb/DS load-transpose intrinsics (available in GFX1250+)
+
+class WrapperType<Type t, int w> {
+  Type type = t;
+  int bitwidth = w;
+}
+class IType<I t> : WrapperType<t, t.bitwidth> {}
+class FType<F t> : WrapperType<t, t.bitwidth> {}
+def BF16Type : WrapperType<BF16, 16> {}
+
+
+class AddrKind<string n, int s> {
+  string name = n;
+  int space = s;
+  LLVM_PointerInAddressSpace type = LLVM_PointerInAddressSpace<s>;
+}
+def GlobalAddrKind : AddrKind<"global", 1>;
+def DSAddrKind : AddrKind<"ds", 3>;
+
+class ROCDL_TrLoadOpMeta<AddrKind addKind, int inElemBits, int outElemBits, WrapperType outElemType> {
+  string inBits = !cast<string>(inElemBits);
+  string outBits = !cast<string>(outElemBits);
+  LLVM_PointerInAddressSpace inType = addKind.type;
+  int outNumElem = !div(outElemBits, outElemType.bitwidth);
+  ROCDL_ConcreteVector outType = ROCDL_ConcreteVector<outElemType.type, outNumElem>;
+  string inBitsEnc = !if(!eq(addKind.space, 1),
+                     !if(!eq(inElemBits, 8),
+                     !if(!eq(inElemBits, 16), "", inBits), inBits), inBits);
+  string mnemonic = addKind.name # ".load.tr" # inBitsEnc # ".b" # outBits;
+}
+
+class ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta meta> :
+  ROCDL_IntrOp<meta.mnemonic, [1], [], [], 1, 0, 1> {
+
+  dag args = (ins Arg<meta.inType, "", [MemRead]>:$ptr);
+  let arguments = !con(args, baseArgs);
+  let results = (outs meta.outType:$res);
+  let summary = "Loads and transposes a matrix from global memory or ds to registers (available in gfx1250+).";
+  let description = [{
+    Load a matrix of }] # meta.inBits # [{-bit data from the global memory,
+    transpose data between row-major and column-major order,
+    and store the result into a }] # meta.outBits # [{-bit vector register.
+
+    Available in gfx1250+.
+  }];
+  let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)";
+  let extraClassDefinition = [{
+    ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
+      return {getPtr()};
+    }
+  }];
+}
+
+def ROCDL_GlobalLoadTr4_2I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 4, 64, IType<I32>>>;
+def ROCDL_GlobalLoadTr8_2I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 8, 64, IType<I32>>>;
+def ROCDL_GlobalLoadTr6_3I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 6, 96, IType<I32>>>;
+def ROCDL_GlobalLoadTr8_8I16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 16, 128, IType<I16>>>;
+//def ROCDL_GlobalLoadTr8_8F16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 8, 128, FType<F16>>>;
+//def ROCDL_GlobalLoadTr8_8BF16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<GlobalAddrKind, 8, 128, BF16Type>>;
+
+def ROCDL_DsLoadTr4_2I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 4, 64, IType<I32>>>;
+def ROCDL_DsLoadTr8_2I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 8, 64, IType<I32>>>;
+def ROCDL_DsLoadTr6_3I32 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 6, 96, IType<I32>>>;
+def ROCDL_DsLoadTr16_8I16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 16, 128, IType<I16>>>;
+//def ROCDL_DsLoadTr16_8F16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 16, 128, FType<F16>>>;
+//def ROCDL_DsLoadTr16_8BF16 : ROCDL_TrLoadOp<ROCDL_TrLoadOpMeta<DSAddrKind, 16, 128, BF16Type>>;
+
 //===---------------------------------------------------------------------===//
 // Load to LDS intrinsic (available in GFX9 and GFX10)
 //===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index e703600c71c8e..f5bae3078a6bb 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -650,6 +650,30 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
   llvm.return %r3 : vector<4xf16>
 }
 
+llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) {
+  // CHECK-LABEL: @rocdl.load.tr.ops
+  // CHECK-SAME: (%[[GL_PTR:.+]]: !llvm.ptr<1>, %[[DS_OTR:.+]]: !llvm.ptr<3>)
+  // CHECK: %0 = rocdl.global.load.tr4.b64 %[[GL_PTR]] : <1> -> vector<2xi32>
+  // CHECK: %1 = rocdl.global.load.tr.b64 %[[GL_PTR]] : <1> -> vector<2xi32>
+  // CHECK: %2 = rocdl.global.load.tr6.b96 %[[GL_PTR]] : <1> -> vector<3xi32>
+  // CHECK: %3 = rocdl.global.load.tr.b128 %[[GL_PTR]] : <1> -> vector<8xi16>
+  // CHECK: %4 = rocdl.ds.load.tr4.b64 %[[DS_OTR]] : <3> -> vector<2xi32>
+  // CHECK: %5 = rocdl.ds.load.tr8.b64 %[[DS_OTR]] : <3> -> vector<2xi32>
+  // CHECK: %6 = rocdl.ds.load.tr6.b96 %[[DS_OTR]] : <3> -> vector<3xi32>
+  // CHECK: %7 = rocdl.ds.load.tr16.b128 %[[DS_OTR]] : <3> -> vector<8xi16>
+
+  rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32>
+  rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32>
+  rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3 x i32>
+  rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x i16>
+
+  rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32>
+  rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32>
+  rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3 x i32>
+  rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x i16>
+  llvm.return
+}
+
 llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) {
   // CHECK-LABEL @rocdl.load.to.lds
   //CHECK: rocdl.load.to.lds %{{.*}}, %{{.*}}, 4, 0, 0 : <7>
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 8a848221a50dd..2285296a0e306 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1028,6 +1028,30 @@ llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
   llvm.return %r3 : vector<4xf16>
 }
 
+llvm.func @rocdl.load.tr.ops(%gl_ptr : !llvm.ptr<1>, %ds_ptr : !llvm.ptr<3>) {
+  // CHECK-LABEL: rocdl.load.tr.ops
+  // CHECK-SAME: (ptr addrspace(1) %[[GL_PTR:.+]], ptr addrspace(3) %[[DS_PTR:.+]])
+  // CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr4.b64.v2i32(ptr addrspace(1) %[[GL_PTR]])
+  // CHECK: call <2 x i32> @llvm.amdgcn.global.load.tr.b64.v2i32(ptr addrspace(1) %[[GL_PTR]])
+  // CHECK: call <3 x i32> @llvm.amdgcn.global.load.tr6.b96.v3i32(ptr addrspace(1) %[[GL_PTR]])
+  // CHECK: call <8 x i16> @llvm.amdgcn.global.load.tr.b128.v8i16(ptr addrspace(1) %[[GL_PTR]])
+  // CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr4.b64.v2i32(ptr addrspace(3) %[[DS_PTR]])
+  // CHECK: call <2 x i32> @llvm.amdgcn.ds.load.tr8.b64.v2i32(ptr addrspace(3) %[[DS_PTR]])
+  // CHECK: call <3 x i32> @llvm.amdgcn.ds.load.tr6.b96.v3i32(ptr addrspace(3) %[[DS_PTR]])
+  // CHECK: call <8 x i16> @llvm.amdgcn.ds.load.tr16.b128.v8i16(ptr addrspace(3) %[[DS_PTR]])
+
+  rocdl.global.load.tr4.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32>
+  rocdl.global.load.tr.b64 %gl_ptr : !llvm.ptr<1> -> vector<2 x i32>
+  rocdl.global.load.tr6.b96 %gl_ptr : !llvm.ptr<1> -> vector<3 x i32>
+  rocdl.global.load.tr.b128 %gl_ptr : !llvm.ptr<1> -> vector<8 x i16>
+
+  rocdl.ds.load.tr4.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32>
+  rocdl.ds.load.tr8.b64 %ds_ptr : !llvm.ptr<3> -> vector<2 x i32>
+  rocdl.ds.load.tr6.b96 %ds_ptr : !llvm.ptr<3> -> vector<3 x i32>
+  rocdl.ds.load.tr16.b128 %ds_ptr : !llvm.ptr<3> -> vector<8 x i16>
+  llvm.return
+}
+
 llvm.func @rocdl.load.to.lds(%src : !llvm.ptr<7>, %dst: !llvm.ptr<3>) {
   //CHECK: call void @llvm.amdgcn.load.to.lds.p7
   rocdl.load.to.lds %src, %dst, 4, 0, 0 : !llvm.ptr<7>



More information about the Mlir-commits mailing list