[Mlir-commits] [mlir] [mlir][GPU] Improve `gpu.module` op implementation (PR #102866)

Matthias Springer llvmlistbot at llvm.org
Mon Aug 12 02:17:01 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/102866

>From 5bac6a0aa162f788c6137f4ec4747f66e8d354f5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 12 Aug 2024 11:05:48 +0200
Subject: [PATCH] [mlir][GPU] Improve `gpu.module` op implementation

- Replace hand-written parser/printer with auto-generated assembly format.
- Remove implicit `gpu.module_end` terminator and use the `NoTerminator` trait instead. (Same as `builtin.module`.)
- Turn the region into a graph region. (Same as `builtin.module`.)
---
 mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h |  1 +
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    | 31 ++++----
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        |  2 +-
 .../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp      |  2 +-
 mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 15 +---
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 70 +------------------
 .../Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir |  2 +-
 .../Conversion/GPUToNVVM/gpu-to-nvvm.mlir     |  2 +-
 8 files changed, 20 insertions(+), 105 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 57acd72610415f..7b53594a1c8e28 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -22,6 +22,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index c57d291552e606..a024c3018eb8d3 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -21,6 +21,7 @@ include "mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td"
 include "mlir/Dialect/GPU/IR/ParallelLoopMapperAttr.td"
 include "mlir/IR/CommonTypeConstraints.td"
 include "mlir/IR/EnumAttr.td"
+include "mlir/IR/RegionKindInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/DataLayoutInterfaces.td"
@@ -1347,10 +1348,7 @@ def GPU_BarrierOp : GPU_Op<"barrier"> {
 
 def GPU_GPUModuleOp : GPU_Op<"module", [
       DataLayoutOpInterface, HasDefaultDLTIDataLayout, IsolatedFromAbove,
-      SymbolTable, Symbol, SingleBlockImplicitTerminator<"ModuleEndOp">
-    ]>, Arguments<(ins SymbolNameAttr:$sym_name,
-          OptionalAttr<GPUNonEmptyTargetArrayAttr>:$targets,
-          OptionalAttr<OffloadingTranslationAttr>:$offloadingHandler)> {
+      NoRegionArguments, SymbolTable, Symbol] # GraphRegionNoTerminator.traits> {
   let summary = "A top level compilation unit containing code to be run on a GPU.";
   let description = [{
     GPU module contains code that is intended to be run on a GPU. A host device
@@ -1379,7 +1377,6 @@ def GPU_GPUModuleOp : GPU_Op<"module", [
     gpu.module @symbol_name {
       gpu.func {}
         ...
-      gpu.module_end
     }
     // Module with offloading handler and target attributes.
     gpu.module @symbol_name2 <#gpu.select_object<1>> [
@@ -1387,7 +1384,6 @@ def GPU_GPUModuleOp : GPU_Op<"module", [
         #rocdl.target<chip = "gfx90a">] {
       gpu.func {}
         ...
-      gpu.module_end
     }
     ```
   }];
@@ -1399,8 +1395,18 @@ def GPU_GPUModuleOp : GPU_Op<"module", [
                    "ArrayRef<Attribute>":$targets,
                    CArg<"Attribute", "{}">:$handler)>
   ];
+
+  let arguments = (ins
+      SymbolNameAttr:$sym_name,
+      OptionalAttr<GPUNonEmptyTargetArrayAttr>:$targets,
+      OptionalAttr<OffloadingTranslationAttr>:$offloadingHandler);
   let regions = (region SizedRegion<1>:$bodyRegion);
-  let hasCustomAssemblyFormat = 1;
+  let assemblyFormat = [{
+    $sym_name
+    (`<` $offloadingHandler^ `>`)?
+    ($targets^)?
+    attr-dict-with-keyword $bodyRegion
+  }];
 
   // We need to ensure the block inside the region is properly terminated;
   // the auto-generated builders do not guarantee that.
@@ -1415,17 +1421,6 @@ def GPU_GPUModuleOp : GPU_Op<"module", [
   }];
 }
 
-def GPU_ModuleEndOp : GPU_Op<"module_end", [
-  Terminator, HasParent<"GPUModuleOp">
-]> {
-  let summary = "A pseudo op that marks the end of a gpu.module.";
-  let description = [{
-    This op terminates the only block inside the only region of a `gpu.module`.
-  }];
-
-  let assemblyFormat = "attr-dict";
-}
-
 def GPU_BinaryOp : GPU_Op<"binary", [Symbol]>, Arguments<(ins
       SymbolNameAttr:$sym_name,
       OptionalAttr<OffloadingTranslationAttr>:$offloadingHandler,
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 060a1e1e82f75e..9957a5804c0b65 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -316,7 +316,7 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
                       LLVM::SinOp, LLVM::SqrtOp>();
 
   // TODO: Remove once we support replacing non-root ops.
-  target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
+  target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
 }
 
 template <typename OpTy>
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 564bab1ad92b90..93e8b080a4f672 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -335,7 +335,7 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
                       LLVM::SqrtOp>();
 
   // TODO: Remove once we support replacing non-root ops.
-  target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
+  target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
 }
 
 template <typename OpTy>
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 98340bf653d61f..b18b6344732eeb 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -90,19 +90,6 @@ class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-class GPUModuleEndConversion final
-    : public OpConversionPattern<gpu::ModuleEndOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(gpu::ModuleEndOp endOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.eraseOp(endOp);
-    return success();
-  }
-};
-
 /// Pattern to convert a gpu.return into a SPIR-V return.
 // TODO: This can go to DRR when GPU return has operands.
 class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
@@ -614,7 +601,7 @@ void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                       RewritePatternSet &patterns) {
   patterns.add<
       GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
-      GPUModuleEndConversion, GPUReturnOpConversion, GPUShuffleConversion,
+      GPUReturnOpConversion, GPUShuffleConversion,
       LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
       LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
       LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index a1f87a637a6141..eeffe829446cf9 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1736,8 +1736,7 @@ LogicalResult gpu::ReturnOp::verify() {
 void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
                         StringRef name, ArrayAttr targets,
                         Attribute offloadingHandler) {
-  ensureTerminator(*result.addRegion(), builder, result.location);
-
+  result.addRegion()->emplaceBlock();
   Properties &props = result.getOrAddProperties<Properties>();
   if (targets)
     props.targets = targets;
@@ -1753,73 +1752,6 @@ void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
         offloadingHandler);
 }
 
-ParseResult GPUModuleOp::parse(OpAsmParser &parser, OperationState &result) {
-  StringAttr nameAttr;
-  ArrayAttr targetsAttr;
-
-  if (parser.parseSymbolName(nameAttr))
-    return failure();
-
-  Properties &props = result.getOrAddProperties<Properties>();
-  props.setSymName(nameAttr);
-
-  // Parse the optional offloadingHandler
-  if (succeeded(parser.parseOptionalLess())) {
-    if (parser.parseAttribute(props.offloadingHandler))
-      return failure();
-    if (parser.parseGreater())
-      return failure();
-  }
-
-  // Parse the optional array of target attributes.
-  OptionalParseResult targetsAttrResult =
-      parser.parseOptionalAttribute(targetsAttr, Type{});
-  if (targetsAttrResult.has_value()) {
-    if (failed(*targetsAttrResult)) {
-      return failure();
-    }
-    props.targets = targetsAttr;
-  }
-
-  // If module attributes are present, parse them.
-  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
-    return failure();
-
-  // Parse the module body.
-  auto *body = result.addRegion();
-  if (parser.parseRegion(*body, {}))
-    return failure();
-
-  // Ensure that this module has a valid terminator.
-  GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
-  return success();
-}
-
-void GPUModuleOp::print(OpAsmPrinter &p) {
-  p << ' ';
-  p.printSymbolName(getName());
-
-  if (Attribute attr = getOffloadingHandlerAttr()) {
-    p << " <";
-    p.printAttribute(attr);
-    p << ">";
-  }
-
-  if (Attribute attr = getTargetsAttr()) {
-    p << ' ';
-    p.printAttribute(attr);
-    p << ' ';
-  }
-
-  p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
-                                     {mlir::SymbolTable::getSymbolAttrName(),
-                                      getTargetsAttrName(),
-                                      getOffloadingHandlerAttrName()});
-  p << ' ';
-  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
-                /*printBlockTerminators=*/false);
-}
-
 bool GPUModuleOp::hasTarget(Attribute target) {
   if (ArrayAttr targets = getTargetsAttr())
     return llvm::count(targets.getValue(), target);
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir
index 7b873463a5f98f..1a22ba662cbf74 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-32b.mlir
@@ -67,7 +67,7 @@ module attributes {transform.with_named_sequence} {
         {index_bitwidth = 32, use_opaque_pointers = true}
     } {
       legal_dialects = ["llvm", "memref", "nvvm"],
-      legal_ops = ["func.func", "gpu.module", "gpu.module_end", "gpu.yield"],
+      legal_ops = ["func.func", "gpu.module", "gpu.yield"],
       illegal_dialects = ["gpu"],
       illegal_ops = ["llvm.cos", "llvm.exp", "llvm.exp2", "llvm.fabs", "llvm.fceil",
                     "llvm.ffloor", "llvm.log", "llvm.log10", "llvm.log2", "llvm.pow",
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index c23b11e46b24c7..8f2ec289c9252c 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -942,7 +942,7 @@ module attributes {transform.with_named_sequence} {
         use_bare_ptr_call_conv = false}
     } {
       legal_dialects = ["llvm", "memref", "nvvm", "test"],
-      legal_ops = ["func.func", "gpu.module", "gpu.module_end", "gpu.yield"],
+      legal_ops = ["func.func", "gpu.module", "gpu.yield"],
       illegal_dialects = ["gpu"],
       illegal_ops = ["llvm.copysign", "llvm.cos", "llvm.exp", "llvm.exp2", "llvm.fabs", "llvm.fceil",
                     "llvm.ffloor", "llvm.fma", "llvm.frem", "llvm.log", "llvm.log10", "llvm.log2", "llvm.pow",



More information about the Mlir-commits mailing list