[flang-commits] [flang] [flang][cuda] Lower atomiccas, atomicxor and atomicexch (PR #128242)

via flang-commits flang-commits at lists.llvm.org
Fri Feb 21 14:41:41 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

Lower atomiccas, atomicxor and atomicexch to corresponding llvm atomic operations. 

---
Full diff: https://github.com/llvm/llvm-project/pull/128242.diff


4 Files Affected:

- (modified) flang/include/flang/Optimizer/Builder/IntrinsicCall.h (+4-1) 
- (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+69) 
- (modified) flang/module/cudadevice.f90 (+107-49) 
- (modified) flang/test/Lower/CUDA/cuda-device-proc.cuf (+12-12) 


``````````diff
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index b679ef74870b1..f5971610694f0 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -187,12 +187,15 @@ struct IntrinsicLibrary {
   mlir::Value genAtanpi(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genAtomicAdd(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genAtomicAnd(mlir::Type, llvm::ArrayRef<mlir::Value>);
-  mlir::Value genAtomicOr(mlir::Type, llvm::ArrayRef<mlir::Value>);
+  mlir::Value genAtomicCas(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genAtomicDec(mlir::Type, llvm::ArrayRef<mlir::Value>);
+  mlir::Value genAtomicExch(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genAtomicInc(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genAtomicMax(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genAtomicMin(mlir::Type, llvm::ArrayRef<mlir::Value>);
+  mlir::Value genAtomicOr(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genAtomicSub(mlir::Type, llvm::ArrayRef<mlir::Value>);
+  mlir::Value genAtomicXor(mlir::Type, llvm::ArrayRef<mlir::Value>);
   fir::ExtendedValue
       genCommandArgumentCount(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   mlir::Value genAsind(mlir::Type, llvm::ArrayRef<mlir::Value>);
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index d98ee58ace2bc..28fbe83defb61 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -152,7 +152,39 @@ static constexpr IntrinsicHandler handlers[]{
     {"atomicaddi", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
     {"atomicaddl", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
     {"atomicandi", &I::genAtomicAnd, {{{"a", asAddr}, {"v", asValue}}}, false},
+    {"atomiccasd",
+     &I::genAtomicCas,
+     {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
+     false},
+    {"atomiccasf",
+     &I::genAtomicCas,
+     {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
+     false},
+    {"atomiccasi",
+     &I::genAtomicCas,
+     {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
+     false},
+    {"atomiccasul",
+     &I::genAtomicCas,
+     {{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
+     false},
     {"atomicdeci", &I::genAtomicDec, {{{"a", asAddr}, {"v", asValue}}}, false},
+    {"atomicexchd",
+     &I::genAtomicExch,
+     {{{"a", asAddr}, {"v", asValue}}},
+     false},
+    {"atomicexchf",
+     &I::genAtomicExch,
+     {{{"a", asAddr}, {"v", asValue}}},
+     false},
+    {"atomicexchi",
+     &I::genAtomicExch,
+     {{{"a", asAddr}, {"v", asValue}}},
+     false},
+    {"atomicexchul",
+     &I::genAtomicExch,
+     {{{"a", asAddr}, {"v", asValue}}},
+     false},
     {"atomicinci", &I::genAtomicInc, {{{"a", asAddr}, {"v", asValue}}}, false},
     {"atomicmaxd", &I::genAtomicMax, {{{"a", asAddr}, {"v", asValue}}}, false},
     {"atomicmaxf", &I::genAtomicMax, {{{"a", asAddr}, {"v", asValue}}}, false},
@@ -167,6 +199,7 @@ static constexpr IntrinsicHandler handlers[]{
     {"atomicsubf", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false},
     {"atomicsubi", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false},
     {"atomicsubl", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false},
+    {"atomicxori", &I::genAtomicXor, {{{"a", asAddr}, {"v", asValue}}}, false},
     {"bessel_jn",
      &I::genBesselJn,
      {{{"n1", asValue}, {"n2", asValue}, {"x", asValue}}},
@@ -2691,6 +2724,22 @@ mlir::Value IntrinsicLibrary::genAtomicOr(mlir::Type resultType,
   return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
 }
 
+// ATOMICCAS
+mlir::Value IntrinsicLibrary::genAtomicCas(mlir::Type resultType,
+                                           llvm::ArrayRef<mlir::Value> args) {
+  assert(args.size() == 3);
+  assert(args[1].getType() == args[2].getType());
+  auto successOrdering = mlir::LLVM::AtomicOrdering::acq_rel;
+  auto failureOrdering = mlir::LLVM::AtomicOrdering::monotonic;
+  auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(resultType.getContext());
+  auto address =
+      builder.create<mlir::UnrealizedConversionCastOp>(loc, llvmPtrTy, args[0])
+          .getResult(0);
+  auto cmpxchg = builder.create<mlir::LLVM::AtomicCmpXchgOp>(
+      loc, address, args[1], args[2], successOrdering, failureOrdering);
+  return builder.create<mlir::LLVM::ExtractValueOp>(loc, cmpxchg, 1);
+}
+
 mlir::Value IntrinsicLibrary::genAtomicDec(mlir::Type resultType,
                                            llvm::ArrayRef<mlir::Value> args) {
   assert(args.size() == 2);
@@ -2700,6 +2749,16 @@ mlir::Value IntrinsicLibrary::genAtomicDec(mlir::Type resultType,
   return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
 }
 
+// ATOMICEXCH
+mlir::Value IntrinsicLibrary::genAtomicExch(mlir::Type resultType,
+                                            llvm::ArrayRef<mlir::Value> args) {
+  assert(args.size() == 2);
+  assert(mlir::isa<mlir::IntegerType>(args[1].getType()));
+
+  mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::xchg;
+  return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
+}
+
 mlir::Value IntrinsicLibrary::genAtomicInc(mlir::Type resultType,
                                            llvm::ArrayRef<mlir::Value> args) {
   assert(args.size() == 2);
@@ -2731,6 +2790,16 @@ mlir::Value IntrinsicLibrary::genAtomicMin(mlir::Type resultType,
   return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
 }
 
+// ATOMICXOR
+mlir::Value IntrinsicLibrary::genAtomicXor(mlir::Type resultType,
+                                           llvm::ArrayRef<mlir::Value> args) {
+  assert(args.size() == 2);
+  assert(mlir::isa<mlir::IntegerType>(args[1].getType()));
+
+  mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::_xor;
+  return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
+}
+
 // ASSOCIATED
 fir::ExtendedValue
 IntrinsicLibrary::genAssociated(mlir::Type resultType,
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index 8b31c0c0856fd..af8ea66618e27 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -557,59 +557,117 @@ attributes(device) pure integer function atomicdeci(address, val)
     end function
   end interface
 
+  interface atomiccas
+    attributes(device) pure integer function atomiccasi(address, val, val2)
+  !dir$ ignore_tkr (rd) address, (d) val, (d) val2
+    integer, intent(inout) :: address
+    integer, value :: val, val2
+    end function
+    attributes(device) pure integer(8) function atomiccasul(address, val, val2)
+  !dir$ ignore_tkr (rd) address, (dk) val, (dk) val2
+    integer(8), intent(inout) :: address
+    integer(8), value :: val, val2
+    end function
+    attributes(device) pure real function atomiccasf(address, val, val2)
+  !dir$ ignore_tkr (rd) address, (d) val, (d) val2
+    real, intent(inout) :: address
+    real, value :: val, val2
+    end function
+    attributes(device) pure double precision function atomiccasd(address, val, val2)
+  !dir$ ignore_tkr (rd) address, (d) val, (d) val2
+    double precision, intent(inout) :: address
+    double precision, value :: val, val2
+    end function
+  end interface
+
+  interface atomicexch
+    attributes(device) pure integer function atomicexchi(address, val)
+  !dir$ ignore_tkr (rd) address, (d) val
+    integer, intent(inout) :: address
+    integer, value :: val
+    end function
+    attributes(device) pure integer(8) function atomicexchul(address, val)
+  !dir$ ignore_tkr (rd) address, (dk) val
+    integer(8), intent(inout) :: address
+    integer(8), value :: val
+    end function
+    attributes(device) pure real function atomicexchf(address, val)
+  !dir$ ignore_tkr (rd) address, (d) val
+    real, intent(inout) :: address
+    real, value :: val
+    end function
+    attributes(device) pure double precision function atomicexchd(address, val)
+  !dir$ ignore_tkr (rd) address, (d) val
+    double precision, intent(inout) :: address
+    double precision, value :: val
+    end function
+  end interface
+
+  interface atomicxor
+    attributes(device) pure integer function atomicxori(address, val)
+  !dir$ ignore_tkr (rd) address, (d) val
+    integer, intent(inout) :: address
+    integer, value :: val
+    end function
+  end interface
+
+  ! Time function
+
   interface
     attributes(device) integer(8) function clock64()
     end function
   end interface
 
-interface match_all_sync
-  attributes(device) integer function match_all_syncjj(mask, val, pred)
-!dir$ ignore_tkr(d) mask, (d) val, (d) pred
-  integer(4), value :: mask
-  integer(4), value :: val
-  integer(4)        :: pred
-  end function
-  attributes(device) integer function match_all_syncjx(mask, val, pred)
-!dir$ ignore_tkr(d) mask, (d) val, (d) pred
-  integer(4), value :: mask
-  integer(8), value :: val
-  integer(4)        :: pred
-  end function
-  attributes(device) integer function match_all_syncjf(mask, val, pred)
-!dir$ ignore_tkr(d) mask, (d) val, (d) pred
-  integer(4), value :: mask
-  real(4), value    :: val
-  integer(4)        :: pred
-  end function
-  attributes(device) integer function match_all_syncjd(mask, val, pred)
-!dir$ ignore_tkr(d) mask, (d) val, (d) pred
-  integer(4), value :: mask
-  real(8), value    :: val
-  integer(4)        :: pred
-  end function
-end interface
-
-interface match_any_sync
-  attributes(device) integer function match_any_syncjj(mask, val)
-!dir$ ignore_tkr(d) mask, (d) val
-  integer(4), value :: mask
-  integer(4), value :: val
-  end function
-  attributes(device) integer function match_any_syncjx(mask, val)
-!dir$ ignore_tkr(d) mask, (d) val
-  integer(4), value :: mask
-  integer(8), value :: val
-  end function
-  attributes(device) integer function match_any_syncjf(mask, val)
-!dir$ ignore_tkr(d) mask, (d) val
-  integer(4), value :: mask
-  real(4), value    :: val
-  end function
-  attributes(device) integer function match_any_syncjd(mask, val)
-!dir$ ignore_tkr(d) mask, (d) val
-  integer(4), value :: mask
-  real(8), value    :: val
-  end function
-end interface
+  ! Warp Match Functions
+
+  interface match_all_sync
+    attributes(device) integer function match_all_syncjj(mask, val, pred)
+  !dir$ ignore_tkr(d) mask, (d) val, (d) pred
+    integer(4), value :: mask
+    integer(4), value :: val
+    integer(4)        :: pred
+    end function
+    attributes(device) integer function match_all_syncjx(mask, val, pred)
+  !dir$ ignore_tkr(d) mask, (d) val, (d) pred
+    integer(4), value :: mask
+    integer(8), value :: val
+    integer(4)        :: pred
+    end function
+    attributes(device) integer function match_all_syncjf(mask, val, pred)
+  !dir$ ignore_tkr(d) mask, (d) val, (d) pred
+    integer(4), value :: mask
+    real(4), value    :: val
+    integer(4)        :: pred
+    end function
+    attributes(device) integer function match_all_syncjd(mask, val, pred)
+  !dir$ ignore_tkr(d) mask, (d) val, (d) pred
+    integer(4), value :: mask
+    real(8), value    :: val
+    integer(4)        :: pred
+    end function
+  end interface
+
+  interface match_any_sync
+    attributes(device) integer function match_any_syncjj(mask, val)
+  !dir$ ignore_tkr(d) mask, (d) val
+    integer(4), value :: mask
+    integer(4), value :: val
+    end function
+    attributes(device) integer function match_any_syncjx(mask, val)
+  !dir$ ignore_tkr(d) mask, (d) val
+    integer(4), value :: mask
+    integer(8), value :: val
+    end function
+    attributes(device) integer function match_any_syncjf(mask, val)
+  !dir$ ignore_tkr(d) mask, (d) val
+    integer(4), value :: mask
+    real(4), value    :: val
+    end function
+    attributes(device) integer function match_any_syncjd(mask, val)
+  !dir$ ignore_tkr(d) mask, (d) val
+    integer(4), value :: mask
+    real(8), value    :: val
+    end function
+  end interface
 
 end module
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index e7d1dba385bb8..fcfcc2e537039 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -150,15 +150,15 @@ end subroutine
 ! CHECK: fir.convert %{{.*}} : (f64) -> i64
 ! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
 
-! CHECK: func.func private @llvm.nvvm.barrier0()
-! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32)
-! CHECK: func.func private @llvm.nvvm.membar.gl()
-! CHECK: func.func private @llvm.nvvm.membar.cta()
-! CHECK: func.func private @llvm.nvvm.membar.sys()
-! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
-! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
-! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
-! CHECK: func.func private @llvm.nvvm.match.all.sync.i32p(i32, i32) -> tuple<i32, i1>
-! CHECK: func.func private @llvm.nvvm.match.all.sync.i64p(i32, i64) -> tuple<i32, i1>
-! CHECK: func.func private @llvm.nvvm.match.any.sync.i32p(i32, i32) -> i32
-! CHECK: func.func private @llvm.nvvm.match.any.sync.i64p(i32, i64) -> i32
+attributes(device) subroutine testAtomic()
+  integer :: a, istat, j
+  istat = atomicexch(a,0)
+  istat = atomicxor(a, j)
+  istat = atomiccas(a, i, 14)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtestatomic()
+! CHECK: llvm.atomicrmw xchg %{{.*}}, %c0{{.*}} seq_cst : !llvm.ptr, i32
+! CHECK: llvm.atomicrmw _xor %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
+! CHECK: %[[ADDR:.*]] = builtin.unrealized_conversion_cast %{{.*}}#1 : !fir.ref<i32> to !llvm.ptr
+! CHECK: llvm.cmpxchg %[[ADDR]], %{{.*}}, %c14{{.*}} acq_rel monotonic : !llvm.ptr, i32

``````````

</details>


https://github.com/llvm/llvm-project/pull/128242


More information about the flang-commits mailing list