[Mlir-commits] [mlir] [MLIR][GPU] Lower subgroup query ops in gpu-to-llvm-spv (PR #108839)

Victor Perez llvmlistbot at llvm.org
Tue Sep 17 00:50:25 PDT 2024


================
@@ -316,6 +316,43 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// Subgroup query ops.
+//===----------------------------------------------------------------------===//
+
+template <typename SubgroupOp>
+struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern<SubgroupOp> {
+  using ConvertOpToLLVMPattern<SubgroupOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(SubgroupOp op, typename SubgroupOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    constexpr StringRef funcName = [] {
+      if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupIdOp>) {
+        return "_Z16get_sub_group_id";
+      } else if constexpr (std::is_same_v<SubgroupOp, gpu::LaneIdOp>) {
+        return "_Z22get_sub_group_local_id";
+      } else if constexpr (std::is_same_v<SubgroupOp, gpu::NumSubgroupsOp>) {
+        return "_Z18get_num_sub_groups";
+      } else if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupSizeOp>) {
+        return "_Z18get_sub_group_size";
+      }
+    }();
+
+    Operation *moduleOp =
+        op->template getParentWithTrait<OpTrait::SymbolTable>();
+    Type resultType = rewriter.getI32Type();
+    LLVM::LLVMFuncOp func =
+        lookupOrCreateSPIRVFn(moduleOp, funcName, {}, resultType,
+                              /*isMemNone=*/false, /*isConvergent=*/false);
+
+    Location loc = op->getLoc();
+    Value result = createSPIRVBuiltinCall(loc, rewriter, func, {}).getResult();
+    rewriter.replaceOp(op, result);
----------------
victor-eds wrote:

We are missing cast from `i32` to `i64` (index type) for index-bitwidth other than 32.

https://github.com/llvm/llvm-project/pull/108839


More information about the Mlir-commits mailing list