[Mlir-commits] [mlir] [ROCDL] Added tensor load/store ops (PR #165390)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 28 06:32:55 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Ravil Dorozhinskii (ravil-mobile)

<details>
<summary>Changes</summary>

This patch introduces tensor load/store ops in the ROCDL dialect

Specifically:

    tensor loads/stores <=2D and <=5D variants

Tests:

    Added lit-tests to check MLIR -> LLVM lowering

---
Full diff: https://github.com/llvm/llvm-project/pull/165390.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+88-24) 
- (modified) mlir/test/Dialect/LLVMIR/rocdl.mlir (+20) 
- (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+20) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index d2df244eb9363..6bb968c24027f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -146,6 +146,35 @@ class ROCDL_DimGetterFunctionOp<string mnemonic, string device_function,
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// ROCDL vector types definitions
+//===----------------------------------------------------------------------===//
+
+class ROCDL_ConcreteVector<Type elem, int length> :
+  FixedVectorOfLengthAndType<[length], [elem]>,
+  BuildableType<
+    "::mlir::VectorType::get({" # length # "} ,"
+      # elem.builderCall # ")">;
+
+def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
+def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
+def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
+def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
+def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
+def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
+def ROCDL_V4I32Type : ROCDL_ConcreteVector<I32, 4>;
+def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
+def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
+def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
+def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
+def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
+def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
+def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
+def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
+def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
+def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
+def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
+
 //===----------------------------------------------------------------------===//
 // Wave-level primitives
 //===----------------------------------------------------------------------===//
@@ -805,6 +834,65 @@ def ROCDL_RawBufferAtomicCmpSwap :
   }];
 }
 
+//===---------------------------------------------------------------------===//
+// Raw tensor load/store intrinsics: gfx12+
+
+def ROCDL_TensorLoadToLds :
+  ROCDL_IntrOp<"tensor.load.to.lds", [], [], [], 0, 0, 0, 0, [4], ["cpol"]>,
+  Arguments<(ins ROCDL_V4I32Type:$desc0,
+                 ROCDL_V8I32Type:$desc1,
+                 ROCDL_V4I32Type:$desc2,
+                 ROCDL_V4I32Type:$desc3,
+                 I32Attr:$cpol)>{
+  let description = [{
+      Loads tensor data from Global to LDS. Available on gfx12+.
+  }];
+  let assemblyFormat = [{
+    attr-dict operands `cachepolicy` $cpol
+  }];
+}
+
+def ROCDL_TensorLoadToLdsD2 :
+  ROCDL_IntrOp<"tensor.load.to.lds.d2", [], [], [], 0, 0, 0, 0, [2], ["cpol"]>,
+  Arguments<(ins ROCDL_V4I32Type:$desc0,
+                 ROCDL_V8I32Type:$desc1,
+                 I32Attr:$cpol)>{
+  let description = [{
+      Loads 2D tensor data from Global to LDS. Available on gfx12+. TODO
+  }];
+  let assemblyFormat = [{
+    attr-dict operands `cachepolicy` $cpol
+  }];
+}
+
+def ROCDL_TensorStoreFromLds :
+  ROCDL_IntrOp<"tensor.store.from.lds", [], [], [], 0, 0, 0, 0, [4], ["cpol"]>,
+  Arguments<(ins ROCDL_V4I32Type:$desc0,
+                 ROCDL_V8I32Type:$desc1,
+                 ROCDL_V4I32Type:$desc2,
+                 ROCDL_V4I32Type:$desc3,
+                 I32Attr:$cpol)>{
+  let description = [{
+      Stores tensor data from Global to LDS. Available on gfx12+.
+  }];
+  let assemblyFormat = [{
+    attr-dict operands `cachepolicy` $cpol
+  }];
+}
+
+def ROCDL_TensorStoreFromLdsD2 :
+  ROCDL_IntrOp<"tensor.store.from.lds.d2", [], [], [], 0, 0, 0, 0, [2], ["cpol"]>,
+  Arguments<(ins ROCDL_V4I32Type:$desc0,
+                 ROCDL_V8I32Type:$desc1,
+                 I32Attr:$cpol)>{
+  let description = [{
+      Stores tensor 2D data from Global to LDS. Available on gfx12+. TODO
+  }];
+  let assemblyFormat = [{
+    attr-dict operands `cachepolicy` $cpol
+  }];
+}
+
 //===---------------------------------------------------------------------===//
 // MI-100 and MI-200 buffer atomic floating point add intrinsic
 
@@ -932,30 +1020,6 @@ def ROCDL_Permlane32SwapOp : ROCDL_IntrOp<"permlane32.swap", [], [],
   }];
 }
 
-class ROCDL_ConcreteVector<Type elem, int length> :
-  FixedVectorOfLengthAndType<[length], [elem]>,
-  BuildableType<
-    "::mlir::VectorType::get({" # length # "} ,"
-      # elem.builderCall # ")">;
-
-def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
-def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
-def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
-def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
-def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
-def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
-def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
-def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
-def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
-def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
-def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
-def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
-def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
-def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
-def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
-def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
-def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
-
 //===---------------------------------------------------------------------===//
 // 16-bit float intrinsics
 //===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index d270ee8b089aa..0de5f38071791 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -776,6 +776,26 @@ llvm.func @rocdl.raw.buffer.i32(%rsrc : vector<4xi32>,
   llvm.return
 }
 
+llvm.func @rocdl.tensor.load.store.ops(
+        %desc0 : vector<4xi32>,
+        %desc1 : vector<8xi32>,
+        %desc2 : vector<4xi32>,
+        %desc3 : vector<4xi32>) {
+  // CHECK-LABEL: @rocdl.tensor.load.store.ops(
+  // CHECK-SAME: %[[DESC0:.*]]: vector<4xi32>, %[[DESC1:.*]]: vector<8xi32>, %[[DESC2:.*]]: vector<4xi32>, %[[DESC3:.*]]: vector<4xi32>)
+  // CHECK: rocdl.tensor.load.to.lds %[[DESC0]], %[[DESC1]], %[[DESC2]], %[[DESC3]] cachepolicy 0
+  // CHECK: rocdl.tensor.load.to.lds.d2 %[[DESC0]], %[[DESC1]] cachepolicy 0
+  // CHECK: rocdl.tensor.store.from.lds %[[DESC0]], %[[DESC1]], %[[DESC2]], %[[DESC3]] cachepolicy 0
+  // CHECK: rocdl.tensor.store.from.lds.d2 %[[DESC0]], %[[DESC1]] cachepolicy 0
+
+  rocdl.tensor.load.to.lds %desc0, %desc1, %desc2, %desc3 cachepolicy 0
+  rocdl.tensor.load.to.lds.d2 %desc0, %desc1 cachepolicy 0
+
+  rocdl.tensor.store.from.lds %desc0, %desc1, %desc2, %desc3 cachepolicy 0
+  rocdl.tensor.store.from.lds.d2 %desc0, %desc1 cachepolicy 0
+  llvm.return
+}
+
 llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf16, %stoch: i32) -> i32 {
 // CHECK-LABEL: @rocdl_8bit_floats
 // CHECK: rocdl.cvt.f32.bf8
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 30126f6bff05a..eac58929795db 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1250,6 +1250,26 @@ llvm.func @rocdl.raw.buffer.atomic.cmpswap(%rsrc : vector<4xi32>,
   llvm.return %val : i32
 }
 
+llvm.func @rocdl.tensor.load.store.ops(
+        %desc0 : vector<4xi32>,
+        %desc1 : vector<8xi32>,
+        %desc2 : vector<4xi32>,
+        %desc3 : vector<4xi32>) {
+  // CHECK-LABEL: @rocdl.tensor.load.store.ops(
+  // CHECK-SAME: <4 x i32> %[[DESC0:.*]], <8 x i32> %[[DESC1:.*]], <4 x i32> %[[DESC2:.*]], <4 x i32> %[[DESC3:.*]])
+  // CHECK: call void @llvm.amdgcn.tensor.load.to.lds(<4 x i32> %[[DESC0]], <8 x i32> %[[DESC1]], <4 x i32> %[[DESC2]], <4 x i32> %[[DESC3]], i32 0)
+  // CHECK: call void @llvm.amdgcn.tensor.load.to.lds.d2(<4 x i32> %[[DESC0]], <8 x i32> %[[DESC1]], i32 0)
+  // CHECK: call void @llvm.amdgcn.tensor.store.from.lds(<4 x i32> %[[DESC0]], <8 x i32> %[[DESC1]], <4 x i32> %[[DESC2]], <4 x i32> %[[DESC3]], i32 0)
+  // CHECK: call void @llvm.amdgcn.tensor.store.from.lds.d2(<4 x i32> %[[DESC0]], <8 x i32> %[[DESC1]], i32 0)
+  // CHECK: ret void
+  rocdl.tensor.load.to.lds %desc0, %desc1, %desc2, %desc3 cachepolicy 0
+  rocdl.tensor.load.to.lds.d2 %desc0, %desc1 cachepolicy 0
+
+  rocdl.tensor.store.from.lds %desc0, %desc1, %desc2, %desc3 cachepolicy 0
+  rocdl.tensor.store.from.lds.d2 %desc0, %desc1 cachepolicy 0
+  llvm.return
+}
+
 llvm.func @rocdl_8bit_floats(%source: i32, %source_half: f16, %source_bfloat: bf16, %source_packed: vector<2xf16>, %stoch: i32) -> i32 {
 // CHECK-LABEL: @rocdl_8bit_floats
 // CHECK: call float @llvm.amdgcn.cvt.f32.bf8(i32 %{{.+}}, i32 0)

``````````

</details>


https://github.com/llvm/llvm-project/pull/165390


More information about the Mlir-commits mailing list