[Mlir-commits] [mlir] 5756bc4 - [mlir][DeclarativeParser] Add support for formatting enum attributes in the string form.

River Riddle llvmlistbot at llvm.org
Thu Feb 13 17:17:46 PST 2020


Author: River Riddle
Date: 2020-02-13T17:11:48-08:00
New Revision: 5756bc4382a6023c8dcc25f39243a49ac413f9bf

URL: https://github.com/llvm/llvm-project/commit/5756bc4382a6023c8dcc25f39243a49ac413f9bf
DIFF: https://github.com/llvm/llvm-project/commit/5756bc4382a6023c8dcc25f39243a49ac413f9bf.diff

LOG: [mlir][DeclarativeParser] Add support for formatting enum attributes in the string form.

Summary: This revision adds support to the declarative parser for formatting enum attributes in the symbolized form. It uses this new functionality to port several of the SPIRV parsers over to the declarative form.

Differential Revision: https://reviews.llvm.org/D74525

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/ops.mlir
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
index b1478d048b5f..56c2b59b1e7f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
@@ -97,6 +97,10 @@ def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> {
   let results = (outs
     SPV_IntVec4:$result
   );
+
+  let assemblyFormat = [{
+    $execution_scope $predicate attr-dict `:` type($result)
+  }];
 }
 
 // -----
@@ -145,6 +149,8 @@ def SPV_GroupNonUniformElectOp : SPV_Op<"GroupNonUniformElect", []> {
   let builders = [
     OpBuilder<[{Builder *builder, OperationState &state, spirv::Scope}]>
   ];
+
+  let assemblyFormat = "$execution_scope attr-dict `:` type($result)";
 }
 
 // -----

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
index 9a578968a59f..b78ecc3cd8fe 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
@@ -162,6 +162,10 @@ def SPV_ControlBarrierOp : SPV_Op<"ControlBarrier", []> {
   let verifier = [{ return verifyMemorySemantics(*this); }];
 
   let autogenSerialization = 0;
+
+  let assemblyFormat = [{
+    $execution_scope `,` $memory_scope `,` $memory_semantics attr-dict
+  }];
 }
 
 // -----
@@ -319,6 +323,8 @@ def SPV_MemoryBarrierOp : SPV_Op<"MemoryBarrier", []> {
   let verifier = [{ return verifyMemorySemantics(*this); }];
 
   let autogenSerialization = 0;
+
+  let assemblyFormat = "$memory_scope `,` $memory_semantics attr-dict";
 }
 
 // -----

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index c8514d4e618e..6e97c3f58a66 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1467,32 +1467,6 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
   llvm_unreachable("unimplemented types for ConstantOp::getOne()");
 }
 
-//===----------------------------------------------------------------------===//
-// spv.ControlBarrier
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseControlBarrierOp(OpAsmParser &parser,
-                                         OperationState &state) {
-  spirv::Scope executionScope;
-  spirv::Scope memoryScope;
-  spirv::MemorySemantics memorySemantics;
-
-  return failure(
-      parseEnumAttribute(executionScope, parser, state,
-                         kExecutionScopeAttrName) ||
-      parser.parseComma() ||
-      parseEnumAttribute(memoryScope, parser, state, kMemoryScopeAttrName) ||
-      parser.parseComma() ||
-      parseEnumAttribute(memorySemantics, parser, state));
-}
-
-static void print(spirv::ControlBarrierOp op, OpAsmPrinter &printer) {
-  printer << spirv::ControlBarrierOp::getOperationName() << " \""
-          << stringifyScope(op.execution_scope()) << "\", \""
-          << stringifyScope(op.memory_scope()) << "\", \""
-          << stringifyMemorySemantics(op.memory_semantics()) << "\"";
-}
-
 //===----------------------------------------------------------------------===//
 // spv.EntryPoint
 //===----------------------------------------------------------------------===//
@@ -1916,28 +1890,6 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
 // spv.GroupNonUniformBallotOp
 //===----------------------------------------------------------------------===//
 
-static ParseResult parseGroupNonUniformBallotOp(OpAsmParser &parser,
-                                                OperationState &state) {
-  spirv::Scope executionScope;
-  OpAsmParser::OperandType operandInfo;
-  Type resultType;
-  IntegerType i1Type = parser.getBuilder().getI1Type();
-  if (parseEnumAttribute(executionScope, parser, state,
-                         kExecutionScopeAttrName) ||
-      parser.parseOperand(operandInfo) || parser.parseColonType(resultType) ||
-      parser.resolveOperand(operandInfo, i1Type, state.operands))
-    return failure();
-
-  return parser.addTypeToList(resultType, state.types);
-}
-
-static void print(spirv::GroupNonUniformBallotOp ballotOp,
-                  OpAsmPrinter &printer) {
-  printer << spirv::GroupNonUniformBallotOp::getOperationName() << " \""
-          << stringifyScope(ballotOp.execution_scope()) << "\" "
-          << ballotOp.predicate() << " : " << ballotOp.getType();
-}
-
 static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
   // TODO(antiagainst): check the result integer type's signedness bit is 0.
 
@@ -1959,25 +1911,6 @@ void spirv::GroupNonUniformElectOp::build(Builder *builder,
   build(builder, state, builder->getI1Type(), scope);
 }
 
-static ParseResult parseGroupNonUniformElectOp(OpAsmParser &parser,
-                                               OperationState &state) {
-  spirv::Scope executionScope;
-  Type resultType;
-  if (parseEnumAttribute(executionScope, parser, state,
-                         kExecutionScopeAttrName) ||
-      parser.parseColonType(resultType))
-    return failure();
-
-  return parser.addTypeToList(resultType, state.types);
-}
-
-static void print(spirv::GroupNonUniformElectOp groupOp,
-                  OpAsmPrinter &printer) {
-  printer << spirv::GroupNonUniformElectOp::getOperationName() << " \""
-          << stringifyScope(groupOp.execution_scope())
-          << "\" : " << groupOp.getType();
-}
-
 static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) {
   spirv::Scope scope = groupOp.execution_scope();
   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
@@ -1987,8 +1920,6 @@ static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) {
   return success();
 }
 
-
-
 //===----------------------------------------------------------------------===//
 // spv.IAdd
 //===----------------------------------------------------------------------===//
@@ -2296,27 +2227,6 @@ static LogicalResult verify(spirv::MergeOp mergeOp) {
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// spv.MemoryBarrier
-//===----------------------------------------------------------------------===//
-
-static ParseResult parseMemoryBarrierOp(OpAsmParser &parser,
-                                        OperationState &state) {
-  spirv::Scope memoryScope;
-  spirv::MemorySemantics memorySemantics;
-
-  return failure(
-      parseEnumAttribute(memoryScope, parser, state, kMemoryScopeAttrName) ||
-      parser.parseComma() ||
-      parseEnumAttribute(memorySemantics, parser, state));
-}
-
-static void print(spirv::MemoryBarrierOp op, OpAsmPrinter &printer) {
-  printer << spirv::MemoryBarrierOp::getOperationName() << " \""
-          << stringifyScope(op.memory_scope()) << "\", \""
-          << stringifyMemorySemantics(op.memory_semantics()) << "\"";
-}
-
 //===----------------------------------------------------------------------===//
 // spv.module
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir
index 09a41da2d340..cf7ab6096689 100644
--- a/mlir/test/Dialect/SPIRV/ops.mlir
+++ b/mlir/test/Dialect/SPIRV/ops.mlir
@@ -289,7 +289,7 @@ func @control_barrier_0() -> () {
 // -----
 
 func @control_barrier_1() -> () {
-  // expected-error @+1 {{invalid scope attribute specification: "Something"}}
+  // expected-error @+1 {{invalid execution_scope attribute specification: "Something"}}
   spv.ControlBarrier "Something", "Device", "Acquire|UniformMemory"
   return
 }

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 91918e099c81..b8aeb904e187 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -264,6 +264,18 @@ struct OperationFormat {
 //===----------------------------------------------------------------------===//
 // Parser Gen
 
+/// Returns if we can format the given attribute as an EnumAttr in the parser
+/// format.
+static bool canFormatEnumAttr(const NamedAttribute *attr) {
+  const EnumAttr *enumAttr = dyn_cast<EnumAttr>(&attr->attr);
+  if (!enumAttr)
+    return false;
+
+  // The attribute must have a valid underlying type and a constant builder.
+  return !enumAttr->getUnderlyingType().empty() &&
+         !enumAttr->getConstBuilderTemplate().empty();
+}
+
 /// The code snippet used to generate a parser call for an attribute.
 ///
 /// {0}: The storage type of the attribute.
@@ -275,6 +287,30 @@ const char *const attrParserCode = R"(
     return failure();
 )";
 
+/// The code snippet used to generate a parser call for an enum attribute.
+///
+/// {0}: The name of the attribute.
+/// {1}: The c++ namespace for the enum symbolize functions.
+/// {2}: The function to symbolize a string of the enum.
+/// {3}: The constant builder call to create an attribute of the enum type.
+const char *const enumAttrParserCode = R"(
+  {
+    StringAttr attrVal;
+    SmallVector<NamedAttribute, 1> attrStorage;
+    auto loc = parser.getCurrentLocation();
+    if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
+                              "{0}", attrStorage))
+      return failure();
+
+    auto attrOptional = {1}::{2}(attrVal.getValue());
+    if (!attrOptional)
+      return parser.emitError(loc, "invalid ")
+             << "{0} attribute specification: " << attrVal;
+
+    result.addAttribute("{0}", {3});
+  }
+)";
+
 /// The code snippet used to generate a parser call for an operand.
 ///
 /// {0}: The name of the operand.
@@ -383,6 +419,24 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
     } else if (auto *attr = dyn_cast<AttributeVariable>(element.get())) {
       const NamedAttribute *var = attr->getVar();
 
+      // Check to see if we can parse this as an enum attribute.
+      if (canFormatEnumAttr(var)) {
+        const EnumAttr &enumAttr = cast<EnumAttr>(var->attr);
+
+        // Generate the code for building an attribute for this enum.
+        std::string attrBuilderStr;
+        {
+          llvm::raw_string_ostream os(attrBuilderStr);
+          os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx,
+                      "attrOptional.getValue()");
+        }
+
+        body << formatv(enumAttrParserCode, var->name,
+                        enumAttr.getCppNamespace(),
+                        enumAttr.getStringToSymbolFnName(), attrBuilderStr);
+        continue;
+      }
+
       // If this attribute has a buildable type, use that when parsing the
       // attribute.
       std::string attrTypeStr;
@@ -637,7 +691,15 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
     if (auto *attr = dyn_cast<AttributeVariable>(element.get())) {
       const NamedAttribute *var = attr->getVar();
 
-      // Elide the attribute type if it is buildable..
+      // If we are formatting as a enum, symbolize the attribute as a string.
+      if (canFormatEnumAttr(var)) {
+        const EnumAttr &enumAttr = cast<EnumAttr>(var->attr);
+        body << "  p << \"\\\"\" << " << enumAttr.getSymbolToStringFnName()
+             << "(" << var->name << "()) << \"\\\"\";\n";
+        continue;
+      }
+
+      // Elide the attribute type if it is buildable.
       Optional<Type> attrType = var->attr.getValueType();
       if (attrType && attrType->getBuilderCall())
         body << "  p.printAttributeWithoutType(" << var->name << "Attr());\n";


        


More information about the Mlir-commits mailing list