[flang-commits] [flang] [flang][cuda] Lower __LDCA, __LDCS, __LDLU, __LDCV, __LDCG with arrays (PR #130357)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Fri Mar 7 14:05:10 PST 2025
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/130357
__LDCA, __LDCS, __LDLU, __LDCV, __LDCG in some form take an array argument and return an array. These functions are implemented with the return array passed as the first argument. Add custom lowering to fit the implemented c function.
>From 004803649463ff67c990548f601e2e6fa384e6a7 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 7 Mar 2025 10:53:32 -0800
Subject: [PATCH] [flang][cuda] Lower __LDCA, __LDCS, __LDLU, __LDCV, __LDCG
with arrays
---
.../flang/Optimizer/Builder/IntrinsicCall.h | 3 +
flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 151 ++++++++++++++++++
flang/test/Lower/CUDA/cuda-device-proc.cuf | 45 +++---
3 files changed, 179 insertions(+), 20 deletions(-)
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index c82e5265970c5..3301b7195d7de 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -223,6 +223,9 @@ struct IntrinsicLibrary {
fir::ExtendedValue genCount(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
void genCpuTime(llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genCshift(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
+ template <const char *fctName, int extent>
+ fir::ExtendedValue genCUDALDXXFunc(mlir::Type,
+ llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genCAssociatedCFunPtr(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genCAssociatedCPtr(mlir::Type,
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index ede3be074a820..bc3c6fcdd853d 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -106,6 +106,34 @@ using I = IntrinsicLibrary;
/// argument is an optional variable in the current scope).
static constexpr bool handleDynamicOptional = true;
+/// TODO: Move all CUDA Fortran intrinsic hanlders into its own file similar to
+/// PPC.
+static const char __ldca_i4x4[] = "__ldca_i4x4_";
+static const char __ldca_i8x2[] = "__ldca_i8x2_";
+static const char __ldca_r2x2[] = "__ldca_r2x2_";
+static const char __ldca_r4x4[] = "__ldca_r4x4_";
+static const char __ldca_r8x2[] = "__ldca_r8x2_";
+static const char __ldcg_i4x4[] = "__ldcg_i4x4_";
+static const char __ldcg_i8x2[] = "__ldcg_i8x2_";
+static const char __ldcg_r2x2[] = "__ldcg_r2x2_";
+static const char __ldcg_r4x4[] = "__ldcg_r4x4_";
+static const char __ldcg_r8x2[] = "__ldcg_r8x2_";
+static const char __ldcs_i4x4[] = "__ldcs_i4x4_";
+static const char __ldcs_i8x2[] = "__ldcs_i8x2_";
+static const char __ldcs_r2x2[] = "__ldcs_r2x2_";
+static const char __ldcs_r4x4[] = "__ldcs_r4x4_";
+static const char __ldcs_r8x2[] = "__ldcs_r8x2_";
+static const char __ldcv_i4x4[] = "__ldcv_i4x4_";
+static const char __ldcv_i8x2[] = "__ldcv_i8x2_";
+static const char __ldcv_r2x2[] = "__ldcv_r2x2_";
+static const char __ldcv_r4x4[] = "__ldcv_r4x4_";
+static const char __ldcv_r8x2[] = "__ldcv_r8x2_";
+static const char __ldlu_i4x4[] = "__ldlu_i4x4_";
+static const char __ldlu_i8x2[] = "__ldlu_i8x2_";
+static const char __ldlu_r2x2[] = "__ldlu_r2x2_";
+static const char __ldlu_r4x4[] = "__ldlu_r4x4_";
+static const char __ldlu_r8x2[] = "__ldlu_r8x2_";
+
/// Table that drives the fir generation depending on the intrinsic or intrinsic
/// module procedure one to one mapping with Fortran arguments. If no mapping is
/// defined here for a generic intrinsic, genRuntimeCall will be called
@@ -114,6 +142,106 @@ static constexpr bool handleDynamicOptional = true;
/// argument must not be lowered by value. In which case, the lowering rules
/// should be provided for all the intrinsic arguments for completeness.
static constexpr IntrinsicHandler handlers[]{
+ {"__ldca_i4x4",
+ &I::genCUDALDXXFunc<__ldca_i4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldca_i8x2",
+ &I::genCUDALDXXFunc<__ldca_i8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldca_r2x2",
+ &I::genCUDALDXXFunc<__ldca_r2x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldca_r4x4",
+ &I::genCUDALDXXFunc<__ldca_r4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldca_r8x2",
+ &I::genCUDALDXXFunc<__ldca_r8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcg_i4x4",
+ &I::genCUDALDXXFunc<__ldcg_i4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcg_i8x2",
+ &I::genCUDALDXXFunc<__ldcg_i8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcg_r2x2",
+ &I::genCUDALDXXFunc<__ldcg_r2x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcg_r4x4",
+ &I::genCUDALDXXFunc<__ldcg_r4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcg_r8x2",
+ &I::genCUDALDXXFunc<__ldcg_r8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcs_i4x4",
+ &I::genCUDALDXXFunc<__ldcs_i4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcs_i8x2",
+ &I::genCUDALDXXFunc<__ldcs_i8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcs_r2x2",
+ &I::genCUDALDXXFunc<__ldcs_r2x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcs_r4x4",
+ &I::genCUDALDXXFunc<__ldcs_r4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcs_r8x2",
+ &I::genCUDALDXXFunc<__ldcs_r8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcv_i4x4",
+ &I::genCUDALDXXFunc<__ldcv_i4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcv_i8x2",
+ &I::genCUDALDXXFunc<__ldcv_i8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcv_r2x2",
+ &I::genCUDALDXXFunc<__ldcv_r2x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcv_r4x4",
+ &I::genCUDALDXXFunc<__ldcv_r4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldcv_r8x2",
+ &I::genCUDALDXXFunc<__ldcv_r8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldlu_i4x4",
+ &I::genCUDALDXXFunc<__ldlu_i4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldlu_i8x2",
+ &I::genCUDALDXXFunc<__ldlu_i8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldlu_r2x2",
+ &I::genCUDALDXXFunc<__ldlu_r2x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldlu_r4x4",
+ &I::genCUDALDXXFunc<__ldlu_r4x4, 4>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
+ {"__ldlu_r8x2",
+ &I::genCUDALDXXFunc<__ldlu_r8x2, 2>,
+ {{{"a", asAddr}}},
+ /*isElemental=*/false},
{"abort", &I::genAbort},
{"abs", &I::genAbs},
{"achar", &I::genChar},
@@ -3544,6 +3672,29 @@ IntrinsicLibrary::genCshift(mlir::Type resultType,
return readAndAddCleanUp(resultMutableBox, resultType, "CSHIFT");
}
+// __LDCA, __LDCS, __LDLU, __LDCV
+template <const char *fctName, int extent>
+fir::ExtendedValue
+IntrinsicLibrary::genCUDALDXXFunc(mlir::Type resultType,
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 1);
+ mlir::Type resTy = fir::SequenceType::get(extent, resultType);
+ mlir::Value arg = fir::getBase(args[0]);
+ mlir::Value res = builder.create<fir::AllocaOp>(loc, resTy);
+ if (mlir::isa<fir::BaseBoxType>(arg.getType()))
+ arg = builder.create<fir::BoxAddrOp>(loc, arg);
+ mlir::FunctionType ftype =
+ mlir::FunctionType::get(arg.getContext(), {resTy, resTy}, {});
+ auto funcOp = builder.createFunction(loc, fctName, ftype);
+ llvm::SmallVector<mlir::Value> funcArgs;
+ funcArgs.push_back(res);
+ funcArgs.push_back(arg);
+ builder.create<fir::CallOp>(loc, funcOp, funcArgs);
+ mlir::Value ext =
+ builder.createIntegerConstant(loc, builder.getIndexType(), extent);
+ return fir::ArrayBoxValue(res, {ext});
+}
+
// DATE_AND_TIME
void IntrinsicLibrary::genDateAndTime(llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 4 && "date_and_time has 4 args");
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 5f39f78f8ecae..02c94235a354f 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -210,10 +210,11 @@ attributes(global) subroutine __ldXXi4(b)
end
! CHECK-LABEL: func.func @_QP__ldxxi4
-! CHECK: __ldca_i4x4
-! CHECK: __ldcg_i4x4
-! CHECK: __ldcs_i4x4
-! CHECK: __ldlu_i4x4
+! CHECK: fir.call @__ldca_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
+! CHECK: fir.call @__ldcg_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
+! CHECK: fir.call @__ldcs_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
+! CHECK: fir.call @__ldlu_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
+! CHECK: fir.call @__ldcv_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
attributes(global) subroutine __ldXXi8(b)
integer(8), device :: b(*)
@@ -226,10 +227,11 @@ attributes(global) subroutine __ldXXi8(b)
end
! CHECK-LABEL: func.func @_QP__ldxxi8
-! CHECK: __ldca_i8x2
-! CHECK: __ldcg_i8x2
-! CHECK: __ldcs_i8x2
-! CHECK: __ldlu_i8x2
+! CHECK: fir.call @__ldca_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
+! CHECK: fir.call @__ldcg_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
+! CHECK: fir.call @__ldcs_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
+! CHECK: fir.call @__ldlu_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
+! CHECK: fir.call @__ldcv_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
attributes(global) subroutine __ldXXr4(b)
real, device :: b(*)
@@ -242,10 +244,11 @@ attributes(global) subroutine __ldXXr4(b)
end
! CHECK-LABEL: func.func @_QP__ldxxr4
-! CHECK: __ldca_r4x4
-! CHECK: __ldcg_r4x4
-! CHECK: __ldcs_r4x4
-! CHECK: __ldlu_r4x4
+! CHECK: fir.call @__ldca_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
+! CHECK: fir.call @__ldcg_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
+! CHECK: fir.call @__ldcs_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
+! CHECK: fir.call @__ldlu_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
+! CHECK: fir.call @__ldcv_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
attributes(global) subroutine __ldXXr2(b)
real(2), device :: b(*)
@@ -258,10 +261,11 @@ attributes(global) subroutine __ldXXr2(b)
end
! CHECK-LABEL: func.func @_QP__ldxxr2
-! CHECK: __ldca_r2x2
-! CHECK: __ldcg_r2x2
-! CHECK: __ldcs_r2x2
-! CHECK: __ldlu_r2x2
+! CHECK: fir.call @__ldca_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
+! CHECK: fir.call @__ldcg_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
+! CHECK: fir.call @__ldcs_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
+! CHECK: fir.call @__ldlu_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
+! CHECK: fir.call @__ldcv_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
attributes(global) subroutine __ldXXr8(b)
real(8), device :: b(*)
@@ -274,7 +278,8 @@ attributes(global) subroutine __ldXXr8(b)
end
! CHECK-LABEL: func.func @_QP__ldxxr8
-! CHECK: __ldca_r8x2
-! CHECK: __ldcg_r8x2
-! CHECK: __ldcs_r8x2
-! CHECK: __ldlu_r8x2
+! CHECK: fir.call @__ldca_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
+! CHECK: fir.call @__ldcg_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
+! CHECK: fir.call @__ldcs_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
+! CHECK: fir.call @__ldlu_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
+! CHECK: fir.call @__ldcv_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
More information about the flang-commits
mailing list