[Mlir-commits] [mlir] af7aa22 - [MLIR][GPU] Lower subgroup query ops in gpu-to-llvm-spv (#108839)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 26 06:52:17 PDT 2024


Author: Finlay
Date: 2024-09-26T14:52:12+01:00
New Revision: af7aa223d27996b129a2d1a0a4540f270c9a1e03

URL: https://github.com/llvm/llvm-project/commit/af7aa223d27996b129a2d1a0a4540f270c9a1e03
DIFF: https://github.com/llvm/llvm-project/commit/af7aa223d27996b129a2d1a0a4540f270c9a1e03.diff

LOG: [MLIR][GPU] Lower subgroup query ops in gpu-to-llvm-spv (#108839)

These ops are:
* gpu.subgroup_id
* gpu.lane_id
* gpu.num_subgroups
* gpu.subgroup_size

---------

Signed-off-by: Finlay Marno <finlay.marno at codeplay.com>

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
    mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index 41a3ac76df4b78..739a34e0aa610e 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -316,6 +316,53 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// Subgroup query ops.
+//===----------------------------------------------------------------------===//
+
+template <typename SubgroupOp>
+struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern<SubgroupOp> {
+  using ConvertOpToLLVMPattern<SubgroupOp>::ConvertOpToLLVMPattern;
+  using ConvertToLLVMPattern::getTypeConverter;
+
+  LogicalResult
+  matchAndRewrite(SubgroupOp op, typename SubgroupOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    constexpr StringRef funcName = [] {
+      if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupIdOp>) {
+        return "_Z16get_sub_group_id";
+      } else if constexpr (std::is_same_v<SubgroupOp, gpu::LaneIdOp>) {
+        return "_Z22get_sub_group_local_id";
+      } else if constexpr (std::is_same_v<SubgroupOp, gpu::NumSubgroupsOp>) {
+        return "_Z18get_num_sub_groups";
+      } else if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupSizeOp>) {
+        return "_Z18get_sub_group_size";
+      }
+    }();
+
+    Operation *moduleOp =
+        op->template getParentWithTrait<OpTrait::SymbolTable>();
+    Type resultTy = rewriter.getI32Type();
+    LLVM::LLVMFuncOp func =
+        lookupOrCreateSPIRVFn(moduleOp, funcName, {}, resultTy,
+                              /*isMemNone=*/false, /*isConvergent=*/false);
+
+    Location loc = op->getLoc();
+    Value result = createSPIRVBuiltinCall(loc, rewriter, func, {}).getResult();
+
+    Type indexTy = getTypeConverter()->getIndexType();
+    if (resultTy != indexTy) {
+      if (indexTy.getIntOrFloatBitWidth() < resultTy.getIntOrFloatBitWidth()) {
+        return failure();
+      }
+      result = rewriter.create<LLVM::ZExtOp>(loc, indexTy, result);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // GPU To LLVM-SPV Pass.
 //===----------------------------------------------------------------------===//
@@ -337,7 +384,9 @@ struct GPUToLLVMSPVConversionPass final
 
     target.addIllegalOp<gpu::BarrierOp, gpu::BlockDimOp, gpu::BlockIdOp,
                         gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp,
-                        gpu::ReturnOp, gpu::ShuffleOp, gpu::ThreadIdOp>();
+                        gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp,
+                        gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp,
+                        gpu::ThreadIdOp>();
 
     populateGpuToLLVMSPVConversionPatterns(converter, patterns);
     populateGpuMemorySpaceAttributeConversions(converter);
@@ -366,11 +415,15 @@ gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace addressSpace) {
 void populateGpuToLLVMSPVConversionPatterns(LLVMTypeConverter &typeConverter,
                                             RewritePatternSet &patterns) {
   patterns.add<GPUBarrierConversion, GPUReturnOpLowering, GPUShuffleConversion,
+               GPUSubgroupOpConversion<gpu::LaneIdOp>,
+               GPUSubgroupOpConversion<gpu::NumSubgroupsOp>,
+               GPUSubgroupOpConversion<gpu::SubgroupIdOp>,
+               GPUSubgroupOpConversion<gpu::SubgroupSizeOp>,
+               LaunchConfigOpConversion<gpu::BlockDimOp>,
                LaunchConfigOpConversion<gpu::BlockIdOp>,
+               LaunchConfigOpConversion<gpu::GlobalIdOp>,
                LaunchConfigOpConversion<gpu::GridDimOp>,
-               LaunchConfigOpConversion<gpu::BlockDimOp>,
-               LaunchConfigOpConversion<gpu::ThreadIdOp>,
-               LaunchConfigOpConversion<gpu::GlobalIdOp>>(typeConverter);
+               LaunchConfigOpConversion<gpu::ThreadIdOp>>(typeConverter);
   MLIRContext *context = &typeConverter.getContext();
   unsigned privateAddressSpace =
       gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Private);

diff  --git a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
index 860bb60726352d..910105ddf69586 100644
--- a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
+++ b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
@@ -563,3 +563,36 @@ gpu.module @kernels {
     gpu.return
   }
 }
+
+// -----
+
+// Lowering of subgroup query operations
+
+// CHECK-DAG: llvm.func spir_funccc @_Z18get_sub_group_size() -> i32 attributes {no_unwind, will_return}
+// CHECK-DAG: llvm.func spir_funccc @_Z18get_num_sub_groups() -> i32 attributes {no_unwind, will_return}
+// CHECK-DAG: llvm.func spir_funccc @_Z22get_sub_group_local_id() -> i32 attributes {no_unwind, will_return}
+// CHECK-DAG: llvm.func spir_funccc @_Z16get_sub_group_id() -> i32 attributes {no_unwind, will_return}
+
+
+gpu.module @subgroup_operations {
+// CHECK-LABEL: @gpu_subgroup
+  func.func @gpu_subgroup() {
+    // CHECK:       %[[SG_ID:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32
+    // CHECK-32-NOT:                llvm.zext
+    // CHECK-64           %{{.*}} = llvm.zext %[[SG_ID]] : i32 to i64
+    %0 = gpu.subgroup_id : index
+    // CHECK: %[[SG_LOCAL_ID:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return}  : () -> i32
+    // CHECK-32-NOT:                llvm.zext
+    // CHECK-64:          %{{.*}} = llvm.zext %[[SG_LOCAL_ID]] : i32 to i64
+    %1 = gpu.lane_id
+    // CHECK:     %[[NUM_SGS:.*]] = llvm.call spir_funccc @_Z18get_num_sub_groups() {no_unwind, will_return} : () -> i32
+    // CHECK-32-NOT:                llvm.zext
+    // CHECK-64:          %{{.*}} = llvm.zext %[[NUM_SGS]] : i32 to i64
+    %2 = gpu.num_subgroups : index
+    // CHECK:     %[[SG_SIZE:.*]] = llvm.call spir_funccc @_Z18get_sub_group_size() {no_unwind, will_return} : () -> i32
+    // CHECK-32-NOT:                llvm.zext
+    // CHECK-64:          %{{.*}} = llvm.zext %[[SG_SIZE]] : i32 to i64
+    %3 = gpu.subgroup_size : index
+    return
+  }
+}


        


More information about the Mlir-commits mailing list