[Mlir-commits] [mlir] 303fba0 - [mlir][ROCDL] Add tensor load and store instructions to ROCDL (#165016)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 28 13:05:35 PDT 2025
Author: Justin Rosner
Date: 2025-10-28T20:05:30Z
New Revision: 303fba0214c9bf3e7d892417b788a012ec6427b4
URL: https://github.com/llvm/llvm-project/commit/303fba0214c9bf3e7d892417b788a012ec6427b4
DIFF: https://github.com/llvm/llvm-project/commit/303fba0214c9bf3e7d892417b788a012ec6427b4.diff
LOG: [mlir][ROCDL] Add tensor load and store instructions to ROCDL (#165016)
Add support for `tensor.load.to.lds` and `tensor.store.from.lds`
instructions in ROCDL.
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
mlir/test/Dialect/LLVMIR/rocdl.mlir
mlir/test/Target/LLVMIR/rocdl.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index d2df244eb9363..5241f9a6f2b43 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -146,6 +146,35 @@ class ROCDL_DimGetterFunctionOp<string mnemonic, string device_function,
];
}
+//===----------------------------------------------------------------------===//
+// ROCDL vector types definitions
+//===----------------------------------------------------------------------===//
+
+class ROCDL_ConcreteVector<Type elem, int length> :
+ FixedVectorOfLengthAndType<[length], [elem]>,
+ BuildableType<
+ "::mlir::VectorType::get({" # length # "} ,"
+ # elem.builderCall # ")">;
+
+def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
+def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
+def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
+def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
+def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
+def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
+def ROCDL_V4I32Type : ROCDL_ConcreteVector<I32, 4>;
+def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
+def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
+def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
+def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
+def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
+def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
+def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
+def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
+def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
+def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
+def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
+
//===----------------------------------------------------------------------===//
// Wave-level primitives
//===----------------------------------------------------------------------===//
@@ -663,6 +692,68 @@ def ROCDL_GlobalLoadLDSOp :
}];
}
+//===---------------------------------------------------------------------===//
+// Tensor load/store intrinsics (available in GFX1250)
+//===---------------------------------------------------------------------===//
+
+// Base class for tensor load/store operations with 4 descriptor groups.
+class ROCDL_TensorLDSIntrOp<string mnemonic> :
+ ROCDL_IntrOp<mnemonic, [], [], [], 0, 0, 1, 0, [4], ["cachePolicy"]> {
+ dag args = (ins ROCDL_V4I32Type:$dgroup0, ROCDL_V8I32Type:$dgroup1,
+ ROCDL_V4I32Type:$dgroup2, ROCDL_V4I32Type:$dgroup3,
+ I32Attr:$cachePolicy);
+ let arguments = !con(args, baseArgs);
+ let summary = "Base class for ROCDL tensor load/store to/from LDS.";
+ let description = [{
+ Moves tiles of tensor data between global memory and LDS. The tile is
+ described by the $dgroup descriptors. 4 $dgroup descriptors allows for
+ movement of up to 5D tensors. $cachePolicy describes the memory scope and an
+ indicator of expected data re-use.
+
+ This op is for gfx1250+ architectures.
+ }];
+ let assemblyFormat = [{
+ attr-dict operands `cachepolicy` $cachePolicy `:` type($dgroup0) `,` type($dgroup1)
+ }];
+ let extraClassDefinition = [{
+ SmallVector<Value> $cppClass::getAccessedOperands() {
+ return {getDgroup0(), getDgroup1(), getDgroup2(), getDgroup3()};
+ }
+ }];
+}
+
+// Base class for tensor load/store operations with 2 descriptor groups
+// (D2 variant).
+class ROCDL_TensorLDSIntrD2Op<string mnemonic> :
+ ROCDL_IntrOp<mnemonic, [], [], [], 0, 0, 1, 0, [2], ["cachePolicy"]> {
+ dag args = (ins ROCDL_V4I32Type:$dgroup0, ROCDL_V8I32Type:$dgroup1,
+ I32Attr:$cachePolicy);
+ let arguments = !con(args, baseArgs);
+ let summary = "Base class for ROCDL tensor load/store to/from LDS (D2 variant).";
+ let description = [{
+ Moves tiles of tensor data between global memory and LDS. The tile is
+ described by the $dgroup descriptors. 2 $dgroup descriptors allows for
+ movement of up to 2D tensors. $cachePolicy describes the memory scope and an
+ indicator of expected data re-use.
+
+ This op is for gfx1250+ architectures.
+ }];
+ let assemblyFormat = [{
+ attr-dict operands `cachepolicy` $cachePolicy `:` type($dgroup0) `,` type($dgroup1)
+ }];
+ let extraClassDefinition = [{
+ SmallVector<Value> $cppClass::getAccessedOperands() {
+ return {getDgroup0(), getDgroup1()};
+ }
+ }];
+}
+
+// Tensor load and store operations
+def ROCDL_TensorLoadToLDSOp : ROCDL_TensorLDSIntrOp<"tensor.load.to.lds">;
+def ROCDL_TensorStoreFromLDSOp : ROCDL_TensorLDSIntrOp<"tensor.store.from.lds">;
+def ROCDL_TensorLoadToLDSD2Op : ROCDL_TensorLDSIntrD2Op<"tensor.load.to.lds.d2">;
+def ROCDL_TensorStoreFromLDSD2Op : ROCDL_TensorLDSIntrD2Op<"tensor.store.from.lds.d2">;
+
//===---------------------------------------------------------------------===//
// Operations on raw buffer resources (stride of 0, bounds checks either off or in
// raw buffer mode).
@@ -932,30 +1023,6 @@ def ROCDL_Permlane32SwapOp : ROCDL_IntrOp<"permlane32.swap", [], [],
}];
}
-class ROCDL_ConcreteVector<Type elem, int length> :
- FixedVectorOfLengthAndType<[length], [elem]>,
- BuildableType<
- "::mlir::VectorType::get({" # length # "} ,"
- # elem.builderCall # ")">;
-
-def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
-def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
-def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
-def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
-def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
-def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
-def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
-def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
-def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
-def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
-def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
-def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
-def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
-def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
-def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
-def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
-def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
-
//===---------------------------------------------------------------------===//
// 16-bit float intrinsics
//===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index d270ee8b089aa..e703600c71c8e 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -664,6 +664,36 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
llvm.return
}
+// CHECK-LABEL @rocdl.tensor.load.to.lds
+llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
+ %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
+ // CHECK: rocdl.tensor.load.to.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
+// CHECK-LABEL @rocdl.tensor.store.from.lds
+llvm.func @rocdl.tensor.store.from.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
+ %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
+ // CHECK: rocdl.tensor.store.from.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ rocdl.tensor.store.from.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
+// CHECK-LABEL @rocdl.tensor.load.to.lds.d2
+llvm.func @rocdl.tensor.load.to.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
+ // CHECK: rocdl.tensor.load.to.lds.d2 %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ rocdl.tensor.load.to.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
+// CHECK-LABEL @rocdl.tensor.store.from.lds.d2
+llvm.func @rocdl.tensor.store.from.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
+ // CHECK: rocdl.tensor.store.from.lds.d2 %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ rocdl.tensor.store.from.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
%stride : i16,
%numRecords : i64,
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 30126f6bff05a..8a848221a50dd 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1040,6 +1040,36 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
llvm.return
}
+// CHECK-LABEL: rocdl.tensor.load.to.lds
+llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
+ %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
+ // CHECK: call void @llvm.amdgcn.tensor.load.to.lds(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 0)
+ rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
+// CHECK-LABEL: rocdl.tensor.store.from.lds
+llvm.func @rocdl.tensor.store.from.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
+ %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
+ // CHECK: call void @llvm.amdgcn.tensor.store.from.lds(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 0)
+ rocdl.tensor.store.from.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
+// CHECK-LABEL: rocdl.tensor.load.to.lds.d2
+llvm.func @rocdl.tensor.load.to.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
+ // CHECK: call void @llvm.amdgcn.tensor.load.to.lds.d2(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i32 0)
+ rocdl.tensor.load.to.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
+// CHECK-LABEL: rocdl.tensor.store.from.lds.d2
+llvm.func @rocdl.tensor.store.from.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
+ // CHECK: call void @llvm.amdgcn.tensor.store.from.lds.d2(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i32 0)
+ rocdl.tensor.store.from.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
%stride : i16,
%numRecords : i64,
More information about the Mlir-commits
mailing list