[Mlir-commits] [mlir] [ROCDL] Add the global.atomic.fadd intrinsic in ROCDL (PR #94486)
Giuseppe Rossini
llvmlistbot at llvm.org
Wed Jun 5 08:17:06 PDT 2024
https://github.com/giuseros created https://github.com/llvm/llvm-project/pull/94486
This PR adds the `global.atomic.fadd` intrinsic in ROCDL (which supports `f32` and `vector<2xf16>`)
>From a35b0535ce7a9002d4738d3a07d21fdd7179f25a Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Wed, 5 Jun 2024 15:14:43 +0000
Subject: [PATCH] [ROCDL] Add the global.atomic.fadd intrinsic in ROCDL
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 17 +++++++++++++++--
mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp | 20 ++++++++++++++++++++
mlir/test/Target/LLVMIR/rocdl.mlir | 9 +++++++++
3 files changed, 44 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 1dabf5d7979b7..c8d4e4c03486e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -165,7 +165,7 @@ def ROCDL_BallotOp :
let summary = "Vote across thread group";
let description = [{
- Ballot provides a bit mask containing the 1-bit predicate value from each lane.
+ Ballot provides a bit mask containing the 1-bit predicate value from each lane.
The nth bit of the result contains the 1 bit contributed by the nth warp lane.
}];
@@ -516,7 +516,7 @@ def ROCDL_RawBufferAtomicCmpSwap :
}
//===---------------------------------------------------------------------===//
-// MI-100 and MI-200 buffer atomic floating point add intrinsic
+// MI-100, MI-200 and MI-300 global/buffer atomic floating point add intrinsic
def ROCDL_RawBufferAtomicFAddOp :
ROCDL_Op<"raw.buffer.atomic.fadd">,
@@ -534,6 +534,19 @@ def ROCDL_RawBufferAtomicFAddOp :
let hasCustomAssemblyFormat = 1;
}
+def ROCDL_GlobalAtomicFAddOp :
+ ROCDL_Op<"global.atomic.fadd">,
+ Arguments<(ins LLVM_Type:$ptr,
+ LLVM_Type:$vdata)>{
+ string llvmBuilder = [{
+ auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
+ auto ptrType = moduleTranslation.convertType(op.getPtr().getType());
+ createIntrinsicCall(builder,
+ llvm::Intrinsic::amdgcn_global_atomic_fadd, {$ptr, $vdata}, {vdataType, ptrType, vdataType});
+ }];
+ let hasCustomAssemblyFormat = 1;
+}
+
//===---------------------------------------------------------------------===//
// Buffer atomic floating point max intrinsic. GFX9 does not support fp32.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 65b770ae32610..34ebdb2ffd3d0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -157,6 +157,26 @@ void RawBufferAtomicFAddOp::print(mlir::OpAsmPrinter &p) {
p << " " << getOperands() << " : " << getVdata().getType();
}
+// <operation> ::=
+// `llvm.amdgcn.global.atomic.fadd.* %vdata, %ptr
+ParseResult GlobalAtomicFAddOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
+ Type type;
+ if (parser.parseOperandList(ops, 2) || parser.parseColonType(type))
+ return failure();
+
+ auto ptrType = LLVM::LLVMPointerType::get(parser.getContext());
+ if (parser.resolveOperands(ops, {ptrType, type}, parser.getNameLoc(),
+ result.operands))
+ return failure();
+ return success();
+}
+
+void GlobalAtomicFAddOp::print(mlir::OpAsmPrinter &p) {
+ p << " " << getOperands() << " : " << getVdata().getType();
+}
+
// <operation> ::=
// `llvm.amdgcn.raw.buffer.atomic.fmax.* %vdata, %rsrc, %offset,
// %soffset, %aux : result_type`
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index ce6b56d48437a..9d22b80748e14 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -494,6 +494,15 @@ llvm.func @rocdl.raw.buffer.atomic.f32(%rsrc : vector<4xi32>,
llvm.return
}
+// CHECK-LABEL: rocdl.global.atomic
+llvm.func @rocdl.global.atomic(%vdata0 : f32, %vdata1 : vector<2xf16>, %ptr : !llvm.ptr) {
+ // CHECK: call float @llvm.amdgcn.global.atomic.fadd.f32.p0.f32(ptr %{{.*}}, float %{{.*}}
+ rocdl.global.atomic.fadd %ptr, %vdata0: f32
+ // CHECK: call <2 x half> @llvm.amdgcn.global.atomic.fadd.v2f16.p0.v2f16(ptr %{{.*}}, <2 x half> %{{.*}})
+ rocdl.global.atomic.fadd %ptr, %vdata1: vector<2xf16>
+ llvm.return
+}
+
llvm.func @rocdl.raw.buffer.atomic.i32(%rsrc : vector<4xi32>,
%offset : i32, %soffset : i32,
%vdata1 : i32) {
More information about the Mlir-commits
mailing list