[flang-commits] [flang] [flang][cuda] Lower __LDCA, __LDCS, __LDLU, __LDCV, __LDCG with arrays (PR #130357)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Fri Mar 7 14:18:54 PST 2025


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/130357

>From 004803649463ff67c990548f601e2e6fa384e6a7 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 7 Mar 2025 10:53:32 -0800
Subject: [PATCH 1/2] [flang][cuda] Lower __LDCA, __LDCS, __LDLU, __LDCV,
 __LDCG with arrays

---
 .../flang/Optimizer/Builder/IntrinsicCall.h   |   3 +
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 151 ++++++++++++++++++
 flang/test/Lower/CUDA/cuda-device-proc.cuf    |  45 +++---
 3 files changed, 179 insertions(+), 20 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index c82e5265970c5..3301b7195d7de 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -223,6 +223,9 @@ struct IntrinsicLibrary {
   fir::ExtendedValue genCount(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   void genCpuTime(llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genCshift(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
+  template <const char *fctName, int extent>
+  fir::ExtendedValue genCUDALDXXFunc(mlir::Type,
+                                     llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genCAssociatedCFunPtr(mlir::Type,
                                            llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genCAssociatedCPtr(mlir::Type,
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index ede3be074a820..bc3c6fcdd853d 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -106,6 +106,34 @@ using I = IntrinsicLibrary;
 /// argument is an optional variable in the current scope).
 static constexpr bool handleDynamicOptional = true;
 
+/// TODO: Move all CUDA Fortran intrinsic hanlders into its own file similar to
+/// PPC.
+static const char __ldca_i4x4[] = "__ldca_i4x4_";
+static const char __ldca_i8x2[] = "__ldca_i8x2_";
+static const char __ldca_r2x2[] = "__ldca_r2x2_";
+static const char __ldca_r4x4[] = "__ldca_r4x4_";
+static const char __ldca_r8x2[] = "__ldca_r8x2_";
+static const char __ldcg_i4x4[] = "__ldcg_i4x4_";
+static const char __ldcg_i8x2[] = "__ldcg_i8x2_";
+static const char __ldcg_r2x2[] = "__ldcg_r2x2_";
+static const char __ldcg_r4x4[] = "__ldcg_r4x4_";
+static const char __ldcg_r8x2[] = "__ldcg_r8x2_";
+static const char __ldcs_i4x4[] = "__ldcs_i4x4_";
+static const char __ldcs_i8x2[] = "__ldcs_i8x2_";
+static const char __ldcs_r2x2[] = "__ldcs_r2x2_";
+static const char __ldcs_r4x4[] = "__ldcs_r4x4_";
+static const char __ldcs_r8x2[] = "__ldcs_r8x2_";
+static const char __ldcv_i4x4[] = "__ldcv_i4x4_";
+static const char __ldcv_i8x2[] = "__ldcv_i8x2_";
+static const char __ldcv_r2x2[] = "__ldcv_r2x2_";
+static const char __ldcv_r4x4[] = "__ldcv_r4x4_";
+static const char __ldcv_r8x2[] = "__ldcv_r8x2_";
+static const char __ldlu_i4x4[] = "__ldlu_i4x4_";
+static const char __ldlu_i8x2[] = "__ldlu_i8x2_";
+static const char __ldlu_r2x2[] = "__ldlu_r2x2_";
+static const char __ldlu_r4x4[] = "__ldlu_r4x4_";
+static const char __ldlu_r8x2[] = "__ldlu_r8x2_";
+
 /// Table that drives the fir generation depending on the intrinsic or intrinsic
 /// module procedure one to one mapping with Fortran arguments. If no mapping is
 /// defined here for a generic intrinsic, genRuntimeCall will be called
@@ -114,6 +142,106 @@ static constexpr bool handleDynamicOptional = true;
 /// argument must not be lowered by value. In which case, the lowering rules
 /// should be provided for all the intrinsic arguments for completeness.
 static constexpr IntrinsicHandler handlers[]{
+    {"__ldca_i4x4",
+     &I::genCUDALDXXFunc<__ldca_i4x4, 4>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldca_i8x2",
+     &I::genCUDALDXXFunc<__ldca_i8x2, 2>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldca_r2x2",
+     &I::genCUDALDXXFunc<__ldca_r2x2, 2>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldca_r4x4",
+     &I::genCUDALDXXFunc<__ldca_r4x4, 4>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldca_r8x2",
+     &I::genCUDALDXXFunc<__ldca_r8x2, 2>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldcg_i4x4",
+     &I::genCUDALDXXFunc<__ldcg_i4x4, 4>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldcg_i8x2",
+     &I::genCUDALDXXFunc<__ldcg_i8x2, 2>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldcg_r2x2",
+     &I::genCUDALDXXFunc<__ldcg_r2x2, 2>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldcg_r4x4",
+     &I::genCUDALDXXFunc<__ldcg_r4x4, 4>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldcg_r8x2",
+     &I::genCUDALDXXFunc<__ldcg_r8x2, 2>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldcs_i4x4",
+     &I::genCUDALDXXFunc<__ldcs_i4x4, 4>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldcs_i8x2",
+     &I::genCUDALDXXFunc<__ldcs_i8x2, 2>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldcs_r2x2",
+     &I::genCUDALDXXFunc<__ldcs_r2x2, 2>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldcs_r4x4",
+     &I::genCUDALDXXFunc<__ldcs_r4x4, 4>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldcs_r8x2",
+     &I::genCUDALDXXFunc<__ldcs_r8x2, 2>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldcv_i4x4",
+     &I::genCUDALDXXFunc<__ldcv_i4x4, 4>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldcv_i8x2",
+     &I::genCUDALDXXFunc<__ldcv_i8x2, 2>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldcv_r2x2",
+     &I::genCUDALDXXFunc<__ldcv_r2x2, 2>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldcv_r4x4",
+     &I::genCUDALDXXFunc<__ldcv_r4x4, 4>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldcv_r8x2",
+     &I::genCUDALDXXFunc<__ldcv_r8x2, 2>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldlu_i4x4",
+     &I::genCUDALDXXFunc<__ldlu_i4x4, 4>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldlu_i8x2",
+     &I::genCUDALDXXFunc<__ldlu_i8x2, 2>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldlu_r2x2",
+     &I::genCUDALDXXFunc<__ldlu_r2x2, 2>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldlu_r4x4",
+     &I::genCUDALDXXFunc<__ldlu_r4x4, 4>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
+    {"__ldlu_r8x2",
+     &I::genCUDALDXXFunc<__ldlu_r8x2, 2>,
+     {{{"a", asAddr}}},
+     /*isElemental=*/false},
     {"abort", &I::genAbort},
     {"abs", &I::genAbs},
     {"achar", &I::genChar},
@@ -3544,6 +3672,29 @@ IntrinsicLibrary::genCshift(mlir::Type resultType,
   return readAndAddCleanUp(resultMutableBox, resultType, "CSHIFT");
 }
 
+// __LDCA, __LDCS, __LDLU, __LDCV
+template <const char *fctName, int extent>
+fir::ExtendedValue
+IntrinsicLibrary::genCUDALDXXFunc(mlir::Type resultType,
+                                  llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 1);
+  mlir::Type resTy = fir::SequenceType::get(extent, resultType);
+  mlir::Value arg = fir::getBase(args[0]);
+  mlir::Value res = builder.create<fir::AllocaOp>(loc, resTy);
+  if (mlir::isa<fir::BaseBoxType>(arg.getType()))
+    arg = builder.create<fir::BoxAddrOp>(loc, arg);
+  mlir::FunctionType ftype =
+      mlir::FunctionType::get(arg.getContext(), {resTy, resTy}, {});
+  auto funcOp = builder.createFunction(loc, fctName, ftype);
+  llvm::SmallVector<mlir::Value> funcArgs;
+  funcArgs.push_back(res);
+  funcArgs.push_back(arg);
+  builder.create<fir::CallOp>(loc, funcOp, funcArgs);
+  mlir::Value ext =
+      builder.createIntegerConstant(loc, builder.getIndexType(), extent);
+  return fir::ArrayBoxValue(res, {ext});
+}
+
 // DATE_AND_TIME
 void IntrinsicLibrary::genDateAndTime(llvm::ArrayRef<fir::ExtendedValue> args) {
   assert(args.size() == 4 && "date_and_time has 4 args");
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 5f39f78f8ecae..02c94235a354f 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -210,10 +210,11 @@ attributes(global) subroutine __ldXXi4(b)
 end
 
 ! CHECK-LABEL: func.func @_QP__ldxxi4
-! CHECK: __ldca_i4x4
-! CHECK: __ldcg_i4x4
-! CHECK: __ldcs_i4x4
-! CHECK: __ldlu_i4x4
+! CHECK: fir.call @__ldca_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
+! CHECK: fir.call @__ldcg_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
+! CHECK: fir.call @__ldcs_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
+! CHECK: fir.call @__ldlu_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
+! CHECK: fir.call @__ldcv_i4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<?xi32>>) -> ()
 
 attributes(global) subroutine __ldXXi8(b)
   integer(8), device :: b(*)
@@ -226,10 +227,11 @@ attributes(global) subroutine __ldXXi8(b)
 end
 
 ! CHECK-LABEL: func.func @_QP__ldxxi8
-! CHECK: __ldca_i8x2
-! CHECK: __ldcg_i8x2
-! CHECK: __ldcs_i8x2
-! CHECK: __ldlu_i8x2
+! CHECK: fir.call @__ldca_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
+! CHECK: fir.call @__ldcg_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
+! CHECK: fir.call @__ldcs_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
+! CHECK: fir.call @__ldlu_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
+! CHECK: fir.call @__ldcv_i8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xi64>>, !fir.ref<!fir.array<?xi64>>) -> ()
 
 attributes(global) subroutine __ldXXr4(b)
   real, device :: b(*)
@@ -242,10 +244,11 @@ attributes(global) subroutine __ldXXr4(b)
 end
 
 ! CHECK-LABEL: func.func @_QP__ldxxr4
-! CHECK: __ldca_r4x4
-! CHECK: __ldcg_r4x4
-! CHECK: __ldcs_r4x4
-! CHECK: __ldlu_r4x4
+! CHECK: fir.call @__ldca_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
+! CHECK: fir.call @__ldcg_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
+! CHECK: fir.call @__ldcs_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
+! CHECK: fir.call @__ldlu_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
+! CHECK: fir.call @__ldcv_r4x4_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<?xf32>>) -> ()
 
 attributes(global) subroutine __ldXXr2(b)
   real(2), device :: b(*)
@@ -258,10 +261,11 @@ attributes(global) subroutine __ldXXr2(b)
 end
 
 ! CHECK-LABEL: func.func @_QP__ldxxr2
-! CHECK: __ldca_r2x2
-! CHECK: __ldcg_r2x2
-! CHECK: __ldcs_r2x2
-! CHECK: __ldlu_r2x2
+! CHECK: fir.call @__ldca_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
+! CHECK: fir.call @__ldcg_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
+! CHECK: fir.call @__ldcs_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
+! CHECK: fir.call @__ldlu_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
+! CHECK: fir.call @__ldcv_r2x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf16>>, !fir.ref<!fir.array<?xf16>>) -> ()
 
 attributes(global) subroutine __ldXXr8(b)
   real(8), device :: b(*)
@@ -274,7 +278,8 @@ attributes(global) subroutine __ldXXr8(b)
 end
 
 ! CHECK-LABEL: func.func @_QP__ldxxr8
-! CHECK: __ldca_r8x2
-! CHECK: __ldcg_r8x2
-! CHECK: __ldcs_r8x2
-! CHECK: __ldlu_r8x2
+! CHECK: fir.call @__ldca_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
+! CHECK: fir.call @__ldcg_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
+! CHECK: fir.call @__ldcs_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
+! CHECK: fir.call @__ldlu_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
+! CHECK: fir.call @__ldcv_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()

>From 74399dbb0ff95d69b67e4a258513c06575df014a Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 7 Mar 2025 14:18:42 -0800
Subject: [PATCH 2/2] Fix typo

---
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index bc3c6fcdd853d..722495e63302e 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -106,7 +106,7 @@ using I = IntrinsicLibrary;
 /// argument is an optional variable in the current scope).
 static constexpr bool handleDynamicOptional = true;
 
-/// TODO: Move all CUDA Fortran intrinsic hanlders into its own file similar to
+/// TODO: Move all CUDA Fortran intrinsic handlers into its own file similar to
 /// PPC.
 static const char __ldca_i4x4[] = "__ldca_i4x4_";
 static const char __ldca_i8x2[] = "__ldca_i8x2_";



More information about the flang-commits mailing list