[Mlir-commits] [mlir] Allow 16 bit floating point operand for LLVM_AtomicRMWOp (PR #110553)
Ilya V
llvmlistbot at llvm.org
Mon Sep 30 11:54:16 PDT 2024
https://github.com/joviliast created https://github.com/llvm/llvm-project/pull/110553
As far as AMDGPU target supports vectorization for atomic_rmw operation, allow construction of LLVM_AtomicRMWOp with 16 bit floating point values. This patch enables building of LLVM_AtomicRMWOp with fixed vectors of 16 bit fp values as operands.
See also: #94845, #95393, #95394
>From a169d00acf243f3f2b01997ddec414c2dd52f182 Mon Sep 17 00:00:00 2001
From: Ilya Veselov <iveselov.nn at gmail.com>
Date: Wed, 25 Sep 2024 19:38:29 +0200
Subject: [PATCH] Allow 16 bit floating point operand for LLVM_AtomicRMWOp
As far as AMDGPU target supports vectorization for atomic_rmw operation,
allow construction of LLVM_AtomicRMWOp with 16 bit floating point values.
This patch enables building of LLVM_AtomicRMWOp with fixed vectors of
16 bit fp values as operands.
See also: #94845, #95393, #95394
Signed-off-by: Ilya Veselov <iveselov.nn at gmail.com>
---
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 3 ++-
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 16 +++++++++++++---
mlir/test/Dialect/LLVMIR/invalid.mlir | 16 ++++++++++++++++
3 files changed, 31 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 030160821bd823..615c0a39f3acd0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1737,7 +1737,8 @@ def LLVM_ConstantOp
// Atomic operations.
//
-def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnySignlessInteger]>;
+def LLVM_AtomicRMWType
+ : AnyTypeOf<[LLVM_AnyPointer, AnySignlessInteger, LLVM_ScalarOrVectorOf<LLVM_AnyFloat>]>;
def LLVM_AtomicRMWOp : LLVM_MemAccessOpBase<"atomicrmw", [
TypesMatchWith<"result #0 and operand #1 have the same type",
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 0561c364c7d591..99b3dc79fda664 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3008,9 +3008,19 @@ void AtomicRMWOp::build(OpBuilder &builder, OperationState &state,
LogicalResult AtomicRMWOp::verify() {
auto valType = getVal().getType();
- if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub ||
- getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax) {
- if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
+ if (getBinOp() == AtomicBinOp::fadd && isCompatibleVectorType(valType)) {
+ // Currently, only fadd operation supports fixed vector operands.
+ if (isScalableVectorType(valType))
+ return emitOpError("expected LLVM IR fixed vector type");
+ Type elemType = getVectorElementType(valType);
+ if (!(isCompatibleFloatingPointType(elemType) &&
+ elemType.getIntOrFloatBitWidth() == 16))
+ return emitOpError("unexpected LLVM IR type for vector element");
+ } else if (getBinOp() == AtomicBinOp::fadd ||
+ getBinOp() == AtomicBinOp::fsub ||
+ getBinOp() == AtomicBinOp::fmin ||
+ getBinOp() == AtomicBinOp::fmax) {
+ if (!isCompatibleFloatingPointType(valType))
return emitOpError("expected LLVM IR floating point type");
} else if (getBinOp() == AtomicBinOp::xchg) {
DataLayout dataLayout = DataLayout::closest(*this);
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 9388d7ef24936e..978572a2b3cca2 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -643,6 +643,22 @@ func.func @atomicrmw_expected_float(%i32_ptr : !llvm.ptr, %i32 : i32) {
// -----
+func.func @atomicrmw_unexpected_scalable_vector(%i32_ptr : !llvm.ptr, %i16_fvec : vector<[3]xf16>) {
+ // expected-error at +1 {{expected LLVM IR fixed vector type}}
+ %0 = llvm.atomicrmw fadd %i32_ptr, %i16_fvec unordered : !llvm.ptr, i32
+ llvm.return
+}
+
+// -----
+
+func.func @atomicrmw_unexpected_vector_element(%i32_ptr : !llvm.ptr, %i16_fvec : vector<3xi16>) {
+ // expected-error at +1 {{unexpected LLVM IR type for vector element}}
+ %0 = llvm.atomicrmw fadd %i32_ptr, %i16_fvec unordered : !llvm.ptr, i32
+ llvm.return
+}
+
+// -----
+
func.func @atomicrmw_unexpected_xchg_type(%i1_ptr : !llvm.ptr, %i1 : i1) {
// expected-error at +1 {{unexpected LLVM IR type for 'xchg' bin_op}}
%0 = llvm.atomicrmw xchg %i1_ptr, %i1 unordered : !llvm.ptr, i1
More information about the Mlir-commits
mailing list