[Mlir-commits] [mlir] d7461b3 - [MLIR][SPIRV] Added optional name to SPIR-V module

George Mitenkov llvmlistbot at llvm.org
Wed Aug 26 21:45:10 PDT 2020


Author: George Mitenkov
Date: 2020-08-27T07:32:31+03:00
New Revision: d7461b31e7ef46b7a57ff4d68f8f47e5b804a25a

URL: https://github.com/llvm/llvm-project/commit/d7461b31e7ef46b7a57ff4d68f8f47e5b804a25a
DIFF: https://github.com/llvm/llvm-project/commit/d7461b31e7ef46b7a57ff4d68f8f47e5b804a25a.diff

LOG: [MLIR][SPIRV] Added optional name to SPIR-V module

This patch adds an optional name to SPIR-V module.
This will help with lowering from GPU dialect (so that we
can pass the kernel module name) and will be more naturally
aligned with `GPUModuleOp`/`ModuleOp`.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D86386

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/structure-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index 034b7d1b09c7e..84e59b6be8bc3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -361,7 +361,7 @@ def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope, Symbol]> {
 def SPV_ModuleOp : SPV_Op<"module",
                           [IsolatedFromAbove,
                            SingleBlockImplicitTerminator<"ModuleEndOp">,
-                           SymbolTable]> {
+                           SymbolTable, Symbol]> {
   let summary = "The top-level op that defines a SPIR-V module";
 
   let description = [{
@@ -409,7 +409,8 @@ def SPV_ModuleOp : SPV_Op<"module",
   let arguments = (ins
     SPV_AddressingModelAttr:$addressing_model,
     SPV_MemoryModelAttr:$memory_model,
-    OptionalAttr<SPV_VerCapExtAttr>:$vce_triple
+    OptionalAttr<SPV_VerCapExtAttr>:$vce_triple,
+    OptionalAttr<StrAttr>:$sym_name
   );
 
   let results = (outs);
@@ -417,10 +418,12 @@ def SPV_ModuleOp : SPV_Op<"module",
   let regions = (region SizedRegion<1>:$body);
 
   let builders = [
-    OpBuilder<[{OpBuilder &, OperationState &state}]>,
+    OpBuilder<[{OpBuilder &, OperationState &state,
+                Optional<StringRef> name = llvm::None}]>,
     OpBuilder<[{OpBuilder &, OperationState &state,
                 spirv::AddressingModel addressing_model,
-                spirv::MemoryModel memory_model}]>
+                spirv::MemoryModel memory_model,
+                Optional<StringRef> name = llvm::None}]>
   ];
 
   // We need to ensure the block inside the region is properly terminated;
@@ -432,6 +435,11 @@ def SPV_ModuleOp : SPV_Op<"module",
   let autogenSerialization = 0;
 
   let extraClassDeclaration = [{
+
+    bool isOptionalSymbol() { return true; }
+
+    Optional<StringRef> getName() { return sym_name(); }
+
     static StringRef getVCETripleAttrName() { return "vce_triple"; }
 
     Block& getBlock() {

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 88ca71ac18acd..7aeecdbbac354 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -2282,24 +2282,39 @@ static LogicalResult verify(spirv::MergeOp mergeOp) {
 // spv.module
 //===----------------------------------------------------------------------===//
 
-void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state) {
+void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
+                            Optional<StringRef> name) {
   ensureTerminator(*state.addRegion(), builder, state.location);
+  if (name) {
+    state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
+                            builder.getStringAttr(*name));
+  }
 }
 
 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
-                            spirv::AddressingModel addressing_model,
-                            spirv::MemoryModel memory_model) {
+                            spirv::AddressingModel addressingModel,
+                            spirv::MemoryModel memoryModel,
+                            Optional<StringRef> name) {
   state.addAttribute(
       "addressing_model",
-      builder.getI32IntegerAttr(static_cast<int32_t>(addressing_model)));
+      builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
   state.addAttribute("memory_model", builder.getI32IntegerAttr(
-                                         static_cast<int32_t>(memory_model)));
+                                         static_cast<int32_t>(memoryModel)));
   ensureTerminator(*state.addRegion(), builder, state.location);
+  if (name) {
+    state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
+                            builder.getStringAttr(*name));
+  }
 }
 
 static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
   Region *body = state.addRegion();
 
+  // If the name is present, parse it.
+  StringAttr nameAttr;
+  parser.parseOptionalSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
+                                 state.attributes);
+
   // Parse attributes
   spirv::AddressingModel addrModel;
   spirv::MemoryModel memoryModel;
@@ -2328,13 +2343,19 @@ static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
 static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
   printer << spirv::ModuleOp::getOperationName();
 
+  if (Optional<StringRef> name = moduleOp.getName()) {
+    printer << ' ';
+    printer.printSymbolName(*name);
+  }
+
   SmallVector<StringRef, 2> elidedAttrs;
 
   printer << " " << spirv::stringifyAddressingModel(moduleOp.addressing_model())
           << " " << spirv::stringifyMemoryModel(moduleOp.memory_model());
   auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
   auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
-  elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName});
+  elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
+                      SymbolTable::getSymbolAttrName()});
 
   if (Optional<spirv::VerCapExtAttr> triple = moduleOp.vce_triple()) {
     printer << " requires " << *triple;

diff  --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir
index e20da2e4e6c97..98da480b83ff1 100644
--- a/mlir/test/Dialect/SPIRV/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir
@@ -372,6 +372,9 @@ spv.module Logical GLSL450 {
 // CHECK: spv.module Logical GLSL450
 spv.module Logical GLSL450 { }
 
+// Module with a name
+// CHECK: spv.module @{{.*}} Logical GLSL450
+spv.module @name Logical GLSL450 { }
 
 // Module with (version, capabilities, extensions) triple
 // CHECK: spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], [SPV_KHR_16bit_storage]>


        


More information about the Mlir-commits mailing list