[flang-commits] [flang] [flang][cuda][NFC] Use NVVM op for match all (PR #134303)
via flang-commits
flang-commits at lists.llvm.org
Thu Apr 3 13:29:27 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/134303.diff
2 Files Affected:
- (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+15-23)
- (modified) flang/test/Lower/CUDA/cuda-device-proc.cuf (+4-6)
``````````diff
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index a562d9b7e461c..349345c1a2ca0 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -6478,31 +6478,23 @@ IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
assert(args.size() == 3);
bool is32 = args[1].getType().isInteger(32) || args[1].getType().isF32();
- llvm::StringRef funcName =
- is32 ? "llvm.nvvm.match.all.sync.i32p" : "llvm.nvvm.match.all.sync.i64p";
- mlir::MLIRContext *context = builder.getContext();
- mlir::Type i32Ty = builder.getI32Type();
- mlir::Type i64Ty = builder.getI64Type();
mlir::Type i1Ty = builder.getI1Type();
- mlir::Type retTy = mlir::TupleType::get(context, {resultType, i1Ty});
- mlir::Type valTy = is32 ? i32Ty : i64Ty;
+ mlir::MLIRContext *context = builder.getContext();
- mlir::FunctionType ftype =
- mlir::FunctionType::get(context, {i32Ty, valTy}, {retTy});
- auto funcOp = builder.createFunction(loc, funcName, ftype);
- llvm::SmallVector<mlir::Value> filteredArgs;
- filteredArgs.push_back(args[0]);
- if (args[1].getType().isF32() || args[1].getType().isF64())
- filteredArgs.push_back(builder.create<fir::ConvertOp>(loc, valTy, args[1]));
- else
- filteredArgs.push_back(args[1]);
- auto call = builder.create<fir::CallOp>(loc, funcOp, filteredArgs);
- auto zero = builder.getIntegerAttr(builder.getIndexType(), 0);
- auto value = builder.create<fir::ExtractValueOp>(
- loc, resultType, call.getResult(0), builder.getArrayAttr(zero));
- auto one = builder.getIntegerAttr(builder.getIndexType(), 1);
- auto pred = builder.create<fir::ExtractValueOp>(loc, i1Ty, call.getResult(0),
- builder.getArrayAttr(one));
+ mlir::Value arg1 = args[1];
+ if (arg1.getType().isF32() || arg1.getType().isF64())
+ arg1 = builder.create<fir::ConvertOp>(
+ loc, is32 ? builder.getI32Type() : builder.getI64Type(), arg1);
+
+ mlir::Type retTy =
+ mlir::LLVM::LLVMStructType::getLiteral(context, {resultType, i1Ty});
+ auto match =
+ builder
+ .create<mlir::NVVM::MatchSyncOp>(loc, retTy, args[0], arg1,
+ mlir::NVVM::MatchSyncKind::all)
+ .getResult();
+ auto value = builder.create<mlir::LLVM::ExtractValueOp>(loc, match, 0);
+ auto pred = builder.create<mlir::LLVM::ExtractValueOp>(loc, match, 1);
auto conv = builder.create<mlir::LLVM::ZExtOp>(loc, resultType, pred);
builder.create<fir::StoreOp>(loc, conv, args[2]);
return value;
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index dbce4a5fa47dd..016d3bd1f1511 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -124,12 +124,10 @@ attributes(device) subroutine testMatch()
end subroutine
! CHECK-LABEL: func.func @_QPtestmatch()
-! CHECK: fir.call @llvm.nvvm.match.all.sync.i32p
-! CHECK: fir.call @llvm.nvvm.match.all.sync.i64p
-! CHECK: fir.convert %{{.*}} : (f32) -> i32
-! CHECK: fir.call @llvm.nvvm.match.all.sync.i32p
-! CHECK: fir.convert %{{.*}} : (f64) -> i64
-! CHECK: fir.call @llvm.nvvm.match.all.sync.i64p
+! CHECK: %{{.*}} = nvvm.match.sync all %{{.*}}, %{{.*}} : i32 -> !llvm.struct<(i32, i1)>
+! CHECK: %{{.*}} = nvvm.match.sync all %{{.*}}, %{{.*}} : i64 -> !llvm.struct<(i32, i1)>
+! CHECK: %{{.*}} = nvvm.match.sync all %{{.*}}, %{{.*}} : i32 -> !llvm.struct<(i32, i1)>
+! CHECK: %{{.*}} = nvvm.match.sync all %{{.*}}, %{{.*}} : i64 -> !llvm.struct<(i32, i1)>
attributes(device) subroutine testMatchAny()
integer :: a, mask, v32
``````````
</details>
https://github.com/llvm/llvm-project/pull/134303
More information about the flang-commits
mailing list