[Mlir-commits] [mlir] Allow fixed vector operand for LLVM_AtomicRMWOp (PR #110553)

Ilya V llvmlistbot at llvm.org
Thu Oct 3 08:04:59 PDT 2024


https://github.com/joviliast updated https://github.com/llvm/llvm-project/pull/110553

>From cb11888e4c732e61e3adeb41f1ca0e2a708b8f01 Mon Sep 17 00:00:00 2001
From: Ilya Veselov <iveselov.nn at gmail.com>
Date: Wed, 25 Sep 2024 19:38:29 +0200
Subject: [PATCH] Allow fixed vector operand for `LLVM_AtomicRMWOp`

This PR fixes `LLVM_AtomicRMWOp` allowed semantics and verifier logic to
enable building of `LLVM_AtomicRMWOp` with fixed vectors of compatible fp values
as operands for fp rmw operation.

See also: https://llvm.org/docs/LangRef.html#id231

Signed-off-by: Ilya Veselov <iveselov.nn at gmail.com>
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td |  3 ++-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp  | 10 +++++++++-
 mlir/test/Dialect/LLVMIR/invalid.mlir       | 15 +++++++++++++++
 mlir/test/Dialect/LLVMIR/roundtrip.mlir     |  8 +++++---
 mlir/test/Target/LLVMIR/llvmir.mlir         | 13 +++++++++++--
 5 files changed, 42 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 030160821bd823..88e82ce48959b0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1737,7 +1737,8 @@ def LLVM_ConstantOp
 // Atomic operations.
 //
 
-def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnySignlessInteger]>;
+def LLVM_AtomicRMWType
+    : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnySignlessInteger, LLVM_AnyFixedVector]>;
 
 def LLVM_AtomicRMWOp : LLVM_MemAccessOpBase<"atomicrmw", [
       TypesMatchWith<"result #0 and operand #1 have the same type",
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 0561c364c7d591..fb7024a14f8d4e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3010,8 +3010,16 @@ LogicalResult AtomicRMWOp::verify() {
   auto valType = getVal().getType();
   if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub ||
       getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax) {
-    if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
+    if (isCompatibleVectorType(valType)) {
+      if (isScalableVectorType(valType))
+        return emitOpError("expected LLVM IR fixed vector type");
+      Type elemType = getVectorElementType(valType);
+      if (!isCompatibleFloatingPointType(elemType))
+        return emitOpError(
+            "expected LLVM IR floating point type for vector element");
+    } else if (!isCompatibleFloatingPointType(valType)) {
       return emitOpError("expected LLVM IR floating point type");
+    }
   } else if (getBinOp() == AtomicBinOp::xchg) {
     DataLayout dataLayout = DataLayout::closest(*this);
     if (!isTypeCompatibleWithAtomicOp(valType, dataLayout))
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 9388d7ef24936e..5677d7ff41202f 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -643,6 +643,21 @@ func.func @atomicrmw_expected_float(%i32_ptr : !llvm.ptr, %i32 : i32) {
 
 // -----
 
+func.func @atomicrmw_scalable_vector(%ptr : !llvm.ptr, %f32_vec : vector<[2]xf32>) {
+  // expected-error at +1 {{'val' must be floating point LLVM type or LLVM pointer type or signless integer or LLVM dialect-compatible fixed-length vector type}}
+  %0 = llvm.atomicrmw fadd %ptr, %f32_vec unordered : !llvm.ptr, vector<[2]xf32>
+  llvm.return
+}
+// -----
+
+func.func @atomicrmw_vector_expected_float(%ptr : !llvm.ptr, %i32_vec : vector<3xi32>) {
+  // expected-error at +1 {{expected LLVM IR floating point type for vector element}}
+  %0 = llvm.atomicrmw fadd %ptr, %i32_vec unordered : !llvm.ptr, vector<3xi32>
+  llvm.return
+}
+
+// -----
+
 func.func @atomicrmw_unexpected_xchg_type(%i1_ptr : !llvm.ptr, %i1 : i1) {
   // expected-error at +1 {{unexpected LLVM IR type for 'xchg' bin_op}}
   %0 = llvm.atomicrmw xchg %i1_ptr, %i1 unordered : !llvm.ptr, i1
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 62f1de2b7fe7d4..3062cdc38c0abb 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -420,11 +420,13 @@ func.func @atomic_store(%val : f32, %large_val : i256, %ptr : !llvm.ptr) {
 }
 
 // CHECK-LABEL: @atomicrmw
-func.func @atomicrmw(%ptr : !llvm.ptr, %val : f32) {
+func.func @atomicrmw(%ptr : !llvm.ptr, %f32 : f32, %f16_vec : vector<2xf16>) {
   // CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} monotonic : !llvm.ptr, f32
-  %0 = llvm.atomicrmw fadd %ptr, %val monotonic : !llvm.ptr, f32
+  %0 = llvm.atomicrmw fadd %ptr, %f32 monotonic : !llvm.ptr, f32
   // CHECK: llvm.atomicrmw volatile fsub %{{.*}}, %{{.*}} syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32
-  %1 = llvm.atomicrmw volatile fsub %ptr, %val syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32
+  %1 = llvm.atomicrmw volatile fsub %ptr, %f32 syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32
+  // CHECK: llvm.atomicrmw fmin %{{.*}}, %{{.*}} monotonic : !llvm.ptr, vector<2xf16>
+  %2 = llvm.atomicrmw fmin %ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16>
   llvm.return
 }
 
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 007284d0ca4435..327c9f05f4c72c 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1496,7 +1496,8 @@ llvm.func @elements_constant_3d_array() -> !llvm.array<2 x array<2 x array<2 x i
 // CHECK-LABEL: @atomicrmw
 llvm.func @atomicrmw(
     %f32_ptr : !llvm.ptr, %f32 : f32,
-    %i32_ptr : !llvm.ptr, %i32 : i32) {
+    %i32_ptr : !llvm.ptr, %i32 : i32,
+    %f16_vec_ptr : !llvm.ptr, %f16_vec : vector<2xf16>) {
   // CHECK: atomicrmw fadd ptr %{{.*}}, float %{{.*}} monotonic
   %0 = llvm.atomicrmw fadd %f32_ptr, %f32 monotonic : !llvm.ptr, f32
   // CHECK: atomicrmw fsub ptr %{{.*}}, float %{{.*}} monotonic
@@ -1535,11 +1536,19 @@ llvm.func @atomicrmw(
   %17 = llvm.atomicrmw usub_cond %i32_ptr, %i32 monotonic : !llvm.ptr, i32
   // CHECK: atomicrmw usub_sat ptr %{{.*}}, i32 %{{.*}} monotonic
   %18 = llvm.atomicrmw usub_sat %i32_ptr, %i32 monotonic : !llvm.ptr, i32
+  // CHECK: atomicrmw fadd ptr %{{.*}}, <2 x half> %{{.*}} monotonic
+  %19 = llvm.atomicrmw fadd %f16_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16>
+  // CHECK: atomicrmw fsub ptr %{{.*}}, <2 x half> %{{.*}} monotonic
+  %20 = llvm.atomicrmw fsub %f16_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16>
+  // CHECK: atomicrmw fmax ptr %{{.*}}, <2 x half> %{{.*}} monotonic
+  %21 = llvm.atomicrmw fmax %f16_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16>
+  // CHECK: atomicrmw fmin ptr %{{.*}}, <2 x half> %{{.*}} monotonic
+  %22 = llvm.atomicrmw fmin %f16_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16>
 
   // CHECK: atomicrmw volatile
   // CHECK-SAME:  syncscope("singlethread")
   // CHECK-SAME:  align 8
-  %19 = llvm.atomicrmw volatile udec_wrap %i32_ptr, %i32 syncscope("singlethread") monotonic {alignment = 8 : i64} : !llvm.ptr, i32
+  %23 = llvm.atomicrmw volatile udec_wrap %i32_ptr, %i32 syncscope("singlethread") monotonic {alignment = 8 : i64} : !llvm.ptr, i32
   llvm.return
 }
 



More information about the Mlir-commits mailing list