[Mlir-commits] [mlir] 303fba0 - [mlir][ROCDL] Add tensor load and store instructions to ROCDL (#165016)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Tue Oct 28 13:05:35 PDT 2025
    
    
  
Author: Justin Rosner
Date: 2025-10-28T20:05:30Z
New Revision: 303fba0214c9bf3e7d892417b788a012ec6427b4
URL: https://github.com/llvm/llvm-project/commit/303fba0214c9bf3e7d892417b788a012ec6427b4
DIFF: https://github.com/llvm/llvm-project/commit/303fba0214c9bf3e7d892417b788a012ec6427b4.diff
LOG: [mlir][ROCDL] Add tensor load and store instructions to ROCDL (#165016)
Add support for `tensor.load.to.lds` and `tensor.store.from.lds`
instructions in ROCDL.
Added: 
    
Modified: 
    mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
    mlir/test/Dialect/LLVMIR/rocdl.mlir
    mlir/test/Target/LLVMIR/rocdl.mlir
Removed: 
    
################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index d2df244eb9363..5241f9a6f2b43 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -146,6 +146,35 @@ class ROCDL_DimGetterFunctionOp<string mnemonic, string device_function,
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// ROCDL vector types definitions
+//===----------------------------------------------------------------------===//
+
+class ROCDL_ConcreteVector<Type elem, int length> :
+  FixedVectorOfLengthAndType<[length], [elem]>,
+  BuildableType<
+    "::mlir::VectorType::get({" # length # "} ,"
+      # elem.builderCall # ")">;
+
+def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
+def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
+def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
+def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
+def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
+def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
+def ROCDL_V4I32Type : ROCDL_ConcreteVector<I32, 4>;
+def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
+def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
+def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
+def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
+def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
+def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
+def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
+def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
+def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
+def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
+def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
+
 //===----------------------------------------------------------------------===//
 // Wave-level primitives
 //===----------------------------------------------------------------------===//
@@ -663,6 +692,68 @@ def ROCDL_GlobalLoadLDSOp :
   }];
 }
 
+//===---------------------------------------------------------------------===//
+// Tensor load/store intrinsics (available in GFX1250)
+//===---------------------------------------------------------------------===//
+
+// Base class for tensor load/store operations with 4 descriptor groups.
+class ROCDL_TensorLDSIntrOp<string mnemonic> :
+  ROCDL_IntrOp<mnemonic, [], [], [], 0, 0, 1, 0, [4], ["cachePolicy"]> {
+  dag args = (ins ROCDL_V4I32Type:$dgroup0, ROCDL_V8I32Type:$dgroup1,
+                  ROCDL_V4I32Type:$dgroup2, ROCDL_V4I32Type:$dgroup3,
+                  I32Attr:$cachePolicy);
+  let arguments = !con(args, baseArgs);
+  let summary = "Base class for ROCDL tensor load/store to/from LDS.";
+  let description = [{
+    Moves tiles of tensor data between global memory and LDS. The tile is
+    described by the $dgroup descriptors. 4 $dgroup descriptors allows for
+    movement of up to 5D tensors. $cachePolicy describes the memory scope and an
+    indicator of expected data re-use.
+
+    This op is for gfx1250+ architectures.
+  }];
+  let assemblyFormat = [{
+    attr-dict operands `cachepolicy` $cachePolicy `:` type($dgroup0) `,` type($dgroup1)
+  }];
+  let extraClassDefinition = [{
+    SmallVector<Value> $cppClass::getAccessedOperands() {
+      return {getDgroup0(), getDgroup1(), getDgroup2(), getDgroup3()};
+    }
+  }];
+}
+
+// Base class for tensor load/store operations with 2 descriptor groups
+// (D2 variant).
+class ROCDL_TensorLDSIntrD2Op<string mnemonic> :
+  ROCDL_IntrOp<mnemonic, [], [], [], 0, 0, 1, 0, [2], ["cachePolicy"]> {
+  dag args = (ins ROCDL_V4I32Type:$dgroup0, ROCDL_V8I32Type:$dgroup1,
+                  I32Attr:$cachePolicy);
+  let arguments = !con(args, baseArgs);
+  let summary = "Base class for ROCDL tensor load/store to/from LDS (D2 variant).";
+  let description = [{
+    Moves tiles of tensor data between global memory and LDS. The tile is
+    described by the $dgroup descriptors. 2 $dgroup descriptors allows for
+    movement of up to 2D tensors. $cachePolicy describes the memory scope and an
+    indicator of expected data re-use.
+
+    This op is for gfx1250+ architectures.
+  }];
+  let assemblyFormat = [{
+    attr-dict operands `cachepolicy` $cachePolicy `:` type($dgroup0) `,` type($dgroup1)
+  }];
+  let extraClassDefinition = [{
+    SmallVector<Value> $cppClass::getAccessedOperands() {
+      return {getDgroup0(), getDgroup1()};
+    }
+  }];
+}
+
+// Tensor load and store operations
+def ROCDL_TensorLoadToLDSOp : ROCDL_TensorLDSIntrOp<"tensor.load.to.lds">;
+def ROCDL_TensorStoreFromLDSOp : ROCDL_TensorLDSIntrOp<"tensor.store.from.lds">;
+def ROCDL_TensorLoadToLDSD2Op : ROCDL_TensorLDSIntrD2Op<"tensor.load.to.lds.d2">;
+def ROCDL_TensorStoreFromLDSD2Op : ROCDL_TensorLDSIntrD2Op<"tensor.store.from.lds.d2">;
+
 //===---------------------------------------------------------------------===//
 // Operations on raw buffer resources (stride of 0, bounds checks either off or in
 // raw buffer mode).
@@ -932,30 +1023,6 @@ def ROCDL_Permlane32SwapOp : ROCDL_IntrOp<"permlane32.swap", [], [],
   }];
 }
 
-class ROCDL_ConcreteVector<Type elem, int length> :
-  FixedVectorOfLengthAndType<[length], [elem]>,
-  BuildableType<
-    "::mlir::VectorType::get({" # length # "} ,"
-      # elem.builderCall # ")">;
-
-def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
-def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
-def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
-def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
-def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
-def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
-def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
-def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
-def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
-def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
-def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
-def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
-def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
-def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
-def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
-def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
-def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
-
 //===---------------------------------------------------------------------===//
 // 16-bit float intrinsics
 //===---------------------------------------------------------------------===//
diff  --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index d270ee8b089aa..e703600c71c8e 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -664,6 +664,36 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
   llvm.return
 }
 
+// CHECK-LABEL @rocdl.tensor.load.to.lds
+llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
+                                    %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
+  // CHECK: rocdl.tensor.load.to.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32>
+  rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+  llvm.return
+}
+
+// CHECK-LABEL @rocdl.tensor.store.from.lds
+llvm.func @rocdl.tensor.store.from.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
+                                       %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
+  // CHECK: rocdl.tensor.store.from.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32>
+  rocdl.tensor.store.from.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+  llvm.return
+}
+
+// CHECK-LABEL @rocdl.tensor.load.to.lds.d2
+llvm.func @rocdl.tensor.load.to.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
+  // CHECK: rocdl.tensor.load.to.lds.d2 %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32>
+  rocdl.tensor.load.to.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+  llvm.return
+}
+
+// CHECK-LABEL @rocdl.tensor.store.from.lds.d2
+llvm.func @rocdl.tensor.store.from.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
+  // CHECK: rocdl.tensor.store.from.lds.d2 %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32>
+  rocdl.tensor.store.from.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+  llvm.return
+}
+
 llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
                                   %stride : i16,
                                   %numRecords : i64,
diff  --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 30126f6bff05a..8a848221a50dd 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1040,6 +1040,36 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
   llvm.return
 }
 
+// CHECK-LABEL: rocdl.tensor.load.to.lds
+llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
+                                    %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
+  // CHECK: call void @llvm.amdgcn.tensor.load.to.lds(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 0)
+  rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+  llvm.return
+}
+
+// CHECK-LABEL: rocdl.tensor.store.from.lds
+llvm.func @rocdl.tensor.store.from.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
+                                       %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
+  // CHECK: call void @llvm.amdgcn.tensor.store.from.lds(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 0)
+  rocdl.tensor.store.from.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+  llvm.return
+}
+
+// CHECK-LABEL: rocdl.tensor.load.to.lds.d2
+llvm.func @rocdl.tensor.load.to.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
+  // CHECK: call void @llvm.amdgcn.tensor.load.to.lds.d2(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i32 0)
+  rocdl.tensor.load.to.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+  llvm.return
+}
+
+// CHECK-LABEL: rocdl.tensor.store.from.lds.d2
+llvm.func @rocdl.tensor.store.from.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
+  // CHECK: call void @llvm.amdgcn.tensor.store.from.lds.d2(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i32 0)
+  rocdl.tensor.store.from.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+  llvm.return
+}
+
 llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
                                   %stride : i16,
                                   %numRecords : i64,
        
    
    
More information about the Mlir-commits
mailing list