[flang-commits] [flang] [flang][cuda][NFC] Use NVVM op for match all (PR #134303)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Thu Apr 3 13:28:51 PDT 2025


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/134303

None

>From 4b7fb352f69106ac2ceb628e426499e3c9677fe1 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 3 Apr 2025 13:27:35 -0700
Subject: [PATCH] [flang][cuda][NFC] Use NVVM op for match all

---
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 38 ++++++++-----------
 flang/test/Lower/CUDA/cuda-device-proc.cuf    | 10 ++---
 2 files changed, 19 insertions(+), 29 deletions(-)

diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index a562d9b7e461c..349345c1a2ca0 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -6478,31 +6478,23 @@ IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
   assert(args.size() == 3);
   bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();
 
-  llvm::StringRef funcName =
-      is32 ? "llvm.nvvm.match.all.sync.i32p" : "llvm.nvvm.match.all.sync.i64p";
-  mlir::MLIRContext *context = builder.getContext();
-  mlir::Type i32Ty = builder.getI32Type();
-  mlir::Type i64Ty = builder.getI64Type();
   mlir::Type i1Ty = builder.getI1Type();
-  mlir::Type retTy = mlir::TupleType::get(context, {resultType, i1Ty});
-  mlir::Type valTy = is32 ? i32Ty : i64Ty;
+  mlir::MLIRContext *context = builder.getContext();
 
-  mlir::FunctionType ftype =
-      mlir::FunctionType::get(context, {i32Ty, valTy}, {retTy});
-  auto funcOp = builder.createFunction(loc, funcName, ftype);
-  llvm::SmallVector<mlir::Value> filteredArgs;
-  filteredArgs.push_back(args[0]);
-  if (args[1].getType().isF32() || args[1].getType().isF64())
-    filteredArgs.push_back(builder.create<fir::ConvertOp>(loc, valTy, args[1]));
-  else
-    filteredArgs.push_back(args[1]);
-  auto call = builder.create<fir::CallOp>(loc, funcOp, filteredArgs);
-  auto zero = builder.getIntegerAttr(builder.getIndexType(), 0);
-  auto value = builder.create<fir::ExtractValueOp>(
-      loc, resultType, call.getResult(0), builder.getArrayAttr(zero));
-  auto one = builder.getIntegerAttr(builder.getIndexType(), 1);
-  auto pred = builder.create<fir::ExtractValueOp>(loc, i1Ty, call.getResult(0),
-                                                  builder.getArrayAttr(one));
+  mlir::Value arg1 = args[1];
+  if (arg1.getType().isF32() || arg1.getType().isF64())
+    arg1 = builder.create<fir::ConvertOp>(
+        loc, is32 ? builder.getI32Type() : builder.getI64Type(), arg1);
+
+  mlir::Type retTy =
+      mlir::LLVM::LLVMStructType::getLiteral(context, {resultType, i1Ty});
+  auto match =
+      builder
+          .create<mlir::NVVM::MatchSyncOp>(loc, retTy, args[0], arg1,
+                                           mlir::NVVM::MatchSyncKind::all)
+          .getResult();
+  auto value = builder.create<mlir::LLVM::ExtractValueOp>(loc, match, 0);
+  auto pred = builder.create<mlir::LLVM::ExtractValueOp>(loc, match, 1);
   auto conv = builder.create<mlir::LLVM::ZExtOp>(loc, resultType, pred);
   builder.create<fir::StoreOp>(loc, conv, args[2]);
   return value;
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index dbce4a5fa47dd..016d3bd1f1511 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -124,12 +124,10 @@ attributes(device) subroutine testMatch()
 end subroutine
 
 ! CHECK-LABEL: func.func @_QPtestmatch()
-! CHECK: fir.call @llvm.nvvm.match.all.sync.i32p
-! CHECK: fir.call @llvm.nvvm.match.all.sync.i64p
-! CHECK: fir.convert %{{.*}} : (f32) -> i32
-! CHECK: fir.call @llvm.nvvm.match.all.sync.i32p
-! CHECK: fir.convert %{{.*}} : (f64) -> i64
-! CHECK: fir.call @llvm.nvvm.match.all.sync.i64p
+! CHECK: %{{.*}} = nvvm.match.sync  all %{{.*}}, %{{.*}} : i32 -> !llvm.struct<(i32, i1)>
+! CHECK: %{{.*}} = nvvm.match.sync  all %{{.*}}, %{{.*}} : i64 -> !llvm.struct<(i32, i1)>
+! CHECK: %{{.*}} = nvvm.match.sync  all %{{.*}}, %{{.*}} : i32 -> !llvm.struct<(i32, i1)>
+! CHECK: %{{.*}} = nvvm.match.sync  all %{{.*}}, %{{.*}} : i64 -> !llvm.struct<(i32, i1)>
 
 attributes(device) subroutine testMatchAny()
   integer :: a, mask, v32



More information about the flang-commits mailing list