[flang-commits] [flang] [flang][cuda] Lower match_any_sync functions to nvvm intrinsics (PR #127942)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Wed Feb 19 18:26:27 PST 2025
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/127942
None
>From ce2a65a6ec2070c0e49e554cf8d26d287086f368 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 19 Feb 2025 18:08:07 -0800
Subject: [PATCH] [flang][cuda] Lower match_any_sync functions to nvvm
intrinsics
---
.../flang/Optimizer/Builder/IntrinsicCall.h | 1 +
flang/include/flang/Semantics/tools.h | 1 +
flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 41 +++++++++++++++++++
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 7 ++++
flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp | 3 +-
flang/module/cudadevice.f90 | 23 +++++++++++
flang/test/Lower/CUDA/cuda-device-proc.cuf | 21 ++++++++++
7 files changed, 96 insertions(+), 1 deletion(-)
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index 65732ce7f3224..27783fac26845 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -335,6 +335,7 @@ struct IntrinsicLibrary {
mlir::Value genMalloc(mlir::Type, llvm::ArrayRef<mlir::Value>);
template <typename Shift>
mlir::Value genMask(mlir::Type, llvm::ArrayRef<mlir::Value>);
+ mlir::Value genMatchAnySync(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genMatmulTranspose(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h
index e82446a2ba884..56dcfa88ad92d 100644
--- a/flang/include/flang/Semantics/tools.h
+++ b/flang/include/flang/Semantics/tools.h
@@ -231,6 +231,7 @@ inline bool NeedCUDAAlloc(const Symbol &sym) {
(*details->cudaDataAttr() == common::CUDADataAttr::Device ||
*details->cudaDataAttr() == common::CUDADataAttr::Managed ||
*details->cudaDataAttr() == common::CUDADataAttr::Unified ||
+ *details->cudaDataAttr() == common::CUDADataAttr::Shared ||
*details->cudaDataAttr() == common::CUDADataAttr::Pinned)) {
return true;
}
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 93744fa58ebc0..215ce327303da 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -469,6 +469,22 @@ static constexpr IntrinsicHandler handlers[]{
{"malloc", &I::genMalloc},
{"maskl", &I::genMask<mlir::arith::ShLIOp>},
{"maskr", &I::genMask<mlir::arith::ShRUIOp>},
+ {"match_any_syncjd",
+ &I::genMatchAnySync,
+ {{{"mask", asValue}, {"value", asValue}}},
+ /*isElemental=*/false},
+ {"match_any_syncjf",
+ &I::genMatchAnySync,
+ {{{"mask", asValue}, {"value", asValue}}},
+ /*isElemental=*/false},
+ {"match_any_syncjj",
+ &I::genMatchAnySync,
+ {{{"mask", asValue}, {"value", asValue}}},
+ /*isElemental=*/false},
+ {"match_any_syncjx",
+ &I::genMatchAnySync,
+ {{{"mask", asValue}, {"value", asValue}}},
+ /*isElemental=*/false},
{"matmul",
&I::genMatmul,
{{{"matrix_a", asAddr}, {"matrix_b", asAddr}}},
@@ -6044,6 +6060,31 @@ mlir::Value IntrinsicLibrary::genMask(mlir::Type resultType,
return result;
}
+mlir::Value
+IntrinsicLibrary::genMatchAnySync(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 2);
+ bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();
+
+ llvm::StringRef funcName =
+ is32 ? "llvm.nvvm.match.any.sync.i32p" : "llvm.nvvm.match.any.sync.i64p";
+ mlir::MLIRContext *context = builder.getContext();
+ mlir::Type i32Ty = builder.getI32Type();
+ mlir::Type i64Ty = builder.getI64Type();
+ mlir::Type valTy = is32 ? i32Ty : i64Ty;
+
+ mlir::FunctionType ftype =
+ mlir::FunctionType::get(context, {i32Ty, valTy}, {i32Ty});
+ auto funcOp = builder.createFunction(loc, funcName, ftype);
+ llvm::SmallVector<mlir::Value> filteredArgs;
+ filteredArgs.push_back(args[0]);
+ if (args[1].getType().isF32() || args[1].getType().isF64())
+ filteredArgs.push_back(builder.create<fir::ConvertOp>(loc, valTy, args[1]));
+ else
+ filteredArgs.push_back(args[1]);
+ return builder.create<fir::CallOp>(loc, funcOp, filteredArgs).getResult(0);
+}
+
// MATMUL
fir::ExtendedValue
IntrinsicLibrary::genMatmul(mlir::Type resultType,
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index c76b7cde55bdd..439cc7a856236 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -292,6 +292,12 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
rewriter.setInsertionPointAfter(size.getDefiningOp());
}
+ if (auto dataAttr = alloc->getAttrOfType<cuf::DataAttributeAttr>(
+ cuf::getDataAttrName())) {
+ if (dataAttr.getValue() == cuf::DataAttribute::Shared)
+ allocaAs = 3;
+ }
+
// NOTE: we used to pass alloc->getAttrs() in the builder for non opaque
// pointers! Only propagate pinned and bindc_name to help debugging, but
// this should have no functional purpose (and passing the operand segment
@@ -316,6 +322,7 @@ struct AllocaOpConversion : public fir::FIROpConversion<fir::AllocaOp> {
rewriter.replaceOpWithNewOp<mlir::LLVM::AddrSpaceCastOp>(
alloc, ::getLlvmPtrType(alloc.getContext(), programAs), llvmAlloc);
}
+
return mlir::success();
}
};
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index b05991a29a321..fa82f3916a57e 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -57,7 +57,8 @@ static llvm::LogicalResult checkCudaAttr(Op op) {
if (op.getDataAttr() == cuf::DataAttribute::Device ||
op.getDataAttr() == cuf::DataAttribute::Managed ||
op.getDataAttr() == cuf::DataAttribute::Unified ||
- op.getDataAttr() == cuf::DataAttribute::Pinned)
+ op.getDataAttr() == cuf::DataAttribute::Pinned ||
+ op.getDataAttr() == cuf::DataAttribute::Shared)
return mlir::success();
return op.emitOpError()
<< "expect device, managed, pinned or unified cuda attribute";
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index e473590a7d78f..03bd10ce9245e 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -562,4 +562,27 @@ attributes(device) integer(8) function clock64()
end function
end interface
+interface match_any_sync
+ attributes(device) integer function match_any_syncjj(mask, val)
+!dir$ ignore_tkr(d) mask, (d) val
+ integer(4), value :: mask
+ integer(4), value :: val
+ end function
+ attributes(device) integer function match_any_syncjx(mask, val)
+!dir$ ignore_tkr(d) mask, (d) val
+ integer(4), value :: mask
+ integer(8), value :: val
+ end function
+ attributes(device) integer function match_any_syncjf(mask, val)
+!dir$ ignore_tkr(d) mask, (d) val
+ integer(4), value :: mask
+ real(4), value :: val
+ end function
+ attributes(device) integer function match_any_syncjd(mask, val)
+!dir$ ignore_tkr(d) mask, (d) val
+ integer(4), value :: mask
+ real(8), value :: val
+ end function
+end interface
+
end module
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 6a5524102c0ea..88461171a90de 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -112,6 +112,25 @@ end
! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
+attributes(device) subroutine testMatchAny()
+ integer :: a, mask, v32
+ integer(8) :: v64
+ real(4) :: r4
+ real(8) :: r8
+ a = match_any_sync(mask, v32)
+ a = match_any_sync(mask, v64)
+ a = match_any_sync(mask, r4)
+ a = match_any_sync(mask, r8)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtestmatchany()
+! CHECK: fir.call @llvm.nvvm.match.any.sync.i32p
+! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
+! CHECK: fir.convert %{{.*}} : (f32) -> i32
+! CHECK: fir.call @llvm.nvvm.match.any.sync.i32p
+! CHECK: fir.convert %{{.*}} : (f64) -> i64
+! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
+
! CHECK: func.func private @llvm.nvvm.barrier0()
! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32)
! CHECK: func.func private @llvm.nvvm.membar.gl()
@@ -120,3 +139,5 @@ end
! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
+! CHECK: func.func private @llvm.nvvm.match.any.sync.i32p(i32, i32) -> i32
+! CHECK: func.func private @llvm.nvvm.match.any.sync.i64p(i32, i64) -> i32
More information about the flang-commits
mailing list