[flang-commits] [flang] [flang][cuda] Lower match_any_sync functions to nvvm intrinsics (PR #127942)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Thu Feb 20 10:48:27 PST 2025


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/127942

>From b94839d2a44701b3e06045faf42470d2c3b227c4 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 19 Feb 2025 18:08:07 -0800
Subject: [PATCH] [flang][cuda] Lower match_any_sync functions to nvvm
 intrinsics

---
 .../flang/Optimizer/Builder/IntrinsicCall.h   |  1 +
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 43 +++++++++++++++++++
 flang/module/cudadevice.f90                   | 23 ++++++++++
 flang/test/Lower/CUDA/cuda-device-proc.cuf    | 21 +++++++++
 4 files changed, 88 insertions(+)

diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index caec6a913293f..b679ef74870b1 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -336,6 +336,7 @@ struct IntrinsicLibrary {
   template <typename Shift>
   mlir::Value genMask(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genMatchAllSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
+  mlir::Value genMatchAnySync(mlir::Type, llvm::ArrayRef<mlir::Value>);
   fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genMatmulTranspose(mlir::Type,
                                         llvm::ArrayRef<fir::ExtendedValue>);
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 754496921ca3a..d98ee58ace2bc 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -485,6 +485,22 @@ static constexpr IntrinsicHandler handlers[]{
      &I::genMatchAllSync,
      {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
      /*isElemental=*/false},
+    {"match_any_syncjd",
+     &I::genMatchAnySync,
+     {{{"mask", asValue}, {"value", asValue}}},
+     /*isElemental=*/false},
+    {"match_any_syncjf",
+     &I::genMatchAnySync,
+     {{{"mask", asValue}, {"value", asValue}}},
+     /*isElemental=*/false},
+    {"match_any_syncjj",
+     &I::genMatchAnySync,
+     {{{"mask", asValue}, {"value", asValue}}},
+     /*isElemental=*/false},
+    {"match_any_syncjx",
+     &I::genMatchAnySync,
+     {{{"mask", asValue}, {"value", asValue}}},
+     /*isElemental=*/false},
     {"matmul",
      &I::genMatmul,
      {{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
@@ -6060,6 +6076,7 @@ mlir::Value IntrinsicLibrary::genMask(mlir::Type resultType,
   return result;
 }
 
+// MATCH_ALL_SYNC
 mlir::Value
 IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
                                   llvm::ArrayRef<mlir::Value> args) {
@@ -6096,6 +6113,32 @@ IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
   return value;
 }
 
+// MATCH_ANY_SYNC
+mlir::Value
+IntrinsicLibrary::genMatchAnySync(mlir::Type resultType,
+                                  llvm::ArrayRef<mlir::Value> args) {
+  assert(args.size() == 2);
+  bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();
+
+  llvm::StringRef funcName =
+      is32 ? "llvm.nvvm.match.any.sync.i32p" : "llvm.nvvm.match.any.sync.i64p";
+  mlir::MLIRContext *context = builder.getContext();
+  mlir::Type i32Ty = builder.getI32Type();
+  mlir::Type i64Ty = builder.getI64Type();
+  mlir::Type valTy = is32 ? i32Ty : i64Ty;
+
+  mlir::FunctionType ftype =
+      mlir::FunctionType::get(context, {i32Ty, valTy}, {i32Ty});
+  auto funcOp = builder.createFunction(loc, funcName, ftype);
+  llvm::SmallVector<mlir::Value> filteredArgs;
+  filteredArgs.push_back(args[0]);
+  if (args[1].getType().isF32() || args[1].getType().isF64())
+    filteredArgs.push_back(builder.create<fir::ConvertOp>(loc, valTy, args[1]));
+  else
+    filteredArgs.push_back(args[1]);
+  return builder.create<fir::CallOp>(loc, funcOp, filteredArgs).getResult(0);
+}
+
 // MATMUL
 fir::ExtendedValue
 IntrinsicLibrary::genMatmul(mlir::Type resultType,
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index c75c5c191ab51..8b31c0c0856fd 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -589,4 +589,27 @@ attributes(device) integer function match_all_syncjd(mask, val, pred)
   end function
 end interface
 
+interface match_any_sync
+  attributes(device) integer function match_any_syncjj(mask, val)
+!dir$ ignore_tkr(d) mask, (d) val
+  integer(4), value :: mask
+  integer(4), value :: val
+  end function
+  attributes(device) integer function match_any_syncjx(mask, val)
+!dir$ ignore_tkr(d) mask, (d) val
+  integer(4), value :: mask
+  integer(8), value :: val
+  end function
+  attributes(device) integer function match_any_syncjf(mask, val)
+!dir$ ignore_tkr(d) mask, (d) val
+  integer(4), value :: mask
+  real(4), value    :: val
+  end function
+  attributes(device) integer function match_any_syncjd(mask, val)
+!dir$ ignore_tkr(d) mask, (d) val
+  integer(4), value :: mask
+  real(8), value    :: val
+  end function
+end interface
+
 end module
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 1210dae8608c8..e7d1dba385bb8 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -131,6 +131,25 @@ end subroutine
 ! CHECK: fir.convert %{{.*}} : (f64) -> i64
 ! CHECK: fir.call @llvm.nvvm.match.all.sync.i64p
 
+attributes(device) subroutine testMatchAny()
+  integer :: a, mask, v32
+  integer(8) :: v64
+  real(4) :: r4
+  real(8) :: r8
+  a = match_any_sync(mask, v32)
+  a = match_any_sync(mask, v64)
+  a = match_any_sync(mask, r4)
+  a = match_any_sync(mask, r8)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtestmatchany()
+! CHECK: fir.call @llvm.nvvm.match.any.sync.i32p
+! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
+! CHECK: fir.convert %{{.*}} : (f32) -> i32
+! CHECK: fir.call @llvm.nvvm.match.any.sync.i32p
+! CHECK: fir.convert %{{.*}} : (f64) -> i64
+! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
+
 ! CHECK: func.func private @llvm.nvvm.barrier0()
 ! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32)
 ! CHECK: func.func private @llvm.nvvm.membar.gl()
@@ -141,3 +160,5 @@ end subroutine
 ! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
 ! CHECK: func.func private @llvm.nvvm.match.all.sync.i32p(i32, i32) -> tuple<i32, i1>
 ! CHECK: func.func private @llvm.nvvm.match.all.sync.i64p(i32, i64) -> tuple<i32, i1>
+! CHECK: func.func private @llvm.nvvm.match.any.sync.i32p(i32, i32) -> i32
+! CHECK: func.func private @llvm.nvvm.match.any.sync.i64p(i32, i64) -> i32



More information about the flang-commits mailing list