[llvm-branch-commits] [mlir] 4086072 - Reland "[mlir][linalg] Support parsing attributes in named op spec"

Lei Zhang via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jan 12 08:05:05 PST 2021


Author: Lei Zhang
Date: 2021-01-12T10:57:46-05:00
New Revision: 4086072f8a9200216088c435c9aa90a2d8ed74a5

URL: https://github.com/llvm/llvm-project/commit/4086072f8a9200216088c435c9aa90a2d8ed74a5
DIFF: https://github.com/llvm/llvm-project/commit/4086072f8a9200216088c435c9aa90a2d8ed74a5.diff

LOG: Reland "[mlir][linalg] Support parsing attributes in named op spec"

With this, now we can specify a list of attributes on named ops
generated from the spec. The format is defined as

```
attr-id ::= bare-id (`?`)?
attr-typedef ::= type (`[` `]`)?
attr-def ::= attr-id `:` attr-typedef

tc-attr-def ::= `attr` `(` attr-def-list `)`
tc-def ::= `def` bare-id
  `(`tensor-def-list`)` `->` `(` tensor-def-list`)`
  (tc-attr-def)?
```

For example,

```
ods_def<SomeCppOp>
def some_op(...) -> (...)
attr(
  f32_attr: f32,
  i32_attr: i32,
  array_attr : f32[],
  optional_attr? : f32
)
```

where `?` means optional attribute and `[]` means array type.

Reviewed By: hanchung, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D94240

Added: 
    

Modified: 
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index f81380f02bb3..1ef128760637 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -72,3 +72,25 @@ ods_def<Test3Op> :
 def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
   C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
 }
+
+// Test attribute definitions
+// ODS-LABEL: def Test4Op
+// ODS: F32ArrayAttr:$array_attr,
+// ODS: F32:$f32_attr,
+// ODS: RankedF32ElementsAttr<[4]>:$fvec_attr,
+// ODS: I32:$i32_attr,
+// ODS: RankedI32ElementsAttr<[5, 6]>:$ivec_attr,
+// ODS: OptionalAttr<F32>:$optional_attr
+//
+ods_def<Test4Op> :
+def test4(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N))
+attr(
+  f32_attr: f32,
+  i32_attr: i32,
+  fvec_attr: 4xf32,
+  ivec_attr: 5x6xi32,
+  array_attr : f32[],
+  optional_attr? : f32
+) {
+  C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
+}

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 592e6cb774fb..138c5a4e904e 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -20,11 +20,17 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/ADT/Twine.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/ToolOutputFile.h"
 
+#include <map>
+
 #define DEBUG_TYPE "linalg-ods-gen"
 
 static llvm::cl::OptionCategory ODSGenCat("Linalg ODS Gen");
@@ -79,11 +85,14 @@ class Token {
     gt,
     l_brace,
     l_paren,
+    l_square,
     lt,
     minus,
     plus,
+    question,
     r_brace,
     r_paren,
+    r_square,
     semicolon,
     star,
 
@@ -91,6 +100,7 @@ class Token {
     kw_def,
     FIRST_KEYWORD = kw_def,
     kw_ods_def,
+    kw_attr_def,
     kw_floordiv,
     kw_ceildiv,
     kw_mod,
@@ -151,6 +161,10 @@ class Lexer {
   Token emitError(llvm::SMLoc loc, const Twine &msg);
   Token emitError(const char *loc, const Twine &msg);
 
+  /// Change the position of the lexer cursor. The next token we lex will start
+  /// at the designated point in the input.
+  void resetPointer(const char *newPtr) { curPtr = newPtr; }
+
 private:
   Token formToken(Token::Kind kind, const char *tokStart) {
     return Token(kind, StringRef(tokStart, curPtr - tokStart));
@@ -247,10 +261,14 @@ Token Lexer::lexToken() {
       return formToken(Token::Kind::l_brace, tokStart);
     case '(':
       return formToken(Token::Kind::l_paren, tokStart);
+    case '[':
+      return formToken(Token::Kind::l_square, tokStart);
     case '}':
       return formToken(Token::Kind::r_brace, tokStart);
     case ')':
       return formToken(Token::Kind::r_paren, tokStart);
+    case ']':
+      return formToken(Token::Kind::r_square, tokStart);
     case '<':
       return formToken(Token::Kind::lt, tokStart);
     case '>':
@@ -263,6 +281,8 @@ Token Lexer::lexToken() {
       return formToken(Token::Kind::semicolon, tokStart);
     case '*':
       return formToken(Token::Kind::star, tokStart);
+    case '?':
+      return formToken(Token::Kind::question, tokStart);
     case '/':
       if (*curPtr == '/') {
         skipComment();
@@ -289,6 +309,7 @@ Token Lexer::lexIdentifier(const char *tokStart) {
   // Check to see if this identifier is a keyword.
   StringRef str(tokStart, curPtr - tokStart);
   Token::Kind kind = StringSwitch<Token::Kind>(str)
+                         .Case("attr", Token::Kind::kw_attr_def)
                          .Case("def", Token::Kind::kw_def)
                          .Case("ods_def", Token::Kind::kw_ods_def)
                          .Case("floordiv", Token::Kind::kw_floordiv)
@@ -352,29 +373,40 @@ class Parser {
            "shouldn't advance past EOF or errors");
     curToken = lexer.lexToken();
   }
+
   void consumeToken(Token::Kind kind) {
     assert(curToken.getKind() == kind && "unexpected token");
     curToken = lexer.lexToken();
   }
+
   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
     if (curToken.getKind() != kind)
       return emitError(curToken.getLoc(), msg);
     consumeToken();
     return success();
   }
+
+  /// Parses an optional token and returns failure if failed to parse.
+  LogicalResult parseOptionalToken(Token::Kind kind) {
+    return success(consumeIf(kind));
+  }
+
   LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) {
     lexer.emitError(loc, msg);
     return failure();
   }
+
   LogicalResult emitError(const Twine &msg) {
     return emitError(curToken.getLoc(), msg);
   }
+
   bool consumeIf(Token::Kind kind) {
     if (curToken.isNot(kind))
       return false;
     consumeToken(kind);
     return true;
   }
+
   LogicalResult
   parseCommaSeparatedList(llvm::function_ref<ParseResult()> parseElement) {
     // Non-empty case starts with an element.
@@ -388,6 +420,7 @@ class Parser {
     }
     return success();
   }
+
   LogicalResult
   parseCommaSeparatedListUntil(Token::Kind rightToken,
                                llvm::function_ref<ParseResult()> parseElement,
@@ -961,6 +994,8 @@ class TCParser {
   LogicalResult parseTensorUse(TensorUse &result,
                                ComprehensionParsingState &state);
 
+  LogicalResult parseAttrDef();
+
   /// Parses a tensor expression.
   LogicalResult parseExpression(TensorUse currentDefinition,
                                 std::unique_ptr<Expression> &result,
@@ -1010,15 +1045,29 @@ class TCParser {
     unsigned index;
   };
 
+  //===--------------------------------------------------------------------===//
+  // Internal bookkeeping of attributes.
+  //===--------------------------------------------------------------------===//
+  struct RegisteredAttr {
+    StringRef elementType;
+    SmallVector<uint64_t, 4> vectorDims;
+    bool isArray;
+    bool isOptional;
+  };
+
   //===--------------------------------------------------------------------===//
   // Per-TC def state.
   //===--------------------------------------------------------------------===//
   /// Symbols are per TC def.
   AffineSymbolList symbols;
+
   /// Tensors are per TC def.
   llvm::StringMap<RegisteredTensor> registeredTensors;
   unsigned nextRegisteredTensorIndex;
 
+  /// Attributes are per TC def.
+  std::map<std::string, RegisteredAttr> registeredAttrs;
+
   Parser &parser;
 };
 } // namespace
@@ -1170,6 +1219,73 @@ LogicalResult TCParser::parseTensorUse(TensorUse &result,
   return success();
 }
 
+/// Parse the information for an attribute def of the form:
+///
+///   affine-expr-list ::= affine-expr (`,` affine-expr )*
+///   attr-id ::= bare-id (`?`)?
+///   dim-list ::= (integer-literal 'x')+
+///   attr-typedef ::= dim-list? type (`[` `]`)?
+///   attr-def ::= attr-id `:` attr-typedef
+LogicalResult TCParser::parseAttrDef() {
+  auto attrLoc = parser.curToken.getLoc();
+  StringRef attrName = parser.curToken.getSpelling();
+  if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
+    return failure();
+  bool isOptional = succeeded(parser.parseOptionalToken(Token::Kind::question));
+  if (failed(parser.parseToken(Token::Kind::colon, "expected colon")))
+    return failure();
+
+  // Parse the attribute's type. We don't expect the type to be arbitrary
+  // complex, so just use this ad-hoc handling here.
+
+  // Parse potential dimension list
+  SmallVector<uint64_t, 4> vectorDims;
+  while (parser.curToken.is(Token::Kind::integer)) {
+    vectorDims.push_back(parser.curToken.getUInt64IntegerValue().getValue());
+    parser.consumeToken();
+
+    StringRef spelling = parser.curToken.getSpelling();
+    if (spelling[0] != 'x')
+      return parser.emitError(parser.curToken.getLoc(),
+                              "expected 'x' in dimension list");
+
+    // If we had a prefix of 'x', lex the next token immediately after the 'x'.
+    if (spelling.size() != 1)
+      parser.lexer.resetPointer(spelling.data() + 1);
+
+    parser.consumeToken();
+  }
+
+  StringRef elementType = parser.curToken.getSpelling();
+  if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
+    return failure();
+
+  bool isArray = false;
+  auto arrayLoc = parser.curToken.getLoc();
+  if (succeeded(parser.parseOptionalToken(Token::Kind::l_square))) {
+    isArray = true;
+    if (failed(parser.parseToken(Token::Kind::r_square, "expected ']'")))
+      return failure();
+  }
+
+  if (!vectorDims.empty() && isArray)
+    return parser.emitError(arrayLoc, "unsupported vector array attribute");
+
+  auto iterBoolPair = registeredAttrs.emplace(
+      attrName.str(),
+      RegisteredAttr{elementType, vectorDims, isArray, isOptional});
+  if (!iterBoolPair.second)
+    return parser.emitError(attrLoc,
+                            "Failed to register attribute '" + attrName + "'");
+
+  LLVM_DEBUG(llvm::dbgs() << "Recorded: " << (isOptional ? "[optional]" : "")
+                          << " " << attrName << " "
+                          << "with type: " << elementType
+                          << (isArray ? "[]" : "") << "\n");
+
+  return success();
+}
+
 /// Parses a tensor expression of the form:
 ///
 ///   op-spec ::= bare-id `<` reduction-dims-list `>`
@@ -1341,10 +1457,13 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
 /// Parse and print the information for a ODS def.
 ///
 ///   tensor-def-list ::= tensor-def (`,` tensor-def )*
+///   attr-def-list ::= attr-def (`,` attr-def )*
 ///
 ///   comprehension-list ::= comprehension comprehension*
 ///
+///   tc-attr-def ::= `attr` `(` attr-def-list `)`
 ///   tc-def ::= `def` bare-id `(`tensor-def-list`)` `->` `(` tensor-def-list`)`
+///     (tc-attr-def)?
 ///     `{` comprehension-list `}`
 ///
 ///   ods-def ::= `ods_def` `<` bare-id `>` `:` tc-def
@@ -1353,6 +1472,7 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
 /// contain only expressions involving symbols and constants), but can
 /// otherwise contain arbitrary affine expressions.
 LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
+  // Parse def header (including C++ op name)
   if (failed(parser.parseToken(Token::Kind::kw_ods_def,
                                "expected 'ods_def' to define a TC ODS")) ||
       failed(parser.parseToken(Token::Kind::lt, "expected '<'")))
@@ -1364,12 +1484,15 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
       failed(parser.parseToken(Token::Kind::gt, "expected '>'")) ||
       failed(parser.parseToken(Token::Kind::colon, "expected ':'")))
     return failure();
+
   if (failed(parser.parseToken(Token::Kind::kw_def,
                                "expected 'def' to define a TC")))
     return failure();
 
   StringRef tcName = parser.curToken.getSpelling();
   LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing TC: " << tcName << "\n");
+
+  // Parse input/output tensor definitions
   if (failed(parser.parseToken(Token::Kind::id, "expected id")) ||
       failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
     return failure();
@@ -1392,6 +1515,16 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
           Token::Kind::r_paren, parseOutputDef, /*allowEmptyList=*/false)))
     return failure();
 
+  // Parse optional attribute definitions
+  if (succeeded(parser.parseOptionalToken(Token::Kind::kw_attr_def))) {
+    if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
+      return failure();
+    if (failed(parser.parseCommaSeparatedListUntil(
+            Token::Kind::r_paren, std::bind(&TCParser::parseAttrDef, this),
+            /*allowEmptyList=*/false)))
+      return failure();
+  }
+
   // Since we don't declare symbols separately, we discover them eagerly: each
   // newly encountered id in a tensor shape expression is treated as a new
   // symbolic. At this point, all tensors have been parsed and all the symbols
@@ -1450,12 +1583,52 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
 void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
                         StringRef linalgOpName,
                         ComprehensionParsingState &state) {
+  SmallVector<std::string, 4> attributes;
+  for (const auto &attr : registeredAttrs) {
+    llvm::StringRef name = attr.first;
+
+    llvm::StringRef elementType = attr.second.elementType;
+    std::string odsType = llvm::StringSwitch<std::string>(elementType)
+                              .Case("f32", "F32")
+                              .Case("i32", "I32")
+                              .Default("");
+    if (odsType.empty()) {
+      parser.emitError("unimplemented support for attribute element type: " +
+                       elementType);
+      return;
+    }
+
+    const auto &dims = attr.second.vectorDims;
+    if (!dims.empty()) {
+      SmallVector<std::string, 4> dimStrs;
+      for (uint64_t dim : dims)
+        dimStrs.push_back(std::to_string(dim));
+      odsType = llvm::formatv("Ranked{0}ElementsAttr<[{1}]>", odsType,
+                              llvm::join(dimStrs, ", "));
+    }
+
+    assert(dims.empty() || !attr.second.isArray);
+    if (attr.second.isArray)
+      odsType = llvm::formatv("{0}ArrayAttr", odsType);
+
+    if (attr.second.isOptional)
+      odsType = llvm::formatv("OptionalAttr<{0}>", odsType);
+
+    attributes.push_back(llvm::formatv("{0}:${1}", odsType, name));
+  }
+
+  std::string attrList = llvm::join(attributes, ",\n");
+  if (!attrList.empty())
+    attrList = ",\n" + attrList;
+
   const char *header = R"FMT(  def {0} : LinalgStructuredBase_Op<"{1}", [
     AttrSizedOperandSegments,
     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
     SingleBlockImplicitTerminator<"YieldOp">]> {
-      let arguments = (ins Variadic<AnyShaped>:$inputs,
-                           Variadic<AnyShaped>:$outputs);
+      let arguments = (ins
+        Variadic<AnyShaped>:$inputs,
+        Variadic<AnyShaped>:$outputs{4}
+      );
       let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
       let regions = (region AnyRegion:$region);
 
@@ -1515,7 +1688,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
         static std::function<void(Block &)> getRegionBuilder() {{ return regionBuilder; }
 
         // Generic methods.
-        static unsigned getNumRegionArgs() {{ return {4}; }
+        static unsigned getNumRegionArgs() {{ return {5}; }
         std::string getLibraryCallName() {{
           return generateLibraryCallName(getOperation());
         }
@@ -1531,7 +1704,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
   }
 
   os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs,
-                      state.orderedTensorArgs.size());
+                      attrList, state.orderedTensorArgs.size());
 }
 
 /// Print the C++ StructuredOpsInterface impl of `iterator_types`.


        


More information about the llvm-branch-commits mailing list