[Mlir-commits] [mlir] a953982 - [mlir][GPU] Plumb range information through the NVVM lowerings (#107659)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 13 10:07:54 PDT 2024


Author: Krzysztof Drewniak
Date: 2024-09-13T12:07:51-05:00
New Revision: a953982cb7dee0678bb5f7c2febe4c3b8b718c7a

URL: https://github.com/llvm/llvm-project/commit/a953982cb7dee0678bb5f7c2febe4c3b8b718c7a
DIFF: https://github.com/llvm/llvm-project/commit/a953982cb7dee0678bb5f7c2febe4c3b8b718c7a.diff

LOG: [mlir][GPU] Plumb range information through the NVVM lowerings (#107659)

Update the GPU to NVVM lowerings to correctly propagate range
information on IDs and dimension queries, etiher from
known_{block,grid}_size attributes or from `upperBound` annotations on
the operations themselves.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
    mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
    mlir/test/Target/LLVMIR/Import/nvvmir.ll
    mlir/test/Target/LLVMIR/nvvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7bbf18fe0106fb..152715f281088e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -123,52 +123,67 @@ class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
+class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
+  NVVM_SpecialRegisterOp<mnemonic, traits> {
+  let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
+  let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
+  let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
+  let mlirBuilder = baseMlirBuilder # importRangeRetAttrCode # baseMlirBuilderCoda;
+
+  // Backwards-compatibility builder for an unspecified range.
+  let builders = [
+    OpBuilder<(ins "Type":$resultType), [{
+      build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{});
+    }]>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // Lane index and range
-def NVVM_LaneIdOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.laneid">;
-def NVVM_WarpSizeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.warpsize">;
+def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
+def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
 
 //===----------------------------------------------------------------------===//
 // Thread index and range
-def NVVM_ThreadIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.x">;
-def NVVM_ThreadIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.y">;
-def NVVM_ThreadIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.z">;
-def NVVM_BlockDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.x">;
-def NVVM_BlockDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.y">;
-def NVVM_BlockDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.z">;
+def NVVM_ThreadIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.x">;
+def NVVM_ThreadIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.y">;
+def NVVM_ThreadIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.z">;
+def NVVM_BlockDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.x">;
+def NVVM_BlockDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.y">;
+def NVVM_BlockDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.z">;
 
 //===----------------------------------------------------------------------===//
 // Block index and range
-def NVVM_BlockIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.x">;
-def NVVM_BlockIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.y">;
-def NVVM_BlockIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.z">;
-def NVVM_GridDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.x">;
-def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">;
-def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">;
+def NVVM_BlockIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.x">;
+def NVVM_BlockIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.y">;
+def NVVM_BlockIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.z">;
+def NVVM_GridDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.x">;
+def NVVM_GridDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.y">;
+def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;
 
 //===----------------------------------------------------------------------===//
 // CTA Cluster index and range
-def NVVM_ClusterIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.x">;
-def NVVM_ClusterIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.y">;
-def NVVM_ClusterIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.z">;
-def NVVM_ClusterDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.x">;
-def NVVM_ClusterDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.y">;
-def NVVM_ClusterDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.z">;
+def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x">;
+def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
+def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
+def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
+def NVVM_ClusterDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.y">;
+def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.z">;
 
 
 //===----------------------------------------------------------------------===//
 // CTA index and range within Cluster
-def NVVM_BlockInClusterIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
-def NVVM_BlockInClusterIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
-def NVVM_BlockInClusterIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
-def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
-def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.y">;
-def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;
+def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
+def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
+def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
+def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
+def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y">;
+def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;
 
 //===----------------------------------------------------------------------===//
 // CTA index and across Cluster dimensions
-def NVVM_ClusterId : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctarank">;
-def NVVM_ClusterDim : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctarank">;
+def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank">;
+def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">;
 
 //===----------------------------------------------------------------------===//
 // Clock registers

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 9b1be198f77a82..164622d77e6b62 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -29,6 +29,7 @@
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -209,7 +210,15 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = op->getLoc();
     MLIRContext *context = rewriter.getContext();
-    Value newOp = rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type());
+    LLVM::ConstantRangeAttr bounds = nullptr;
+    if (std::optional<APInt> upperBound = op.getUpperBound())
+      bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
+          /*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue());
+    else
+      bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
+          /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
+    Value newOp =
+        rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type(), bounds);
     // Truncate or extend the result depending on the index bitwidth specified
     // by the LLVMTypeConverter options.
     const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
@@ -340,27 +349,40 @@ void mlir::populateGpuSubgroupReduceOpLoweringPattern(
 
 void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                                RewritePatternSet &patterns) {
+  using gpu::index_lowering::IndexKind;
+  using gpu::index_lowering::IntrType;
   populateWithGenerated(patterns);
   patterns.add<GPUPrintfOpToVPrintfLowering>(converter);
   patterns.add<
       gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
-                                      NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
+                                      NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
+      converter, IndexKind::Block, IntrType::Id);
+  patterns.add<
       gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
-                                      NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
+                                      NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
+      converter, IndexKind::Block, IntrType::Dim);
+  patterns.add<
       gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
-                                      NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>,
-      gpu::index_lowering::OpLowering<gpu::ClusterDimOp, NVVM::ClusterDimXOp,
-                                      NVVM::ClusterDimYOp, NVVM::ClusterDimZOp>,
-      gpu::index_lowering::OpLowering<
-          gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
-          NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>,
-      gpu::index_lowering::OpLowering<gpu::ClusterDimOp, NVVM::ClusterDimXOp,
-                                      NVVM::ClusterDimYOp, NVVM::ClusterDimZOp>,
-      gpu::index_lowering::OpLowering<gpu::BlockIdOp, NVVM::BlockIdXOp,
-                                      NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
-      gpu::index_lowering::OpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
-                                      NVVM::GridDimYOp, NVVM::GridDimZOp>,
-      GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(converter);
+                                      NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
+      converter, IndexKind::Other, IntrType::Id);
+  patterns.add<gpu::index_lowering::OpLowering<
+      gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
+      NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim);
+  patterns.add<gpu::index_lowering::OpLowering<
+      gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
+      NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
+      converter, IndexKind::Other, IntrType::Id);
+  patterns.add<gpu::index_lowering::OpLowering<
+      gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
+      NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim);
+  patterns.add<gpu::index_lowering::OpLowering<
+      gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
+      converter, IndexKind::Block, IntrType::Id);
+  patterns.add<gpu::index_lowering::OpLowering<
+      gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
+      converter, IndexKind::Grid, IntrType::Dim);
+  patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
+      converter);
 
   patterns.add<GPUDynamicSharedMemoryOpLowering>(
       converter, NVVM::kSharedMemoryAlignmentBit);

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
index 855abc12a909ef..bc830a77f3c580 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/Target/LLVMIR/ModuleImport.h"
 
+#include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/IntrinsicsNVPTX.h"
 
 using namespace mlir;

diff  --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index 8f2ec289c9252c..66ad1e307fc3a5 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -50,7 +50,7 @@ gpu.module @test_module_0 {
     %gDimZ = gpu.grid_dim z
 
 
-    // CHECK: = nvvm.read.ptx.sreg.laneid : i32
+    // CHECK: = nvvm.read.ptx.sreg.laneid range <i32, 0, 32> : i32
     // CHECK: = llvm.sext %{{.*}} : i32 to i64
     %laneId = gpu.lane_id
 
@@ -699,9 +699,21 @@ gpu.module @test_module_32 {
 }
 
 gpu.module @test_module_33 {
-// CHECK-LABEL: func @kernel_with_block_size()
-// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 128, 1, 1>, nvvm.kernel, nvvm.maxntid = array<i32: 128, 1, 1>}
-  gpu.func @kernel_with_block_size() kernel attributes {known_block_size = array<i32: 128, 1, 1>} {
+// CHECK-LABEL: func @kernel_with_block_size(
+// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 32, 4, 2>, nvvm.kernel, nvvm.maxntid = array<i32: 32, 4, 2>}
+  gpu.func @kernel_with_block_size(%arg0: !llvm.ptr) kernel attributes {known_block_size = array<i32: 32, 4, 2>} {
+    // CHECK: = nvvm.read.ptx.sreg.tid.x range <i32, 0, 32> : i32
+    %0 = gpu.thread_id x
+    // CHECK: = nvvm.read.ptx.sreg.tid.y range <i32, 0, 4> : i32
+    %1 = gpu.thread_id y
+    // CHECK: = nvvm.read.ptx.sreg.tid.z range <i32, 0, 2> : i32
+    %2 = gpu.thread_id z
+
+    // Fake usage to prevent dead code elimination
+    %3 = arith.addi %0, %1 : index
+    %4 = arith.addi %3, %2 : index
+    %5 = arith.index_cast %4 : index to i64
+    llvm.store %5, %arg0 : i64, !llvm.ptr
     gpu.return
   }
 }
@@ -917,6 +929,20 @@ gpu.module @test_module_48 {
   }
 }
 
+gpu.module @test_module_49 {
+// CHECK-LABEL: func @explicit_id_bounds()
+  func.func @explicit_id_bounds() -> (index, index, index) {
+    // CHECK: = nvvm.read.ptx.sreg.tid.x range <i32, 0, 32> : i32
+    %0 = gpu.thread_id x upper_bound 32
+    // CHECK: = nvvm.read.ptx.sreg.ntid.x range <i32, 1, 33> : i32
+    %1 = gpu.block_dim x upper_bound 32
+    // CHECK: = nvvm.read.ptx.sreg.laneid range <i32, 0, 16> : i32
+    %2 = gpu.lane_id upper_bound 16
+
+    return %0, %1, %2 : index, index, index
+  }
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
     %gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module

diff  --git a/mlir/test/Target/LLVMIR/Import/nvvmir.ll b/mlir/test/Target/LLVMIR/Import/nvvmir.ll
index e4a8773e2dd806..131e9065b2d883 100644
--- a/mlir/test/Target/LLVMIR/Import/nvvmir.ll
+++ b/mlir/test/Target/LLVMIR/Import/nvvmir.ll
@@ -58,6 +58,9 @@ define i32 @nvvm_special_regs() {
   %27 = call i32 @llvm.nvvm.read.ptx.sreg.cluster.ctarank()
   ; CHECK: = nvvm.read.ptx.sreg.cluster.nctarank : i32
   %28 = call i32 @llvm.nvvm.read.ptx.sreg.cluster.nctarank()
+
+  ; CHECK = nvvm.read.ptx.sreg.tid.x range <0 : i32, 64 : i32> : i32
+  %29 = call range(i32 0, 64) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
   ret i32 %1
 }
 

diff  --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 88ffb1c7bfdf7a..7fd082a5eb3c75 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -62,7 +62,10 @@ llvm.func @nvvm_special_regs() -> i32 {
   %29 = nvvm.read.ptx.sreg.clock : i32
   // CHECK: call i64 @llvm.nvvm.read.ptx.sreg.clock64
   %30 = nvvm.read.ptx.sreg.clock64 : i64
-  
+
+  // CHECK: %31 = call range(i32 0, 64) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+  %31 = nvvm.read.ptx.sreg.tid.x range <i32, 0, 64> : i32
+
   llvm.return %1 : i32
 }
 


        


More information about the Mlir-commits mailing list