[Mlir-commits] [mlir] Allow 16 bit floating point operand for LLVM_AtomicRMWOp (PR #110553)
Ilya V
llvmlistbot at llvm.org
Wed Oct 2 02:43:22 PDT 2024
https://github.com/joviliast updated https://github.com/llvm/llvm-project/pull/110553
>From f8c56ecbb8523e617a1b662bc05c5f03ee3b9d9b Mon Sep 17 00:00:00 2001
From: Ilya Veselov <iveselov.nn at gmail.com>
Date: Wed, 25 Sep 2024 19:38:29 +0200
Subject: [PATCH] Allow 16 bit floating point operand for LLVM_AtomicRMWOp
As far as AMDGPU target supports vectorization for atomic_rmw operation,
allow construction of LLVM_AtomicRMWOp with 16 bit floating point values.
This patch enables building of LLVM_AtomicRMWOp with fixed vectors of
16 bit fp values as operands.
See also: #94845, #95393, #95394
Signed-off-by: Ilya Veselov <iveselov.nn at gmail.com>
---
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td | 10 ++++++++++
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 3 ++-
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 14 +++++++++++---
mlir/test/Dialect/LLVMIR/invalid.mlir | 8 ++++++++
mlir/test/Dialect/LLVMIR/roundtrip.mlir | 4 +++-
mlir/test/Target/LLVMIR/llvmir.mlir | 7 +++++--
6 files changed, 39 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index c3d352d8d0dd48..fa16f098cc6a2f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -139,6 +139,16 @@ class LLVM_VectorOf<Type element> : Type<
class LLVM_ScalarOrVectorOf<Type element> :
AnyTypeOf<[element, LLVM_VectorOf<element>]>;
+// Type constraint accepting an LLVM fixed vector type with an additional constraint
+// on the vector element type.
+class LLVM_FixedVectorOf<Type element> : Type<
+ And<[LLVM_AnyFixedVector.predicate,
+ SubstLeaves<
+ "$_self",
+ "::mlir::LLVM::getVectorElementType($_self)",
+ element.predicate>]>,
+ "LLVM dialect-compatible fixed vector of " # element.summary>;
+
// Base class for LLVM operations. Defines the interface to the llvm::IRBuilder
// used to translate to proper LLVM IR and the interface to the mlir::OpBuilder
// used to import from LLVM IR.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 030160821bd823..beb8723f3bcd3b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1737,7 +1737,8 @@ def LLVM_ConstantOp
// Atomic operations.
//
-def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnySignlessInteger]>;
+def LLVM_AtomicRMWType
+ : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnySignlessInteger, LLVM_FixedVectorOf<LLVM_AnyFloat>]>;
def LLVM_AtomicRMWOp : LLVM_MemAccessOpBase<"atomicrmw", [
TypesMatchWith<"result #0 and operand #1 have the same type",
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 0561c364c7d591..ccdc4c79d7d189 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3008,9 +3008,17 @@ void AtomicRMWOp::build(OpBuilder &builder, OperationState &state,
LogicalResult AtomicRMWOp::verify() {
auto valType = getVal().getType();
- if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub ||
- getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax) {
- if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
+ if (getBinOp() == AtomicBinOp::fadd && isCompatibleVectorType(valType)) {
+ Type elemType = getVectorElementType(valType);
+ // Only 16 bit floating point elements are supported for now.
+ if (!(isCompatibleFloatingPointType(elemType) &&
+ elemType.getIntOrFloatBitWidth() == 16))
+ return emitOpError("unexpected LLVM IR type for vector element");
+ } else if (getBinOp() == AtomicBinOp::fadd ||
+ getBinOp() == AtomicBinOp::fsub ||
+ getBinOp() == AtomicBinOp::fmin ||
+ getBinOp() == AtomicBinOp::fmax) {
+ if (!isCompatibleFloatingPointType(valType))
return emitOpError("expected LLVM IR floating point type");
} else if (getBinOp() == AtomicBinOp::xchg) {
DataLayout dataLayout = DataLayout::closest(*this);
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 9388d7ef24936e..0e8c473fc3257d 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -643,6 +643,14 @@ func.func @atomicrmw_expected_float(%i32_ptr : !llvm.ptr, %i32 : i32) {
// -----
+func.func @atomicrmw_unexpected_vector_element(%ptr : !llvm.ptr, %f32_vec : vector<3xf32>) {
+ // expected-error at +1 {{unexpected LLVM IR type for vector element}}
+ %0 = llvm.atomicrmw fadd %ptr, %f32_vec unordered : !llvm.ptr, vector<3xf32>
+ llvm.return
+}
+
+// -----
+
func.func @atomicrmw_unexpected_xchg_type(%i1_ptr : !llvm.ptr, %i1 : i1) {
// expected-error at +1 {{unexpected LLVM IR type for 'xchg' bin_op}}
%0 = llvm.atomicrmw xchg %i1_ptr, %i1 unordered : !llvm.ptr, i1
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 62f1de2b7fe7d4..4a0647588b3353 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -420,11 +420,13 @@ func.func @atomic_store(%val : f32, %large_val : i256, %ptr : !llvm.ptr) {
}
// CHECK-LABEL: @atomicrmw
-func.func @atomicrmw(%ptr : !llvm.ptr, %val : f32) {
+func.func @atomicrmw(%ptr : !llvm.ptr, %val : f32, %f16_vec : vector<2xf16>) {
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} monotonic : !llvm.ptr, f32
%0 = llvm.atomicrmw fadd %ptr, %val monotonic : !llvm.ptr, f32
// CHECK: llvm.atomicrmw volatile fsub %{{.*}}, %{{.*}} syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32
%1 = llvm.atomicrmw volatile fsub %ptr, %val syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32
+ // CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} monotonic : !llvm.ptr, vector<2xf16>
+ %2 = llvm.atomicrmw fadd %ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16>
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 007284d0ca4435..3791e5f2757fa0 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1496,7 +1496,8 @@ llvm.func @elements_constant_3d_array() -> !llvm.array<2 x array<2 x array<2 x i
// CHECK-LABEL: @atomicrmw
llvm.func @atomicrmw(
%f32_ptr : !llvm.ptr, %f32 : f32,
- %i32_ptr : !llvm.ptr, %i32 : i32) {
+ %i32_ptr : !llvm.ptr, %i32 : i32,
+ %f16_vec_ptr : !llvm.ptr, %f16_vec : vector<2xf16>) {
// CHECK: atomicrmw fadd ptr %{{.*}}, float %{{.*}} monotonic
%0 = llvm.atomicrmw fadd %f32_ptr, %f32 monotonic : !llvm.ptr, f32
// CHECK: atomicrmw fsub ptr %{{.*}}, float %{{.*}} monotonic
@@ -1535,11 +1536,13 @@ llvm.func @atomicrmw(
%17 = llvm.atomicrmw usub_cond %i32_ptr, %i32 monotonic : !llvm.ptr, i32
// CHECK: atomicrmw usub_sat ptr %{{.*}}, i32 %{{.*}} monotonic
%18 = llvm.atomicrmw usub_sat %i32_ptr, %i32 monotonic : !llvm.ptr, i32
+ // CHECK: atomicrmw fadd ptr %{{.*}}, <2 x half> %{{.*}} monotonic
+ %19 = llvm.atomicrmw fadd %f16_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16>
// CHECK: atomicrmw volatile
// CHECK-SAME: syncscope("singlethread")
// CHECK-SAME: align 8
- %19 = llvm.atomicrmw volatile udec_wrap %i32_ptr, %i32 syncscope("singlethread") monotonic {alignment = 8 : i64} : !llvm.ptr, i32
+ %20 = llvm.atomicrmw volatile udec_wrap %i32_ptr, %i32 syncscope("singlethread") monotonic {alignment = 8 : i64} : !llvm.ptr, i32
llvm.return
}
More information about the Mlir-commits
mailing list