[flang-commits] [flang] [flang][cuda] Lower match_all_sync functions to nvvm intrinsics (PR #127940)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Wed Feb 19 18:09:13 PST 2025


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/127940

None

>From 12d3dd9c6c68b6f5745796ca09b508f59bbb98fd Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 19 Feb 2025 18:08:07 -0800
Subject: [PATCH] [flang][cuda] Lower match_all_sync functions to nvvm
 intrinsics

---
 .../flang/Optimizer/Builder/IntrinsicCall.h   |  1 +
 flang/include/flang/Semantics/tools.h         |  1 +
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 52 +++++++++++++++++++
 flang/lib/Optimizer/CodeGen/CodeGen.cpp       |  7 +++
 flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp    |  3 +-
 flang/module/cudadevice.f90                   | 27 ++++++++++
 flang/test/Lower/CUDA/cuda-device-proc.cuf    | 21 ++++++++
 7 files changed, 111 insertions(+), 1 deletion(-)

diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index 65732ce7f3224..caec6a913293f 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -335,6 +335,7 @@ struct IntrinsicLibrary {
   mlir::Value genMalloc(mlir::Type, llvm::ArrayRef<mlir::Value>);
   template <typename Shift>
   mlir::Value genMask(mlir::Type, llvm::ArrayRef<mlir::Value>);
+  mlir::Value genMatchAllSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
   fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genMatmulTranspose(mlir::Type,
                                         llvm::ArrayRef<fir::ExtendedValue>);
diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h
index e82446a2ba884..56dcfa88ad92d 100644
--- a/flang/include/flang/Semantics/tools.h
+++ b/flang/include/flang/Semantics/tools.h
@@ -231,6 +231,7 @@ inline bool NeedCUDAAlloc(const Symbol &sym) {
         (*details->cudaDataAttr() == common::CUDADataAttr::Device ||
             *details->cudaDataAttr() == common::CUDADataAttr::Managed ||
             *details->cudaDataAttr() == common::CUDADataAttr::Unified ||
+            *details->cudaDataAttr() == common::CUDADataAttr::Shared ||
             *details->cudaDataAttr() == common::CUDADataAttr::Pinned)) {
       return true;
     }
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 93744fa58ebc0..754496921ca3a 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -469,6 +469,22 @@ static constexpr IntrinsicHandler handlers[]{
     {"malloc", &I::genMalloc},
     {"maskl", &I::genMask<mlir::arith::ShLIOp>},
     {"maskr", &I::genMask<mlir::arith::ShRUIOp>},
+    {"match_all_syncjd",
+     &I::genMatchAllSync,
+     {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
+     /*isElemental=*/false},
+    {"match_all_syncjf",
+     &I::genMatchAllSync,
+     {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
+     /*isElemental=*/false},
+    {"match_all_syncjj",
+     &I::genMatchAllSync,
+     {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
+     /*isElemental=*/false},
+    {"match_all_syncjx",
+     &I::genMatchAllSync,
+     {{{"mask", asValue}, {"value", asValue}, {"pred", asAddr}}},
+     /*isElemental=*/false},
     {"matmul",
      &I::genMatmul,
      {{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
@@ -6044,6 +6060,42 @@ mlir::Value IntrinsicLibrary::genMask(mlir::Type resultType,
   return result;
 }
 
+mlir::Value
+IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
+                                  llvm::ArrayRef<mlir::Value> args) {
+  assert(args.size() == 3);
+  bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();
+
+  llvm::StringRef funcName =
+      is32 ? "llvm.nvvm.match.all.sync.i32p" : "llvm.nvvm.match.all.sync.i64p";
+  mlir::MLIRContext *context = builder.getContext();
+  mlir::Type i32Ty = builder.getI32Type();
+  mlir::Type i64Ty = builder.getI64Type();
+  mlir::Type i1Ty = builder.getI1Type();
+  mlir::Type retTy = mlir::TupleType::get(context, {resultType, i1Ty});
+  mlir::Type valTy = is32 ? i32Ty : i64Ty;
+
+  mlir::FunctionType ftype =
+      mlir::FunctionType::get(context, {i32Ty, valTy}, {retTy});
+  auto funcOp = builder.createFunction(loc, funcName, ftype);
+  llvm::SmallVector<mlir::Value> filteredArgs;
+  filteredArgs.push_back(args[0]);
+  if (args[1].getType().isF32() || args[1].getType().isF64())
+    filteredArgs.push_back(builder.create<fir::ConvertOp>(loc, valTy, args[1]));
+  else
+    filteredArgs.push_back(args[1]);
+  auto call = builder.create<fir::CallOp>(loc, funcOp, filteredArgs);
+  auto zero = builder.getIntegerAttr(builder.getIndexType(), 0);
+  auto value = builder.create<fir::ExtractValueOp>(
+      loc, resultType, call.getResult(0), builder.getArrayAttr(zero));
+  auto one = builder.getIntegerAttr(builder.getIndexType(), 1);
+  auto pred = builder.create<fir::ExtractValueOp>(loc, i1Ty, call.getResult(0),
+                                                  builder.getArrayAttr(one));
+  auto conv = builder.create<mlir::LLVM::ZExtOp>(loc, resultType, pred);
+  builder.create<fir::StoreOp>(loc, conv, args[2]);
+  return value;
+}
+
 // MATMUL
 fir::ExtendedValue
 IntrinsicLibrary::genMatmul(mlir::Type resultType,
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index c76b7cde55bdd..439cc7a856236 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -292,6 +292,12 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
       rewriter.setInsertionPointAfter(size.getDefiningOp());
     }
 
+    if (auto dataAttr = alloc->getAttrOfType<cuf::DataAttributeAttr>(
+            cuf::getDataAttrName())) {
+      if (dataAttr.getValue() == cuf::DataAttribute::Shared)
+        allocaAs = 3;
+    }
+
     // NOTE: we used to pass alloc->getAttrs() in the builder for non opaque
     // pointers! Only propagate pinned and bindc_name to help debugging, but
     // this should have no functional purpose (and passing the operand segment
@@ -316,6 +322,7 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
       rewriter.replaceOpWithNewOp<mlir::LLVM::AddrSpaceCastOp>(
           alloc, ::getLlvmPtrType(alloc.getContext(), programAs), llvmAlloc);
     }
+
     return mlir::success();
   }
 };
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index b05991a29a321..fa82f3916a57e 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -57,7 +57,8 @@ static llvm::LogicalResult checkCudaAttr(Op op) {
   if (op.getDataAttr() == cuf::DataAttribute::Device ||
       op.getDataAttr() == cuf::DataAttribute::Managed ||
       op.getDataAttr() == cuf::DataAttribute::Unified ||
-      op.getDataAttr() == cuf::DataAttribute::Pinned)
+      op.getDataAttr() == cuf::DataAttribute::Pinned ||
+      op.getDataAttr() == cuf::DataAttribute::Shared)
     return mlir::success();
   return op.emitOpError()
          << "expect device, managed, pinned or unified cuda attribute";
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index e473590a7d78f..c75c5c191ab51 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -562,4 +562,31 @@ attributes(device) integer(8) function clock64()
     end function
   end interface
 
+interface match_all_sync
+  attributes(device) integer function match_all_syncjj(mask, val, pred)
+!dir$ ignore_tkr(d) mask, (d) val, (d) pred
+  integer(4), value :: mask
+  integer(4), value :: val
+  integer(4)        :: pred
+  end function
+  attributes(device) integer function match_all_syncjx(mask, val, pred)
+!dir$ ignore_tkr(d) mask, (d) val, (d) pred
+  integer(4), value :: mask
+  integer(8), value :: val
+  integer(4)        :: pred
+  end function
+  attributes(device) integer function match_all_syncjf(mask, val, pred)
+!dir$ ignore_tkr(d) mask, (d) val, (d) pred
+  integer(4), value :: mask
+  real(4), value    :: val
+  integer(4)        :: pred
+  end function
+  attributes(device) integer function match_all_syncjd(mask, val, pred)
+!dir$ ignore_tkr(d) mask, (d) val, (d) pred
+  integer(4), value :: mask
+  real(8), value    :: val
+  integer(4)        :: pred
+  end function
+end interface
+
 end module
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 6a5524102c0ea..1210dae8608c8 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -112,6 +112,25 @@ end
 ! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
 ! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
 
+attributes(device) subroutine testMatch()
+  integer :: a, ipred, mask, v32
+  integer(8) :: v64
+  real(4) :: r4
+  real(8) :: r8
+  a = match_all_sync(mask, v32, ipred)
+  a = match_all_sync(mask, v64, ipred)
+  a = match_all_sync(mask, r4, ipred)
+  a = match_all_sync(mask, r8, ipred)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtestmatch()
+! CHECK: fir.call @llvm.nvvm.match.all.sync.i32p
+! CHECK: fir.call @llvm.nvvm.match.all.sync.i64p
+! CHECK: fir.convert %{{.*}} : (f32) -> i32
+! CHECK: fir.call @llvm.nvvm.match.all.sync.i32p
+! CHECK: fir.convert %{{.*}} : (f64) -> i64
+! CHECK: fir.call @llvm.nvvm.match.all.sync.i64p
+
 ! CHECK: func.func private @llvm.nvvm.barrier0()
 ! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32)
 ! CHECK: func.func private @llvm.nvvm.membar.gl()
@@ -120,3 +139,5 @@ end
 ! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
 ! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
 ! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
+! CHECK: func.func private @llvm.nvvm.match.all.sync.i32p(i32, i32) -> tuple<i32, i1>
+! CHECK: func.func private @llvm.nvvm.match.all.sync.i64p(i32, i64) -> tuple<i32, i1>



More information about the flang-commits mailing list