[Mlir-commits] [mlir] [ROCDL] Add the global.atomic.fadd intrinsic in ROCDL (PR #94486)
Giuseppe Rossini
llvmlistbot at llvm.org
Wed Jun 5 09:01:57 PDT 2024
https://github.com/giuseros updated https://github.com/llvm/llvm-project/pull/94486
>From a35b0535ce7a9002d4738d3a07d21fdd7179f25a Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Wed, 5 Jun 2024 15:14:43 +0000
Subject: [PATCH 1/2] [ROCDL] Add the global.atomic.fadd intrinsic in ROCDL
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 17 +++++++++++++++--
mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp | 20 ++++++++++++++++++++
mlir/test/Target/LLVMIR/rocdl.mlir | 9 +++++++++
3 files changed, 44 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 1dabf5d7979b7..c8d4e4c03486e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -165,7 +165,7 @@ def ROCDL_BallotOp :
let summary = "Vote across thread group";
let description = [{
- Ballot provides a bit mask containing the 1-bit predicate value from each lane.
+ Ballot provides a bit mask containing the 1-bit predicate value from each lane.
The nth bit of the result contains the 1 bit contributed by the nth warp lane.
}];
@@ -516,7 +516,7 @@ def ROCDL_RawBufferAtomicCmpSwap :
}
//===---------------------------------------------------------------------===//
-// MI-100 and MI-200 buffer atomic floating point add intrinsic
+// MI-100, MI-200 and MI-300 global/buffer atomic floating point add intrinsic
def ROCDL_RawBufferAtomicFAddOp :
ROCDL_Op<"raw.buffer.atomic.fadd">,
@@ -534,6 +534,19 @@ def ROCDL_RawBufferAtomicFAddOp :
let hasCustomAssemblyFormat = 1;
}
+def ROCDL_GlobalAtomicFAddOp :
+ ROCDL_Op<"global.atomic.fadd">,
+ Arguments<(ins LLVM_Type:$ptr,
+ LLVM_Type:$vdata)>{
+ string llvmBuilder = [{
+ auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
+ auto ptrType = moduleTranslation.convertType(op.getPtr().getType());
+ createIntrinsicCall(builder,
+ llvm::Intrinsic::amdgcn_global_atomic_fadd, {$ptr, $vdata}, {vdataType, ptrType, vdataType});
+ }];
+ let hasCustomAssemblyFormat = 1;
+}
+
//===---------------------------------------------------------------------===//
// Buffer atomic floating point max intrinsic. GFX9 does not support fp32.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 65b770ae32610..34ebdb2ffd3d0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -157,6 +157,26 @@ void RawBufferAtomicFAddOp::print(mlir::OpAsmPrinter &p) {
p << " " << getOperands() << " : " << getVdata().getType();
}
+// <operation> ::=
+// `llvm.amdgcn.global.atomic.fadd.* %vdata, %ptr
+ParseResult GlobalAtomicFAddOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
+ Type type;
+ if (parser.parseOperandList(ops, 2) || parser.parseColonType(type))
+ return failure();
+
+ auto ptrType = LLVM::LLVMPointerType::get(parser.getContext());
+ if (parser.resolveOperands(ops, {ptrType, type}, parser.getNameLoc(),
+ result.operands))
+ return failure();
+ return success();
+}
+
+void GlobalAtomicFAddOp::print(mlir::OpAsmPrinter &p) {
+ p << " " << getOperands() << " : " << getVdata().getType();
+}
+
// <operation> ::=
// `llvm.amdgcn.raw.buffer.atomic.fmax.* %vdata, %rsrc, %offset,
// %soffset, %aux : result_type`
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index ce6b56d48437a..9d22b80748e14 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -494,6 +494,15 @@ llvm.func @rocdl.raw.buffer.atomic.f32(%rsrc : vector<4xi32>,
llvm.return
}
+// CHECK-LABEL: rocdl.global.atomic
+llvm.func @rocdl.global.atomic(%vdata0 : f32, %vdata1 : vector<2xf16>, %ptr : !llvm.ptr) {
+ // CHECK: call float @llvm.amdgcn.global.atomic.fadd.f32.p0.f32(ptr %{{.*}}, float %{{.*}}
+ rocdl.global.atomic.fadd %ptr, %vdata0: f32
+ // CHECK: call <2 x half> @llvm.amdgcn.global.atomic.fadd.v2f16.p0.v2f16(ptr %{{.*}}, <2 x half> %{{.*}})
+ rocdl.global.atomic.fadd %ptr, %vdata1: vector<2xf16>
+ llvm.return
+}
+
llvm.func @rocdl.raw.buffer.atomic.i32(%rsrc : vector<4xi32>,
%offset : i32, %soffset : i32,
%vdata1 : i32) {
>From 9d9cc3d4961a00e418c6fe3640b16ad473181727 Mon Sep 17 00:00:00 2001
From: Giuseppe Rossini <giuseppe.rossini at amd.com>
Date: Wed, 5 Jun 2024 16:00:51 +0000
Subject: [PATCH 2/2] Address review feedback
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 7 ++++---
mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp | 20 --------------------
mlir/test/Target/LLVMIR/rocdl.mlir | 6 +++---
3 files changed, 7 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index c8d4e4c03486e..deadd6caeb7e2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -342,6 +342,7 @@ def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]
//===---------------------------------------------------------------------===//
def ROCDLBufferRsrc : LLVM_PointerInAddressSpace<8>;
+def ROCDLGlobalPtr: LLVM_PointerInAddressSpace<1>;
def ROCDL_MakeBufferRsrcOp :
ROCDL_IntrOp<"make.buffer.rsrc", [], [0], [Pure], 1>,
@@ -516,7 +517,7 @@ def ROCDL_RawBufferAtomicCmpSwap :
}
//===---------------------------------------------------------------------===//
-// MI-100, MI-200 and MI-300 global/buffer atomic floating point add intrinsic
+// gfx9x global/buffer atomic floating point add intrinsics
def ROCDL_RawBufferAtomicFAddOp :
ROCDL_Op<"raw.buffer.atomic.fadd">,
@@ -536,7 +537,7 @@ def ROCDL_RawBufferAtomicFAddOp :
def ROCDL_GlobalAtomicFAddOp :
ROCDL_Op<"global.atomic.fadd">,
- Arguments<(ins LLVM_Type:$ptr,
+ Arguments<(ins ROCDLGlobalPtr:$ptr,
LLVM_Type:$vdata)>{
string llvmBuilder = [{
auto vdataType = moduleTranslation.convertType(op.getVdata().getType());
@@ -544,7 +545,7 @@ def ROCDL_GlobalAtomicFAddOp :
createIntrinsicCall(builder,
llvm::Intrinsic::amdgcn_global_atomic_fadd, {$ptr, $vdata}, {vdataType, ptrType, vdataType});
}];
- let hasCustomAssemblyFormat = 1;
+ let assemblyFormat = "operands attr-dict `:` type($vdata)";
}
//===---------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 34ebdb2ffd3d0..65b770ae32610 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -157,26 +157,6 @@ void RawBufferAtomicFAddOp::print(mlir::OpAsmPrinter &p) {
p << " " << getOperands() << " : " << getVdata().getType();
}
-// <operation> ::=
-// `llvm.amdgcn.global.atomic.fadd.* %vdata, %ptr
-ParseResult GlobalAtomicFAddOp::parse(OpAsmParser &parser,
- OperationState &result) {
- SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
- Type type;
- if (parser.parseOperandList(ops, 2) || parser.parseColonType(type))
- return failure();
-
- auto ptrType = LLVM::LLVMPointerType::get(parser.getContext());
- if (parser.resolveOperands(ops, {ptrType, type}, parser.getNameLoc(),
- result.operands))
- return failure();
- return success();
-}
-
-void GlobalAtomicFAddOp::print(mlir::OpAsmPrinter &p) {
- p << " " << getOperands() << " : " << getVdata().getType();
-}
-
// <operation> ::=
// `llvm.amdgcn.raw.buffer.atomic.fmax.* %vdata, %rsrc, %offset,
// %soffset, %aux : result_type`
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 9d22b80748e14..c940d01a0a614 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -495,10 +495,10 @@ llvm.func @rocdl.raw.buffer.atomic.f32(%rsrc : vector<4xi32>,
}
// CHECK-LABEL: rocdl.global.atomic
-llvm.func @rocdl.global.atomic(%vdata0 : f32, %vdata1 : vector<2xf16>, %ptr : !llvm.ptr) {
- // CHECK: call float @llvm.amdgcn.global.atomic.fadd.f32.p0.f32(ptr %{{.*}}, float %{{.*}}
+llvm.func @rocdl.global.atomic(%vdata0 : f32, %vdata1 : vector<2xf16>, %ptr : !llvm.ptr<1>) {
+ // CHECK: call float @llvm.amdgcn.global.atomic.fadd.f32.p1.f32(ptr addrspace(1) %{{.*}}, float %{{.*}}
rocdl.global.atomic.fadd %ptr, %vdata0: f32
- // CHECK: call <2 x half> @llvm.amdgcn.global.atomic.fadd.v2f16.p0.v2f16(ptr %{{.*}}, <2 x half> %{{.*}})
+ // CHECK: call <2 x half> @llvm.amdgcn.global.atomic.fadd.v2f16.p1.v2f16(ptr addrspace(1) %{{.*}}, <2 x half> %{{.*}})
rocdl.global.atomic.fadd %ptr, %vdata1: vector<2xf16>
llvm.return
}
More information about the Mlir-commits
mailing list