[Mlir-commits] [mlir] 56f60a1 - [mlir][spirv] Use SingleBlock + NoTerminator for spv.module

Lei Zhang llvmlistbot at llvm.org
Wed Jun 9 11:04:13 PDT 2021


Author: Lei Zhang
Date: 2021-06-09T14:00:06-04:00
New Revision: 56f60a1ce7656654d4b2f0cc42b2c5a15653db83

URL: https://github.com/llvm/llvm-project/commit/56f60a1ce7656654d4b2f0cc42b2c5a15653db83
DIFF: https://github.com/llvm/llvm-project/commit/56f60a1ce7656654d4b2f0cc42b2c5a15653db83.diff

LOG: [mlir][spirv] Use SingleBlock + NoTerminator for spv.module

This allows us to remove the `spv.mlir.endmodule` op and
all the code associated with it.

Along the way, tightened the APIs for `spv.module` a bit
by removing some aliases. Now we use `getRegion` to get
the only region, and `getBody` to get the region's only
block.

Reviewed By: mravishankar, hanchung

Differential Revision: https://reviews.llvm.org/D103265

Added: 
    

Modified: 
    mlir/docs/Dialects/SPIR-V.md
    mlir/docs/SPIRVToLLVMDialectConversion.md
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
    mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
    mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
    mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
    mlir/test/Conversion/SPIRVToLLVM/module-ops-to-llvm.mlir
    mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
    mlir/unittests/Dialect/SPIRV/SerializationTest.cpp

Removed: 
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.td


################################################################################
diff  --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md
index 15cc080d50842..9e9106b3f5382 100644
--- a/mlir/docs/Dialects/SPIR-V.md
+++ b/mlir/docs/Dialects/SPIR-V.md
@@ -92,8 +92,8 @@ The SPIR-V dialect adopts the following conventions for IR:
     (de)serialization.
 *   Ops with `mlir.snake_case` names are those that have no corresponding
     instructions (or concepts) in the binary format. They are introduced to
-    satisfy MLIR structural requirements. For example, `spv.mlir.endmodule` and
-    `spv.mlir.merge`. They map to no instructions during (de)serialization.
+    satisfy MLIR structural requirements. For example, `spv.mlir.merge`. They
+    map to no instructions during (de)serialization.
 
 (TODO: consider merging the last two cases and adopting `spv.mlir.` prefix for
 them.)

diff  --git a/mlir/docs/SPIRVToLLVMDialectConversion.md b/mlir/docs/SPIRVToLLVMDialectConversion.md
index 0b978f59d517e..ff6bb9629e5e8 100644
--- a/mlir/docs/SPIRVToLLVMDialectConversion.md
+++ b/mlir/docs/SPIRVToLLVMDialectConversion.md
@@ -810,8 +810,6 @@ Module in SPIR-V has one region that contains one block. It is defined via
 `spv.module` is converted into `ModuleOp`. This plays a role of enclosing scope
 to LLVM ops. At the moment, SPIR-V module attributes are ignored.
 
-`spv.mlir.endmodule` is mapped to an equivalent terminator `ModuleTerminatorOp`.
-
 ## `mlir-spirv-cpu-runner`
 
 `mlir-spirv-cpu-runner` allows to execute `gpu` dialect kernel on the CPU via

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index 95daa9887de01..683c6dab183fb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -407,9 +407,8 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
 // -----
 
 def SPV_ModuleOp : SPV_Op<"module",
-                          [IsolatedFromAbove,
-                           SingleBlockImplicitTerminator<"ModuleEndOp">,
-                           SymbolTable, Symbol]> {
+    [IsolatedFromAbove, NoRegionArguments, NoTerminator,
+     SingleBlock, SymbolTable, Symbol]> {
   let summary = "The top-level op that defines a SPIR-V module";
 
   let description = [{
@@ -426,7 +425,7 @@ def SPV_ModuleOp : SPV_Op<"module",
     implicitly capture values from the enclosing environment.
 
     This op has only one region, which only contains one block. The block
-    must be terminated via the `spv.mlir.endmodule` op.
+    has no terminator.
 
     <!-- End of AutoGen section -->
 
@@ -463,7 +462,7 @@ def SPV_ModuleOp : SPV_Op<"module",
 
   let results = (outs);
 
-  let regions = (region SizedRegion<1>:$body);
+  let regions = (region AnyRegion);
 
   let builders = [
     OpBuilder<(ins CArg<"Optional<StringRef>", "llvm::None">:$name)>,
@@ -487,40 +486,11 @@ def SPV_ModuleOp : SPV_Op<"module",
     Optional<StringRef> getName() { return sym_name(); }
 
     static StringRef getVCETripleAttrName() { return "vce_triple"; }
-
-    Block& getBlock() {
-      return this->getOperation()->getRegion(0).front();
-    }
   }];
 }
 
 // -----
 
-def SPV_ModuleEndOp : SPV_Op<"mlir.endmodule", [InModuleScope, Terminator]> {
-  let summary = "The pseudo op that ends a SPIR-V module";
-
-  let description = [{
-    This op terminates the only block inside a `spv.module`'s only region.
-    This op does not have a corresponding SPIR-V instruction and thus will
-    not be serialized into the binary format; it is used solely to satisfy
-    the structual requirement that an block must be ended with a terminator.
-  }];
-
-  let arguments = (ins);
-
-  let results = (outs);
-
-  let assemblyFormat = "attr-dict";
-
-  let verifier = [{ return success(); }];
-
-  let hasOpcode = 0;
-
-  let autogenSerialization = 0;
-}
-
-// -----
-
 def SPV_ReferenceOfOp : SPV_Op<"mlir.referenceof", [NoSideEffect]> {
   let summary = "Reference a specialization constant.";
 

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
index 01fa3e1448437..3dc53c2a845a5 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
@@ -1,14 +1,9 @@
-set(LLVM_TARGET_DEFINITIONS GPUToSPIRV.td)
-mlir_tablegen(GPUToSPIRV.cpp.inc -gen-rewriters)
-add_public_tablegen_target(MLIRGPUToSPIRVIncGen)
-
 add_mlir_conversion_library(MLIRGPUToSPIRV
   GPUToSPIRV.cpp
   GPUToSPIRVPass.cpp
 
   DEPENDS
   MLIRConversionPassIncGen
-  MLIRGPUToSPIRVIncGen
 
   LINK_LIBS PUBLIC
   MLIRGPU

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index fa4bbff5bb327..5aaec815a83b1 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -85,6 +85,19 @@ class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+class GPUModuleEndConversion final
+    : public OpConversionPattern<gpu::ModuleEndOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::ModuleEndOp endOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.eraseOp(endOp);
+    return success();
+  }
+};
+
 /// Pattern to convert a gpu.return into a SPIR-V return.
 // TODO: This can go to DRR when GPU return has operands.
 class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
@@ -301,12 +314,10 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
       StringRef(spvModuleName));
 
   // Move the region from the module op into the SPIR-V module.
-  Region &spvModuleRegion = spvModule.body();
+  Region &spvModuleRegion = spvModule.getRegion();
   rewriter.inlineRegionBefore(moduleOp.body(), spvModuleRegion,
                               spvModuleRegion.begin());
-  // The spv.module build method adds a block with a terminator. Remove that
-  // block. The terminator of the module op in the remaining block will be
-  // legalized later.
+  // The spv.module build method adds a block. Remove that.
   rewriter.eraseBlock(&spvModuleRegion.back());
   rewriter.eraseOp(moduleOp);
   return success();
@@ -330,15 +341,11 @@ LogicalResult GPUReturnOpConversion::matchAndRewrite(
 // GPU To SPIRV Patterns.
 //===----------------------------------------------------------------------===//
 
-namespace {
-#include "GPUToSPIRV.cpp.inc"
-}
-
 void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                       RewritePatternSet &patterns) {
-  populateWithGenerated(patterns);
   patterns.add<
-      GPUFuncOpConversion, GPUModuleConversion, GPUReturnOpConversion,
+      GPUFuncOpConversion, GPUModuleConversion, GPUModuleEndConversion,
+      GPUReturnOpConversion,
       LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
       LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
       LaunchConfigConversion<gpu::ThreadIdOp,

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.td b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.td
deleted file mode 100644
index 9615582557483..0000000000000
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.td
+++ /dev/null
@@ -1,22 +0,0 @@
-//===-- GPUToSPIRV.td - GPU to SPIR-V Dialect Lowerings ----*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file contains patterns to lower GPU dialect ops to to SPIR-V ops.
-//
-//===----------------------------------------------------------------------===//
-
-
-#ifndef MLIR_CONVERSION_GPU_TO_SPIRV
-#define MLIR_CONVERSION_GPU_TO_SPIRV
-
-include "mlir/Dialect/GPU/GPUOps.td"
-include "mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td"
-
-def : Pat<(GPU_ModuleEndOp), (SPV_ModuleEndOp)>;
-
-#endif // MLIR_CONVERSION_GPU_TO_SPIRV

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 1914812789e42..7913c09923030 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1342,7 +1342,7 @@ class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
 
     auto newModuleOp =
         rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
-    rewriter.inlineRegionBefore(spvModuleOp.body(), newModuleOp.getBody());
+    rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
 
     // Remove the terminator block that was automatically added by builder
     rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
@@ -1351,20 +1351,6 @@ class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
   }
 };
 
-class ModuleEndConversionPattern
-    : public SPIRVToLLVMConversion<spirv::ModuleEndOp> {
-public:
-  using SPIRVToLLVMConversion<spirv::ModuleEndOp>::SPIRVToLLVMConversion;
-
-  LogicalResult
-  matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-
-    rewriter.eraseOp(moduleEndOp);
-    return success();
-  }
-};
-
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1507,8 +1493,7 @@ void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
 
 void mlir::populateSPIRVToLLVMModuleConversionPatterns(
     LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
-  patterns.add<ModuleConversionPattern, ModuleEndConversionPattern>(
-      patterns.getContext(), typeConverter);
+  patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 374fdadf4b783..471fa2c1b4a35 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2529,7 +2529,8 @@ static LogicalResult verify(spirv::MergeOp mergeOp) {
 
 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
                             Optional<StringRef> name) {
-  ensureTerminator(*state.addRegion(), builder, state.location);
+  OpBuilder::InsertionGuard guard(builder);
+  builder.createBlock(state.addRegion());
   if (name) {
     state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
                             builder.getStringAttr(*name));
@@ -2545,7 +2546,8 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
       builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
   state.addAttribute("memory_model", builder.getI32IntegerAttr(
                                          static_cast<int32_t>(memoryModel)));
-  ensureTerminator(*state.addRegion(), builder, state.location);
+  OpBuilder::InsertionGuard guard(builder);
+  builder.createBlock(state.addRegion());
   if (name) {
     state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
                             builder.getStringAttr(*name));
@@ -2581,7 +2583,10 @@ static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
     return failure();
 
-  spirv::ModuleOp::ensureTerminator(*body, parser.getBuilder(), state.location);
+  // Make sure we have at least one block.
+  if (body->empty())
+    body->push_back(new Block());
+
   return success();
 }
 
@@ -2608,8 +2613,7 @@ static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
   }
 
   printer.printOptionalAttrDictWithKeyword(moduleOp->getAttrs(), elidedAttrs);
-  printer.printRegion(moduleOp.body(), /*printEntryBlockArgs=*/false,
-                      /*printBlockTerminators=*/false);
+  printer.printRegion(moduleOp.getRegion());
 }
 
 static LogicalResult verify(spirv::ModuleOp moduleOp) {
@@ -2619,7 +2623,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
       entryPoints;
   SymbolTable table(moduleOp);
 
-  for (auto &op : moduleOp.getBlock()) {
+  for (auto &op : *moduleOp.getBody()) {
     if (op.getDialect() != dialect)
       return op.emitError("'spv.module' can only contain spv.* ops");
 

diff  --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
index 61494866850c9..5fc948a070900 100644
--- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
+++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
@@ -134,7 +134,7 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
 
   auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
       modules[0].getLoc(), addressingModel, memoryModel);
-  combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody());
+  combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
 
   // In some cases, a symbol in the (current state of the) combined module is
   // renamed in order to maintain the conflicting symbol in the input module
@@ -160,7 +160,7 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
     // for spv.funcs. This way, if the conflicting op in the input module is
     // non-spv.func, we rename that symbol instead and maintain the spv.func in
     // the combined module name as it is.
-    for (auto &op : combinedModule.getBlock().without_terminator()) {
+    for (auto &op : *combinedModule.getBody()) {
       if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
         StringRef oldSymName = symbolOp.getName();
 
@@ -195,7 +195,7 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
 
     // In the current input module, rename all symbols that conflict with
     // symbols from the combined module. This includes renaming spv.funcs.
-    for (auto &op : moduleClone.getBlock().without_terminator()) {
+    for (auto &op : *moduleClone.getBody()) {
       if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
         StringRef oldSymName = symbolOp.getName();
 
@@ -225,7 +225,7 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
     }
 
     // Clone all the module's ops to the combined module.
-    for (auto &op : moduleClone.getBlock().without_terminator())
+    for (auto &op : *moduleClone.getBody())
       combinedModuleBuilder.insert(op.clone());
   }
 
@@ -233,7 +233,7 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
   DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp;
   SmallVector<SymbolOpInterface, 0> eraseList;
 
-  for (auto &op : combinedModule.getBlock().without_terminator()) {
+  for (auto &op : *combinedModule.getBody()) {
     llvm::hash_code hashCode(0);
     SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
 

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index ed8ae0a53d787..ae607f249c078 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -115,7 +115,7 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
 
   OpBuilder::InsertionGuard moduleInsertionGuard(builder);
   auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
-  builder.setInsertionPoint(spirvModule.body().front().getTerminator());
+  builder.setInsertionPointToEnd(spirvModule.getBody());
 
   // Adds the spv.EntryPointOp after collecting all the interface variables
   // needed.

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index bbe16717fa022..132e23283704a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -51,7 +51,7 @@ static inline bool isFnEntryBlock(Block *block) {
 spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
                                   MLIRContext *context)
     : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
-      module(createModuleOp()), opBuilder(module->body()) {}
+      module(createModuleOp()), opBuilder(module->getRegion()) {}
 
 LogicalResult spirv::Deserializer::deserialize() {
   LLVM_DEBUG(llvm::dbgs() << "+++ starting deserialization +++\n");

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 773fa863c0811..64aa8806c0e6b 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -99,7 +99,7 @@ LogicalResult Serializer::serialize() {
 
   // Iterate over the module body to serialize it. Assumptions are that there is
   // only one basic block in the moduleOp
-  for (auto &op : module.getBlock()) {
+  for (auto &op : *module.getBody()) {
     if (failed(processOperation(&op))) {
       return failure();
     }
@@ -1090,7 +1090,6 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
         return processGlobalVariableOp(op);
       })
       .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
-      .Case([&](spirv::ModuleEndOp) { return success(); })
       .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
       .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
       .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })

diff  --git a/mlir/test/Conversion/SPIRVToLLVM/module-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/module-ops-to-llvm.mlir
index 03baad77647bd..c75214920b5ac 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/module-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/module-ops-to-llvm.mlir
@@ -13,12 +13,6 @@ spv.module @foo Logical GLSL450 {}
 // CHECK: module
 spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], [SPV_KHR_16bit_storage]> {}
 
-// CHECK: module
-spv.module Logical GLSL450 {
-	// CHECK: }
-  spv.mlir.endmodule
-}
-
 // CHECK: module
 spv.module Logical GLSL450 {
 	// CHECK-LABEL: llvm.func @empty()

diff  --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 311cdea571c78..798b843874a5d 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -425,12 +425,6 @@ spv.module Logical GLSL450
   requires #spv.vce<v1.0, [Shader], [SPV_KHR_16bit_storage]>
   attributes {foo = "bar"} { }
 
-// Module with explicit spv.mlir.endmodule
-// CHECK: spv.module
-spv.module Logical GLSL450 {
-  spv.mlir.endmodule
-}
-
 // Module with function
 // CHECK: spv.module
 spv.module Logical GLSL450 {
@@ -476,15 +470,6 @@ spv.module Logical GLSL450 {
 
 // -----
 
-// Module with wrong terminator
-// expected-error at +2 {{expects regions to end with 'spv.mlir.endmodule'}}
-// expected-note at +1 {{in custom textual format, the absence of terminator implies 'spv.mlir.endmodule'}}
-"spv.module"() ({
-  %0 = spv.Constant true
-}) {addressing_model = 0 : i32, memory_model = 1 : i32} : () -> ()
-
-// -----
-
 // Use non SPIR-V op inside module
 spv.module Logical GLSL450 {
   // expected-error @+1 {{'spv.module' can only contain spv.* ops}}
@@ -511,17 +496,6 @@ spv.module Logical GLSL450 {
 
 // -----
 
-//===----------------------------------------------------------------------===//
-// spv.mlir.endmodule
-//===----------------------------------------------------------------------===//
-
-func @module_end_not_in_module() -> () {
-  // expected-error @+1 {{op must appear in a module-like op's block}}
-  spv.mlir.endmodule
-}
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // spv.mlir.referenceof
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index d92c52438e8c5..a3ae8bc3a9e75 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -59,7 +59,7 @@ class SerializationTest : public ::testing::Test {
   }
 
   Type getFloatStructType() {
-    OpBuilder opBuilder(module->body());
+    OpBuilder opBuilder(module->getRegion());
     llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
     llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
     auto structType = spirv::StructType::get(elementTypes, offsetInfo);
@@ -67,7 +67,7 @@ class SerializationTest : public ::testing::Test {
   }
 
   void addGlobalVar(Type type, llvm::StringRef name) {
-    OpBuilder opBuilder(module->body());
+    OpBuilder opBuilder(module->getRegion());
     auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
     opBuilder.create<spirv::GlobalVariableOp>(
         UnknownLoc::get(&context), TypeAttr::get(ptrType),


        


More information about the Mlir-commits mailing list