[Mlir-commits] [mlir] [ROCDL] Add the global.atomic.fadd intrinsic in ROCDL (PR #94486)

Giuseppe Rossini llvmlistbot at llvm.org
Wed Jun 5 08:17:06 PDT 2024


https://github.com/giuseros created https://github.com/llvm/llvm-project/pull/94486

This PR adds the `global.atomic.fadd` intrinsic in ROCDL (which supports `f32` and `vector<2xf16>`)

>From a35b0535ce7a9002d4738d3a07d21fdd7179f25a Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Wed, 5 Jun 2024 15:14:43 +0000
Subject: [PATCH] [ROCDL] Add the global.atomic.fadd intrinsic in ROCDL

---
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 17 +++++++++++++++--
 mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp  | 20 ++++++++++++++++++++
 mlir/test/Target/LLVMIR/rocdl.mlir           |  9 +++++++++
 3 files changed, 44 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 1dabf5d7979b7..c8d4e4c03486e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -165,7 +165,7 @@ def ROCDL_BallotOp :
   let summary = "Vote across thread group";
 
   let description = [{
-      Ballot provides a bit mask containing the 1-bit predicate value from each lane. 
+      Ballot provides a bit mask containing the 1-bit predicate value from each lane.
       The nth bit of the result contains the 1 bit contributed by the nth warp lane.
   }];
 
@@ -516,7 +516,7 @@ def ROCDL_RawBufferAtomicCmpSwap :
 }
 
 //===---------------------------------------------------------------------===//
-// MI-100 and MI-200 buffer atomic floating point add intrinsic
+// MI-100, MI-200 and MI-300 global/buffer atomic floating point add intrinsic
 
 def ROCDL_RawBufferAtomicFAddOp :
   ROCDL_Op<"raw.buffer.atomic.fadd">,
@@ -534,6 +534,19 @@ def ROCDL_RawBufferAtomicFAddOp :
   let hasCustomAssemblyFormat = 1;
 }
 
+def ROCDL_GlobalAtomicFAddOp :
+  ROCDL_Op<"global.atomic.fadd">,
+  Arguments<(ins LLVM_Type:$ptr,
+                 LLVM_Type:$vdata)>{
+  string llvmBuilder = [{
+      auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
+      auto ptrType = moduleTranslation.convertType(op.getPtr().getType());
+      createIntrinsicCall(builder,
+          llvm::Intrinsic::amdgcn_global_atomic_fadd, {$ptr, $vdata}, {vdataType, ptrType, vdataType});
+  }];
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===---------------------------------------------------------------------===//
 // Buffer atomic floating point max intrinsic. GFX9 does not support fp32.
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 65b770ae32610..34ebdb2ffd3d0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -157,6 +157,26 @@ void RawBufferAtomicFAddOp::print(mlir::OpAsmPrinter &p) {
   p << " " << getOperands() << " : " << getVdata().getType();
 }
 
+// <operation> ::=
+//     `llvm.amdgcn.global.atomic.fadd.* %vdata, %ptr
+ParseResult GlobalAtomicFAddOp::parse(OpAsmParser &parser,
+                                      OperationState &result) {
+  SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
+  Type type;
+  if (parser.parseOperandList(ops, 2) || parser.parseColonType(type))
+    return failure();
+
+  auto ptrType = LLVM::LLVMPointerType::get(parser.getContext());
+  if (parser.resolveOperands(ops, {ptrType, type}, parser.getNameLoc(),
+                             result.operands))
+    return failure();
+  return success();
+}
+
+void GlobalAtomicFAddOp::print(mlir::OpAsmPrinter &p) {
+  p << " " << getOperands() << " : " << getVdata().getType();
+}
+
 // <operation> ::=
 //     `llvm.amdgcn.raw.buffer.atomic.fmax.* %vdata, %rsrc,  %offset,
 //     %soffset, %aux : result_type`
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index ce6b56d48437a..9d22b80748e14 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -494,6 +494,15 @@ llvm.func @rocdl.raw.buffer.atomic.f32(%rsrc : vector<4xi32>,
   llvm.return
 }
 
+// CHECK-LABEL: rocdl.global.atomic
+llvm.func @rocdl.global.atomic(%vdata0 : f32, %vdata1 : vector<2xf16>, %ptr : !llvm.ptr) {
+  // CHECK: call float @llvm.amdgcn.global.atomic.fadd.f32.p0.f32(ptr %{{.*}}, float %{{.*}}
+  rocdl.global.atomic.fadd %ptr, %vdata0: f32
+  // CHECK: call <2 x half> @llvm.amdgcn.global.atomic.fadd.v2f16.p0.v2f16(ptr %{{.*}}, <2 x half> %{{.*}})
+  rocdl.global.atomic.fadd %ptr, %vdata1: vector<2xf16>
+  llvm.return
+}
+
 llvm.func @rocdl.raw.buffer.atomic.i32(%rsrc : vector<4xi32>,
                         %offset : i32, %soffset : i32,
                         %vdata1 : i32) {



More information about the Mlir-commits mailing list