[Mlir-commits] [mlir] [mlir][ROCDL] Add tensor load and store instructions to ROCDL (PR #165016)
Justin Rosner
llvmlistbot at llvm.org
Mon Oct 27 05:28:40 PDT 2025
https://github.com/justinrosner updated https://github.com/llvm/llvm-project/pull/165016
>From 1bea035d508a92e1a70b40e27272e60b8ab39bd9 Mon Sep 17 00:00:00 2001
From: Justin Rosner <justin.rosner at amd.com>
Date: Fri, 24 Oct 2025 17:15:25 +0000
Subject: [PATCH] [mlir][ROCDL] Add tensor load and store instructions to ROCDL
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 44 ++++++++++++++++++++
mlir/test/Dialect/LLVMIR/rocdl.mlir | 30 +++++++++++++
mlir/test/Target/LLVMIR/rocdl.mlir | 30 +++++++++++++
3 files changed, 104 insertions(+)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index d2df244eb9363..7b240f02653d5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -663,6 +663,50 @@ def ROCDL_GlobalLoadLDSOp :
}];
}
+//===---------------------------------------------------------------------===//
+// Tensor load/store intrinsics (available in GFX1250)
+//===---------------------------------------------------------------------===//
+
+// Base class for tensor load/store operations with 4 descriptor groups
+class ROCDL_TensorLDSIntrOp<string mnemonic> :
+ ROCDL_IntrOp<mnemonic, [], [], [], 0, 0, 1, 0, [4], ["cachePolicy"]> {
+ dag args = (ins LLVM_Type:$dgroup0, LLVM_Type:$dgroup1, LLVM_Type:$dgroup2,
+ LLVM_Type:$dgroup3, I32Attr:$cachePolicy);
+ let arguments = !con(args, baseArgs);
+ let assemblyFormat = [{
+ $dgroup0 `,` $dgroup1 `,` $dgroup2 `,` $dgroup3 `,` $cachePolicy
+ attr-dict `:` type($dgroup0) `,` type($dgroup1) `,` type($dgroup2) `,` type($dgroup3)
+ }];
+ let extraClassDefinition = [{
+ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
+ return {getDgroup0(), getDgroup1(), getDgroup2(), getDgroup3()};
+ }
+ }];
+}
+
+// Base class for tensor load/store operations with 2 descriptor groups
+// (D2 variant)
+class ROCDL_TensorLDSIntrD2Op<string mnemonic> :
+ ROCDL_IntrOp<mnemonic, [], [], [], 0, 0, 1, 0, [2], ["cachePolicy"]> {
+ dag args = (ins LLVM_Type:$dgroup0, LLVM_Type:$dgroup1, I32Attr:$cachePolicy);
+ let arguments = !con(args, baseArgs);
+ let assemblyFormat = [{
+ $dgroup0 `,` $dgroup1 `,` $cachePolicy
+ attr-dict `:` type($dgroup0) `,` type($dgroup1)
+ }];
+ let extraClassDefinition = [{
+ ::llvm::SmallVector<::mlir::Value> $cppClass::getAccessedOperands() {
+ return {getDgroup0(), getDgroup1()};
+ }
+ }];
+}
+
+// Tensor load and store operations
+def ROCDL_TensorLoadToLDSOp : ROCDL_TensorLDSIntrOp<"tensor.load.to.lds">;
+def ROCDL_TensorStoreFromLDSOp : ROCDL_TensorLDSIntrOp<"tensor.store.from.lds">;
+def ROCDL_TensorLoadToLDSD2Op : ROCDL_TensorLDSIntrD2Op<"tensor.load.to.lds.d2">;
+def ROCDL_TensorStoreFromLDSD2Op : ROCDL_TensorLDSIntrD2Op<"tensor.store.from.lds.d2">;
+
//===---------------------------------------------------------------------===//
// Operations on raw buffer resources (stride of 0, bounds checks either off or in
// raw buffer mode).
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index d270ee8b089aa..13c01ddd6b1f2 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -664,6 +664,36 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
llvm.return
}
+llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
+ %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
+ // CHECK-LABEL @rocdl.tensor.load.to.lds
+ // CHECK: rocdl.tensor.load.to.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, 0 : vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>
+ rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3, 0 : vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>
+ llvm.return
+}
+
+llvm.func @rocdl.tensor.store.from.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
+ %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
+ // CHECK-LABEL @rocdl.tensor.store.from.lds
+ // CHECK: rocdl.tensor.store.from.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, 0 : vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>
+ rocdl.tensor.store.from.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3, 0 : vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>
+ llvm.return
+}
+
+llvm.func @rocdl.tensor.load.to.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
+ // CHECK-LABEL @rocdl.tensor.load.to.lds.d2
+ // CHECK: rocdl.tensor.load.to.lds.d2 %{{.*}}, %{{.*}}, 0 : vector<4xi32>, vector<8xi32>
+ rocdl.tensor.load.to.lds.d2 %dgroup0, %dgroup1, 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
+llvm.func @rocdl.tensor.store.from.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
+ // CHECK-LABEL @rocdl.tensor.store.from.lds.d2
+ // CHECK: rocdl.tensor.store.from.lds.d2 %{{.*}}, %{{.*}}, 0 : vector<4xi32>, vector<8xi32>
+ rocdl.tensor.store.from.lds.d2 %dgroup0, %dgroup1, 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
%stride : i16,
%numRecords : i64,
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 30126f6bff05a..ac5e703458ed1 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1040,6 +1040,36 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
llvm.return
}
+llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
+ %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
+ // CHECK-LABEL: rocdl.tensor.load.to.lds
+ // CHECK: call void @llvm.amdgcn.tensor.load.to.lds(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 0)
+ rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3, 0 : vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>
+ llvm.return
+}
+
+llvm.func @rocdl.tensor.store.from.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
+ %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
+ // CHECK-LABEL: rocdl.tensor.store.from.lds
+ // CHECK: call void @llvm.amdgcn.tensor.store.from.lds(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 0)
+ rocdl.tensor.store.from.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3, 0 : vector<4xi32>, vector<8xi32>, vector<4xi32>, vector<4xi32>
+ llvm.return
+}
+
+llvm.func @rocdl.tensor.load.to.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
+ // CHECK-LABEL: rocdl.tensor.load.to.lds.d2
+ // CHECK: call void @llvm.amdgcn.tensor.load.to.lds.d2(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i32 0)
+ rocdl.tensor.load.to.lds.d2 %dgroup0, %dgroup1, 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
+llvm.func @rocdl.tensor.store.from.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
+ // CHECK-LABEL: rocdl.tensor.store.from.lds.d2
+ // CHECK: call void @llvm.amdgcn.tensor.store.from.lds.d2(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i32 0)
+ rocdl.tensor.store.from.lds.d2 %dgroup0, %dgroup1, 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
%stride : i16,
%numRecords : i64,
More information about the Mlir-commits
mailing list