[Mlir-commits] [mlir] d7461b3 - [MLIR][SPIRV] Added optional name to SPIR-V module
George Mitenkov
llvmlistbot at llvm.org
Wed Aug 26 21:45:10 PDT 2020
Author: George Mitenkov
Date: 2020-08-27T07:32:31+03:00
New Revision: d7461b31e7ef46b7a57ff4d68f8f47e5b804a25a
URL: https://github.com/llvm/llvm-project/commit/d7461b31e7ef46b7a57ff4d68f8f47e5b804a25a
DIFF: https://github.com/llvm/llvm-project/commit/d7461b31e7ef46b7a57ff4d68f8f47e5b804a25a.diff
LOG: [MLIR][SPIRV] Added optional name to SPIR-V module
This patch adds an optional name to SPIR-V module.
This will help with lowering from GPU dialect (so that we
can pass the kernel module name) and will be more naturally
aligned with `GPUModuleOp`/`ModuleOp`.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D86386
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/structure-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index 034b7d1b09c7e..84e59b6be8bc3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -361,7 +361,7 @@ def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope, Symbol]> {
def SPV_ModuleOp : SPV_Op<"module",
[IsolatedFromAbove,
SingleBlockImplicitTerminator<"ModuleEndOp">,
- SymbolTable]> {
+ SymbolTable, Symbol]> {
let summary = "The top-level op that defines a SPIR-V module";
let description = [{
@@ -409,7 +409,8 @@ def SPV_ModuleOp : SPV_Op<"module",
let arguments = (ins
SPV_AddressingModelAttr:$addressing_model,
SPV_MemoryModelAttr:$memory_model,
- OptionalAttr<SPV_VerCapExtAttr>:$vce_triple
+ OptionalAttr<SPV_VerCapExtAttr>:$vce_triple,
+ OptionalAttr<StrAttr>:$sym_name
);
let results = (outs);
@@ -417,10 +418,12 @@ def SPV_ModuleOp : SPV_Op<"module",
let regions = (region SizedRegion<1>:$body);
let builders = [
- OpBuilder<[{OpBuilder &, OperationState &state}]>,
+ OpBuilder<[{OpBuilder &, OperationState &state,
+ Optional<StringRef> name = llvm::None}]>,
OpBuilder<[{OpBuilder &, OperationState &state,
spirv::AddressingModel addressing_model,
- spirv::MemoryModel memory_model}]>
+ spirv::MemoryModel memory_model,
+ Optional<StringRef> name = llvm::None}]>
];
// We need to ensure the block inside the region is properly terminated;
@@ -432,6 +435,11 @@ def SPV_ModuleOp : SPV_Op<"module",
let autogenSerialization = 0;
let extraClassDeclaration = [{
+
+ bool isOptionalSymbol() { return true; }
+
+ Optional<StringRef> getName() { return sym_name(); }
+
static StringRef getVCETripleAttrName() { return "vce_triple"; }
Block& getBlock() {
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 88ca71ac18acd..7aeecdbbac354 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -2282,24 +2282,39 @@ static LogicalResult verify(spirv::MergeOp mergeOp) {
// spv.module
//===----------------------------------------------------------------------===//
-void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state) {
+void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
+ Optional<StringRef> name) {
ensureTerminator(*state.addRegion(), builder, state.location);
+ if (name) {
+ state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
+ builder.getStringAttr(*name));
+ }
}
void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
- spirv::AddressingModel addressing_model,
- spirv::MemoryModel memory_model) {
+ spirv::AddressingModel addressingModel,
+ spirv::MemoryModel memoryModel,
+ Optional<StringRef> name) {
state.addAttribute(
"addressing_model",
- builder.getI32IntegerAttr(static_cast<int32_t>(addressing_model)));
+ builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
state.addAttribute("memory_model", builder.getI32IntegerAttr(
- static_cast<int32_t>(memory_model)));
+ static_cast<int32_t>(memoryModel)));
ensureTerminator(*state.addRegion(), builder, state.location);
+ if (name) {
+ state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
+ builder.getStringAttr(*name));
+ }
}
static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
Region *body = state.addRegion();
+ // If the name is present, parse it.
+ StringAttr nameAttr;
+ parser.parseOptionalSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
+ state.attributes);
+
// Parse attributes
spirv::AddressingModel addrModel;
spirv::MemoryModel memoryModel;
@@ -2328,13 +2343,19 @@ static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
printer << spirv::ModuleOp::getOperationName();
+ if (Optional<StringRef> name = moduleOp.getName()) {
+ printer << ' ';
+ printer.printSymbolName(*name);
+ }
+
SmallVector<StringRef, 2> elidedAttrs;
printer << " " << spirv::stringifyAddressingModel(moduleOp.addressing_model())
<< " " << spirv::stringifyMemoryModel(moduleOp.memory_model());
auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
- elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName});
+ elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
+ SymbolTable::getSymbolAttrName()});
if (Optional<spirv::VerCapExtAttr> triple = moduleOp.vce_triple()) {
printer << " requires " << *triple;
diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir
index e20da2e4e6c97..98da480b83ff1 100644
--- a/mlir/test/Dialect/SPIRV/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir
@@ -372,6 +372,9 @@ spv.module Logical GLSL450 {
// CHECK: spv.module Logical GLSL450
spv.module Logical GLSL450 { }
+// Module with a name
+// CHECK: spv.module @{{.*}} Logical GLSL450
+spv.module @name Logical GLSL450 { }
// Module with (version, capabilities, extensions) triple
// CHECK: spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], [SPV_KHR_16bit_storage]>
More information about the Mlir-commits
mailing list