[Mlir-commits] [mlir] [ROCDL] Added tensor load/store ops (PR #165390)
Ravil Dorozhinskii
llvmlistbot at llvm.org
Tue Oct 28 06:32:21 PDT 2025
https://github.com/ravil-mobile created https://github.com/llvm/llvm-project/pull/165390
This patch introduces tensor load/store ops in the ROCDL dialect
Specifically:
tensor loads/stores <=2D and <=5D variants
Tests:
Added lit-tests to check MLIR -> LLVM lowering
>From 5954d73f153c2e55a8808ba4dcf4c510b909004c Mon Sep 17 00:00:00 2001
From: ravil-mobile <ravil.aviva.com at gmail.com>
Date: Tue, 28 Oct 2025 13:27:38 +0000
Subject: [PATCH] [ROCDL] Added tensor load/store ops
This patch introduces tensor load/store ops in the ROCDL dialect
Specifically:
tensor loads/stores <=2D and <=5D variants
Tests:
Added lit-tests to check MLIR -> LLVM lowering
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 112 +++++++++++++++----
mlir/test/Dialect/LLVMIR/rocdl.mlir | 20 ++++
mlir/test/Target/LLVMIR/rocdl.mlir | 20 ++++
3 files changed, 128 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index d2df244eb9363..6bb968c24027f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -146,6 +146,35 @@ class ROCDL_DimGetterFunctionOp<string mnemonic, string device_function,
];
}
+//===----------------------------------------------------------------------===//
+// ROCDL vector types definitions
+//===----------------------------------------------------------------------===//
+
+class ROCDL_ConcreteVector<Type elem, int length> :
+ FixedVectorOfLengthAndType<[length], [elem]>,
+ BuildableType<
+ "::mlir::VectorType::get({" # length # "} ,"
+ # elem.builderCall # ")">;
+
+def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
+def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
+def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
+def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
+def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
+def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
+def ROCDL_V4I32Type : ROCDL_ConcreteVector<I32, 4>;
+def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
+def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
+def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
+def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
+def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
+def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
+def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
+def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
+def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
+def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
+def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
+
//===----------------------------------------------------------------------===//
// Wave-level primitives
//===----------------------------------------------------------------------===//
@@ -805,6 +834,65 @@ def ROCDL_RawBufferAtomicCmpSwap :
}];
}
+//===---------------------------------------------------------------------===//
+// Raw tensor load/store intrinsics: gfx12+
+
+def ROCDL_TensorLoadToLds :
+ ROCDL_IntrOp<"tensor.load.to.lds", [], [], [], 0, 0, 0, 0, [4], ["cpol"]>,
+ Arguments<(ins ROCDL_V4I32Type:$desc0,
+ ROCDL_V8I32Type:$desc1,
+ ROCDL_V4I32Type:$desc2,
+ ROCDL_V4I32Type:$desc3,
+ I32Attr:$cpol)>{
+ let description = [{
+ Loads tensor data from Global to LDS. Available on gfx12+.
+ }];
+ let assemblyFormat = [{
+ attr-dict operands `cachepolicy` $cpol
+ }];
+}
+
+def ROCDL_TensorLoadToLdsD2 :
+ ROCDL_IntrOp<"tensor.load.to.lds.d2", [], [], [], 0, 0, 0, 0, [2], ["cpol"]>,
+ Arguments<(ins ROCDL_V4I32Type:$desc0,
+ ROCDL_V8I32Type:$desc1,
+ I32Attr:$cpol)>{
+ let description = [{
+ Loads 2D tensor data from Global to LDS. Available on gfx12+. TODO
+ }];
+ let assemblyFormat = [{
+ attr-dict operands `cachepolicy` $cpol
+ }];
+}
+
+def ROCDL_TensorStoreFromLds :
+ ROCDL_IntrOp<"tensor.store.from.lds", [], [], [], 0, 0, 0, 0, [4], ["cpol"]>,
+ Arguments<(ins ROCDL_V4I32Type:$desc0,
+ ROCDL_V8I32Type:$desc1,
+ ROCDL_V4I32Type:$desc2,
+ ROCDL_V4I32Type:$desc3,
+ I32Attr:$cpol)>{
+ let description = [{
+ Stores tensor data from Global to LDS. Available on gfx12+.
+ }];
+ let assemblyFormat = [{
+ attr-dict operands `cachepolicy` $cpol
+ }];
+}
+
+def ROCDL_TensorStoreFromLdsD2 :
+ ROCDL_IntrOp<"tensor.store.from.lds.d2", [], [], [], 0, 0, 0, 0, [2], ["cpol"]>,
+ Arguments<(ins ROCDL_V4I32Type:$desc0,
+ ROCDL_V8I32Type:$desc1,
+ I32Attr:$cpol)>{
+ let description = [{
+ Stores tensor 2D data from Global to LDS. Available on gfx12+. TODO
+ }];
+ let assemblyFormat = [{
+ attr-dict operands `cachepolicy` $cpol
+ }];
+}
+
//===---------------------------------------------------------------------===//
// MI-100 and MI-200 buffer atomic floating point add intrinsic
@@ -932,30 +1020,6 @@ def ROCDL_Permlane32SwapOp : ROCDL_IntrOp<"permlane32.swap", [], [],
}];
}
-class ROCDL_ConcreteVector<Type elem, int length> :
- FixedVectorOfLengthAndType<[length], [elem]>,
- BuildableType<
- "::mlir::VectorType::get({" # length # "} ,"
- # elem.builderCall # ")">;
-
-def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
-def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
-def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
-def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
-def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
-def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
-def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
-def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
-def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
-def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
-def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
-def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
-def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
-def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
-def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
-def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
-def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
-
//===---------------------------------------------------------------------===//
// 16-bit float intrinsics
//===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index d270ee8b089aa..0de5f38071791 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -776,6 +776,26 @@ llvm.func @rocdl.raw.buffer.i32(%rsrc : vector<4xi32>,
llvm.return
}
+llvm.func @rocdl.tensor.load.store.ops(
+ %desc0 : vector<4xi32>,
+ %desc1 : vector<8xi32>,
+ %desc2 : vector<4xi32>,
+ %desc3 : vector<4xi32>) {
+ // CHECK-LABEL: @rocdl.tensor.load.store.ops(
+ // CHECK-SAME: %[[DESC0:.*]]: vector<4xi32>, %[[DESC1:.*]]: vector<8xi32>, %[[DESC2:.*]]: vector<4xi32>, %[[DESC3:.*]]: vector<4xi32>)
+ // CHECK: rocdl.tensor.load.to.lds %[[DESC0]], %[[DESC1]], %[[DESC2]], %[[DESC3]] cachepolicy 0
+ // CHECK: rocdl.tensor.load.to.lds.d2 %[[DESC0]], %[[DESC1]] cachepolicy 0
+ // CHECK: rocdl.tensor.store.from.lds %[[DESC0]], %[[DESC1]], %[[DESC2]], %[[DESC3]] cachepolicy 0
+ // CHECK: rocdl.tensor.store.from.lds.d2 %[[DESC0]], %[[DESC1]] cachepolicy 0
+
+ rocdl.tensor.load.to.lds %desc0, %desc1, %desc2, %desc3 cachepolicy 0
+ rocdl.tensor.load.to.lds.d2 %desc0, %desc1 cachepolicy 0
+
+ rocdl.tensor.store.from.lds %desc0, %desc1, %desc2, %desc3 cachepolicy 0
+ rocdl.tensor.store.from.lds.d2 %desc0, %desc1 cachepolicy 0
+ llvm.return
+}
+
llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf16, %stoch: i32) -> i32 {
// CHECK-LABEL: @rocdl_8bit_floats
// CHECK: rocdl.cvt.f32.bf8
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 30126f6bff05a..eac58929795db 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1250,6 +1250,26 @@ llvm.func @rocdl.raw.buffer.atomic.cmpswap(%rsrc : vector<4xi32>,
llvm.return %val : i32
}
+llvm.func @rocdl.tensor.load.store.ops(
+ %desc0 : vector<4xi32>,
+ %desc1 : vector<8xi32>,
+ %desc2 : vector<4xi32>,
+ %desc3 : vector<4xi32>) {
+ // CHECK-LABEL: @rocdl.tensor.load.store.ops(
+ // CHECK-SAME: <4 x i32> %[[DESC0:.*]], <8 x i32> %[[DESC1:.*]], <4 x i32> %[[DESC2:.*]], <4 x i32> %[[DESC3:.*]])
+ // CHECK: call void @llvm.amdgcn.tensor.load.to.lds(<4 x i32> %[[DESC0]], <8 x i32> %[[DESC1]], <4 x i32> %[[DESC2]], <4 x i32> %[[DESC3]], i32 0)
+ // CHECK: call void @llvm.amdgcn.tensor.load.to.lds.d2(<4 x i32> %[[DESC0]], <8 x i32> %[[DESC1]], i32 0)
+ // CHECK: call void @llvm.amdgcn.tensor.store.from.lds(<4 x i32> %[[DESC0]], <8 x i32> %[[DESC1]], <4 x i32> %[[DESC2]], <4 x i32> %[[DESC3]], i32 0)
+ // CHECK: call void @llvm.amdgcn.tensor.store.from.lds.d2(<4 x i32> %[[DESC0]], <8 x i32> %[[DESC1]], i32 0)
+ // CHECK: ret void
+ rocdl.tensor.load.to.lds %desc0, %desc1, %desc2, %desc3 cachepolicy 0
+ rocdl.tensor.load.to.lds.d2 %desc0, %desc1 cachepolicy 0
+
+ rocdl.tensor.store.from.lds %desc0, %desc1, %desc2, %desc3 cachepolicy 0
+ rocdl.tensor.store.from.lds.d2 %desc0, %desc1 cachepolicy 0
+ llvm.return
+}
+
llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf16, %source_packed: vector<2xf16>, %stoch: i32) -> i32 {
// CHECK-LABEL: @rocdl_8bit_floats
// CHECK: call float @llvm.amdgcn.cvt.f32.bf8(i32 %{{.+}}, i32 0)
More information about the Mlir-commits
mailing list