[Mlir-commits] [mlir] [MLIR][GPU] Lower subgroup query ops in gpu-to-llvm-spv (PR #108839)
Victor Perez
llvmlistbot at llvm.org
Tue Sep 17 00:50:25 PDT 2024
================
@@ -316,6 +316,43 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
}
};
+//===----------------------------------------------------------------------===//
+// Subgroup query ops.
+//===----------------------------------------------------------------------===//
+
+template <typename SubgroupOp>
+struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern<SubgroupOp> {
+ using ConvertOpToLLVMPattern<SubgroupOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(SubgroupOp op, typename SubgroupOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ constexpr StringRef funcName = [] {
+ if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupIdOp>) {
+ return "_Z16get_sub_group_id";
+ } else if constexpr (std::is_same_v<SubgroupOp, gpu::LaneIdOp>) {
+ return "_Z22get_sub_group_local_id";
+ } else if constexpr (std::is_same_v<SubgroupOp, gpu::NumSubgroupsOp>) {
+ return "_Z18get_num_sub_groups";
+ } else if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupSizeOp>) {
+ return "_Z18get_sub_group_size";
+ }
+ }();
+
+ Operation *moduleOp =
+ op->template getParentWithTrait<OpTrait::SymbolTable>();
+ Type resultType = rewriter.getI32Type();
+ LLVM::LLVMFuncOp func =
+ lookupOrCreateSPIRVFn(moduleOp, funcName, {}, resultType,
+ /*isMemNone=*/false, /*isConvergent=*/false);
+
+ Location loc = op->getLoc();
+ Value result = createSPIRVBuiltinCall(loc, rewriter, func, {}).getResult();
+ rewriter.replaceOp(op, result);
----------------
victor-eds wrote:
We are missing cast from `i32` to `i64` (index type) for index-bitwidth other than 32.
https://github.com/llvm/llvm-project/pull/108839
More information about the Mlir-commits
mailing list