[Mlir-commits] [mlir] a2ad3ec - [mlir][ods] Support string literals in `custom` directives

Jeff Niu llvmlistbot at llvm.org
Fri Aug 12 17:55:16 PDT 2022


Author: Jeff Niu
Date: 2022-08-12T20:55:11-04:00
New Revision: a2ad3ec7ac6279370630ec05d6426c97f4cf50d6

URL: https://github.com/llvm/llvm-project/commit/a2ad3ec7ac6279370630ec05d6426c97f4cf50d6
DIFF: https://github.com/llvm/llvm-project/commit/a2ad3ec7ac6279370630ec05d6426c97f4cf50d6.diff

LOG: [mlir][ods] Support string literals in `custom` directives

This patch adds support for string literals as `custom` directive
arguments. This can be useful for re-using custom parsers and printers
when arguments have a known value. For example:

```
ParseResult parseTypedAttr(AsmParser &parser, Attribute &attr, Type type) {
  return parser.parseAttribute(attr, type);
}

void printTypedAttr(AsmPrinter &printer, Attribute attr, Type type) {
  return parser.printAttributeWithoutType(attr);
}
```

And in TableGen:

```
def FooOp : ... {
  let arguments = (ins AnyAttr:$a);
  let assemblyFormat = [{ custom<TypedAttr>($a, "$_builder.getI1Type()")
                          attr-dict }];
}

def BarOp : ... {
  let arguments = (ins AnyAttr:$a);
  let assemblyFormat = [{ custom<TypedAttr>($a, "$_builder.getIndexType()")
                          attr-dict }];
}
```

Instead of writing two separate sets of custom parsers and printers.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D131603

Added: 
    mlir/test/mlir-tblgen/op-format.td

Modified: 
    mlir/docs/AttributesAndTypes.md
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/include/mlir/Interfaces/ViewLikeInterface.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Interfaces/ViewLikeInterface.cpp
    mlir/test/mlir-tblgen/attr-or-type-format.td
    mlir/test/mlir-tblgen/op-format-invalid.td
    mlir/test/mlir-tblgen/op-format-spec.td
    mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
    mlir/tools/mlir-tblgen/FormatGen.cpp
    mlir/tools/mlir-tblgen/FormatGen.h
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md
index caf0078dd2b4b..748b69fd86ee0 100644
--- a/mlir/docs/AttributesAndTypes.md
+++ b/mlir/docs/AttributesAndTypes.md
@@ -895,6 +895,19 @@ void printStringParam(AsmPrinter &printer, StringRef value);
 The custom parser is considered to have failed if it returns failure or if any
 bound parameters have failure values afterwards.
 
+A string of C++ code can be used as a `custom` directive argument. When
+generating the custom parser and printer call, the string is pasted as a
+function argument. For example, `parseBar` and `printBar` can be re-used with
+a constant integer:
+
+```tablegen
+let parameters = (ins "int":$bar);
+let assemblyFormat = [{ custom<Bar>($foo, "1") }];
+```
+
+The string is pasted verbatim but with substitutions for `$_builder` and
+`$_ctxt`. String literals can be used to parameterize custom directives.
+
 ### Verification
 
 If the `genVerifyDecl` field is set, additional verification methods are

diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index eea9ad79bf54f..51d1076bbceb1 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -768,9 +768,9 @@ when generating the C++ code for the format. The `UserDirective` is an
 identifier used as a suffix to these two calls, i.e., `custom<MyDirective>(...)`
 would result in calls to `parseMyDirective` and `printMyDirective` within the
 parser and printer respectively. `Params` may be any combination of variables
-(i.e. Attribute, Operand, Successor, etc.), type directives, and `attr-dict`.
-The type directives must refer to a variable, but that variable need not also be
-a parameter to the custom directive.
+(i.e. Attribute, Operand, Successor, etc.), type directives, `attr-dict`, and
+strings of C++ code. The type directives must refer to a variable, but that
+variable need not also be a parameter to the custom directive.
 
 The arguments to the `parse<UserDirective>` method are firstly a reference to
 the `OpAsmParser`(`OpAsmParser &`), and secondly a set of output parameters
@@ -837,7 +837,16 @@ declarative parameter to `print` method argument is detailed below:
     -   VariadicOfVariadic: `TypeRangeRange`
 *   `attr-dict` Directive: `DictionaryAttr`
 
-When a variable is optional, the provided value may be null.
+When a variable is optional, the provided value may be null. When a variable is
+referenced in a custom directive parameter using `ref`, it is passed in by
+value. Referenced variables to `print<UserDirective>` are passed as the same as
+bound variables, but referenced variables to `parse<UserDirective>` are passed
+like to the printer.
+
+A custom directive can take a string of C++ code as a parameter. The code is
+pasted verbatim in the calls to the custom parser and printers, with the
+substitutions `$_builder` and `$_ctxt`. String literals can be used to
+parameterize custom directives.
 
 #### Optional Groups
 
@@ -1462,7 +1471,7 @@ std::string stringifyMyBitEnum(MyBitEnum symbol) {
   if (2u == (2u & val)) { strs.push_back("Bit1"); }
   if (4u == (4u & val)) { strs.push_back("Bit2"); }
   if (8u == (8u & val)) { strs.push_back("Bit3"); }
-  
+
   return llvm::join(strs, "|");
 }
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 6380104473feb..c4b8add825491 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -45,8 +45,9 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
   let results = (outs AnyTensor:$result);
 
   let assemblyFormat = [{
-    custom<OperandsOrIntegersSizesList>($sizes, $static_sizes) attr-dict
-    `:` type($result)
+    custom<DynamicIndexList>($sizes, $static_sizes,
+                               "ShapedType::kDynamicSize")
+    attr-dict `:` type($result)
   }];
 
   let extraClassDeclaration = [{

diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index d397d2911c184..1e780121c2428 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1023,11 +1023,14 @@ def MemRef_ReinterpretCastOp
 
   let assemblyFormat = [{
     $source `to` `offset` `` `:`
-    custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
+    custom<DynamicIndexList>($offsets, $static_offsets,
+                               "ShapedType::kDynamicStrideOrOffset")
     `` `,` `sizes` `` `:`
-    custom<OperandsOrIntegersSizesList>($sizes, $static_sizes) `` `,` `strides`
-    `` `:`
-    custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
+    custom<DynamicIndexList>($sizes, $static_sizes,
+                               "ShapedType::kDynamicSize")
+    `` `,` `strides` `` `:`
+    custom<DynamicIndexList>($strides, $static_strides,
+                               "ShapedType::kDynamicStrideOrOffset")
     attr-dict `:` type($source) `to` type($result)
   }];
 
@@ -1586,9 +1589,12 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
 
   let assemblyFormat = [{
     $source ``
-    custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
-    custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
-    custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
+    custom<DynamicIndexList>($offsets, $static_offsets,
+                               "ShapedType::kDynamicStrideOrOffset")
+    custom<DynamicIndexList>($sizes, $static_sizes,
+                               "ShapedType::kDynamicSize")
+    custom<DynamicIndexList>($strides, $static_strides,
+                               "ShapedType::kDynamicStrideOrOffset")
     attr-dict `:` type($source) `to` type($result)
   }];
 

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index e3f953f8c528c..4095d4e036a5a 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -219,11 +219,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
     To disambiguate, the inference helpers `inferCanonicalRankReducedResultType`
     only drop the first unit dimensions, in order:
       e.g. 1x6x1 rank-reduced to 2-D will infer the 6x1 2-D shape, but not 1x6.
-    
+
     Verification however has access to result type and does not need to infer.
-    The verifier calls `isRankReducedType(getSource(), getResult())` to 
+    The verifier calls `isRankReducedType(getSource(), getResult())` to
     determine whether the result type is rank-reduced from the source type.
-    This computes a so-called rank-reduction mask, consisting of dropped unit 
+    This computes a so-called rank-reduction mask, consisting of dropped unit
     dims, to map the rank-reduced type to the source type by dropping ones:
       e.g. 1x6 is a rank-reduced version of 1x6x1 by mask {2}
            6x1 is a rank-reduced version of 1x6x1 by mask {0}
@@ -254,9 +254,12 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
 
   let assemblyFormat = [{
     $source ``
-    custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
-    custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
-    custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
+    custom<DynamicIndexList>($offsets, $static_offsets,
+                               "ShapedType::kDynamicStrideOrOffset")
+    custom<DynamicIndexList>($sizes, $static_sizes,
+                               "ShapedType::kDynamicSize")
+    custom<DynamicIndexList>($strides, $static_strides,
+                               "ShapedType::kDynamicStrideOrOffset")
     attr-dict `:` type($source) `to` type($result)
   }];
 
@@ -298,12 +301,12 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
     /// tensor type to the result tensor type by dropping unit dims.
     llvm::Optional<llvm::SmallDenseSet<unsigned>>
     computeRankReductionMask() {
-      return ::mlir::computeRankReductionMask(getSourceType().getShape(), 
+      return ::mlir::computeRankReductionMask(getSourceType().getShape(),
                                               getType().getShape());
     };
 
     /// An extract_slice result type can be inferred, when it is not
-    /// rank-reduced, from the source type and the static representation of 
+    /// rank-reduced, from the source type and the static representation of
     /// offsets, sizes and strides. Special sentinels encode the dynamic case.
     static RankedTensorType inferResultType(
       ShapedType sourceShapedTensorType,
@@ -580,9 +583,12 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
 
   let assemblyFormat = [{
     $source `into` $dest ``
-    custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
-    custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
-    custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
+    custom<DynamicIndexList>($offsets, $static_offsets,
+                               "ShapedType::kDynamicStrideOrOffset")
+    custom<DynamicIndexList>($sizes, $static_sizes,
+                               "ShapedType::kDynamicSize")
+    custom<DynamicIndexList>($strides, $static_strides,
+                               "ShapedType::kDynamicStrideOrOffset")
     attr-dict `:` type($source) `into` type($dest)
   }];
 
@@ -608,7 +614,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
     RankedTensorType getType() {
       return getResult().getType().cast<RankedTensorType>();
     }
-    
+
     /// The `dest` type is the same as the result type.
     RankedTensorType getDestType() {
       return getType();
@@ -962,8 +968,10 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
   let assemblyFormat = [{
     $source
     (`nofold` $nofold^)?
-    `low` `` custom<OperandsOrIntegersSizesList>($low, $static_low)
-    `high` `` custom<OperandsOrIntegersSizesList>($high, $static_high)
+    `low` `` custom<DynamicIndexList>($low, $static_low,
+                                        "ShapedType::kDynamicSize")
+    `high` `` custom<DynamicIndexList>($high, $static_high,
+                                         "ShapedType::kDynamicSize")
     $region attr-dict `:` type($source) `to` type($result)
   }];
 
@@ -1069,15 +1077,15 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
        // HasParent<"ParallelCombiningOpInterface">
   ]> {
   let summary = [{
-    Specify the tensor slice update of a single thread of a parent 
+    Specify the tensor slice update of a single thread of a parent
     ParallelCombiningOpInterface op.
   }];
   let description = [{
-    The `parallel_insert_slice` yields a subset tensor value to its parent 
+    The `parallel_insert_slice` yields a subset tensor value to its parent
     ParallelCombiningOpInterface. These subset tensor values are aggregated to
-    in some unspecified order into a full tensor value returned by the parent 
-    parallel iterating op. 
-    The `parallel_insert_slice` is one such op allowed in the 
+    in some unspecified order into a full tensor value returned by the parent
+    parallel iterating op.
+    The `parallel_insert_slice` is one such op allowed in the
     ParallelCombiningOpInterface op.
 
     Conflicting writes result in undefined semantics, in that the indices written
@@ -1118,12 +1126,12 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
     into a memref.subview op.
 
     A parallel_insert_slice operation may additionally specify insertion into a
-    tensor of higher rank than the source tensor, along dimensions that are 
+    tensor of higher rank than the source tensor, along dimensions that are
     statically known to be of size 1.
     This rank-altering behavior is not required by the op semantics: this
     flexibility allows to progressively drop unit dimensions while lowering
     between 
diff erent flavors of ops on that operate on tensors.
-    The rank-altering behavior of tensor.parallel_insert_slice matches the 
+    The rank-altering behavior of tensor.parallel_insert_slice matches the
     rank-reducing behavior of tensor.insert_slice and tensor.extract_slice.
 
     Verification in the rank-reduced case:
@@ -1144,9 +1152,12 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
   );
   let assemblyFormat = [{
     $source `into` $dest ``
-    custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
-    custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
-    custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
+    custom<DynamicIndexList>($offsets, $static_offsets,
+                               "ShapedType::kDynamicStrideOrOffset")
+    custom<DynamicIndexList>($sizes, $static_sizes,
+                               "ShapedType::kDynamicSize")
+    custom<DynamicIndexList>($strides, $static_strides,
+                               "ShapedType::kDynamicStrideOrOffset")
     attr-dict `:` type($source) `into` type($dest)
   }];
 
@@ -1194,7 +1205,7 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
       "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
   ];
-  
+
   let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index e590be2bbc54b..4bb13ce750607 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -84,73 +84,40 @@ namespace mlir {
 
 /// Printer hook for custom directive in assemblyFormat.
 ///
-///   custom<OperandsOrIntegersOffsetsOrStridesList>($values, $integers)
+///   custom<DynamicIndexList>($values, $integers)
 ///
 /// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
-/// type `I64ArrayAttr`.  for use in in assemblyFormat. Prints a list with
-/// either (1) the static integer value in `integers` if the value is
-/// ShapedType::kDynamicStrideOrOffset or (2) the next value otherwise.  This
-/// allows idiomatic printing of mixed value and integer attributes in a
-/// list. E.g. `[%arg0, 7, 42, %arg42]`.
-void printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &printer,
-                                                 Operation *op,
-                                                 OperandRange values,
-                                                 ArrayAttr integers);
-
-/// Printer hook for custom directive in assemblyFormat.
-///
-///   custom<OperandsOrIntegersSizesList>($values, $integers)
-///
-/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
-/// type `I64ArrayAttr`.  for use in in assemblyFormat. Prints a list with
-/// either (1) the static integer value in `integers` if the value is
-/// ShapedType::kDynamicSize or (2) the next value otherwise.  This
-/// allows idiomatic printing of mixed value and integer attributes in a
-/// list. E.g. `[%arg0, 7, 42, %arg42]`.
-void printOperandsOrIntegersSizesList(OpAsmPrinter &printer, Operation *op,
-                                      OperandRange values, ArrayAttr integers);
-
-/// Pasrer hook for custom directive in assemblyFormat.
-///
-///   custom<OperandsOrIntegersOffsetsOrStridesList>($values, $integers)
-///
-/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
-/// type `I64ArrayAttr`.  for use in in assemblyFormat. Parse a mixed list with
-/// either (1) static integer values or (2) SSA values.  Fill `integers` with
-/// the integer ArrayAttr, where ShapedType::kDynamicStrideOrOffset encodes the
-/// position of SSA values. Add the parsed SSA values to `values` in-order.
-//
-/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
-///   1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
-///   2. `ssa` is filled with "[%arg0, %arg1]".
-ParseResult parseOperandsOrIntegersOffsetsOrStridesList(
-    OpAsmParser &parser,
-    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-    ArrayAttr &integers);
+/// type `I64ArrayAttr`. Prints a list with either (1) the static integer value
+/// in `integers` is `dynVal` or (2) the next value otherwise. This allows
+/// idiomatic printing of mixed value and integer attributes in a list. E.g.
+/// `[%arg0, 7, 42, %arg42]`.
+void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
+                           OperandRange values, ArrayAttr integers,
+                           int64_t dynVal);
 
 /// Pasrer hook for custom directive in assemblyFormat.
 ///
-///   custom<OperandsOrIntegersSizesList>($values, $integers)
+///   custom<DynamicIndexList>($values, $integers)
 ///
 /// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
-/// type `I64ArrayAttr`.  for use in in assemblyFormat. Parse a mixed list with
-/// either (1) static integer values or (2) SSA values.  Fill `integers` with
-/// the integer ArrayAttr, where ShapedType::kDynamicSize encodes the
-/// position of SSA values. Add the parsed SSA values to `values` in-order.
+/// type `I64ArrayAttr`. Parse a mixed list with either (1) static integer
+/// values or (2) SSA values. Fill `integers` with the integer ArrayAttr, where
+/// `dynVal` encodes the position of SSA values. Add the parsed SSA values
+/// to `values` in-order.
 //
 /// E.g. after parsing "[%arg0, 7, 42, %arg42]":
 ///   1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
 ///   2. `ssa` is filled with "[%arg0, %arg1]".
-ParseResult parseOperandsOrIntegersSizesList(
-    OpAsmParser &parser,
-    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-    ArrayAttr &integers);
+ParseResult
+parseDynamicIndexList(OpAsmParser &parser,
+                      SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+                      ArrayAttr &integers, int64_t dynVal);
 
 /// Verify that a the `values` has as many elements as the number of entries in
 /// `attr` for which `isDynamic` evaluates to true.
 LogicalResult verifyListOfOperandsOrIntegers(
     Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr,
-    ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic);
+    ValueRange values, function_ref<bool(int64_t)> isDynamic);
 
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 328c4141e50fa..be4f202cc0b39 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -987,7 +987,8 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
   auto pdlOperationType = pdl::OperationType::get(parser.getContext());
   if (parser.parseOperand(target) ||
       parser.resolveOperand(target, pdlOperationType, result.operands) ||
-      parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) ||
+      parseDynamicIndexList(parser, dynamicSizes, staticSizes,
+                            ShapedType::kDynamicSize) ||
       parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
       parser.parseOptionalAttrDict(result.attributes))
     return ParseResult::failure();
@@ -1001,8 +1002,8 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
 
 void TileOp::print(OpAsmPrinter &p) {
   p << ' ' << getTarget();
-  printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(),
-                                   getStaticSizes());
+  printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
+                        ShapedType::kDynamicSize);
   p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
 }
 

diff  --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index a593ec6a2a240..89ebd81271721 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -70,45 +70,29 @@ mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
   return success();
 }
 
-template <int64_t dynVal>
-static void printOperandsOrIntegersListImpl(OpAsmPrinter &p, ValueRange values,
-                                            ArrayAttr arrayAttr) {
-  p << '[';
-  if (arrayAttr.empty()) {
-    p << "]";
+void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
+                                 OperandRange values, ArrayAttr integers,
+                                 int64_t dynVal) {
+  printer << '[';
+  if (integers.empty()) {
+    printer << "]";
     return;
   }
   unsigned idx = 0;
-  llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
+  llvm::interleaveComma(integers, printer, [&](Attribute a) {
     int64_t val = a.cast<IntegerAttr>().getInt();
     if (val == dynVal)
-      p << values[idx++];
+      printer << values[idx++];
     else
-      p << val;
+      printer << val;
   });
-  p << ']';
+  printer << ']';
 }
 
-void mlir::printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &p,
-                                                       Operation *op,
-                                                       OperandRange values,
-                                                       ArrayAttr integers) {
-  return printOperandsOrIntegersListImpl<ShapedType::kDynamicStrideOrOffset>(
-      p, values, integers);
-}
-
-void mlir::printOperandsOrIntegersSizesList(OpAsmPrinter &p, Operation *op,
-                                            OperandRange values,
-                                            ArrayAttr integers) {
-  return printOperandsOrIntegersListImpl<ShapedType::kDynamicSize>(p, values,
-                                                                   integers);
-}
-
-template <int64_t dynVal>
-static ParseResult parseOperandsOrIntegersImpl(
+ParseResult mlir::parseDynamicIndexList(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-    ArrayAttr &integers) {
+    ArrayAttr &integers, int64_t dynVal) {
   if (failed(parser.parseLSquare()))
     return failure();
   // 0-D.
@@ -142,22 +126,6 @@ static ParseResult parseOperandsOrIntegersImpl(
   return success();
 }
 
-ParseResult mlir::parseOperandsOrIntegersOffsetsOrStridesList(
-    OpAsmParser &parser,
-    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-    ArrayAttr &integers) {
-  return parseOperandsOrIntegersImpl<ShapedType::kDynamicStrideOrOffset>(
-      parser, values, integers);
-}
-
-ParseResult mlir::parseOperandsOrIntegersSizesList(
-    OpAsmParser &parser,
-    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-    ArrayAttr &integers) {
-  return parseOperandsOrIntegersImpl<ShapedType::kDynamicSize>(parser, values,
-                                                               integers);
-}
-
 bool mlir::detail::sameOffsetsSizesAndStrides(
     OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
     llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) {

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index 960c3870a9c0d..937c35d503769 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -555,3 +555,19 @@ def TypeK : TestType<"TestM"> {
   let mnemonic = "type_k";
   let assemblyFormat = "$a";
 }
+
+// TYPE-LABEL: ::mlir::Type TestNType::parse
+// TYPE: parseFoo(
+// TYPE-NEXT: _result_a,
+// TYPE-NEXT: 1);
+
+// TYPE-LABEL: void TestNType::print
+// TYPE: printFoo(
+// TYPE-NEXT: getA(),
+// TYPE-NEXT: 1);
+
+def TypeL : TestType<"TestN"> {
+  let parameters = (ins "int":$a);
+  let mnemonic = "type_l";
+  let assemblyFormat = [{ custom<Foo>($a, "1") }];
+}

diff  --git a/mlir/test/mlir-tblgen/op-format-invalid.td b/mlir/test/mlir-tblgen/op-format-invalid.td
index d165587e079ca..aae9ed45b92d1 100644
--- a/mlir/test/mlir-tblgen/op-format-invalid.td
+++ b/mlir/test/mlir-tblgen/op-format-invalid.td
@@ -403,6 +403,13 @@ def OptionalInvalidP : TestFormat_Op<[{
   ($arg^):(`test`)
 }]>, Arguments<(ins Variadic<I64>:$arg)>;
 
+//===----------------------------------------------------------------------===//
+// Strings
+//===----------------------------------------------------------------------===//
+
+// CHECK: error: strings may only be used as 'custom' directive arguments
+def StringInvalidA : TestFormat_Op<[{ "foo" }]>;
+
 //===----------------------------------------------------------------------===//
 // Variables
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 553c996190a79..6d21afdfe6de7 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -135,6 +135,13 @@ def OptionalValidA : TestFormat_Op<[{
   (` ` `` $arg^)? attr-dict
 }]>, Arguments<(ins Optional<I32>:$arg)>;
 
+//===----------------------------------------------------------------------===//
+// Strings
+//===----------------------------------------------------------------------===//
+
+// CHECK-NOT: error
+def StringInvalidA : TestFormat_Op<[{ custom<Foo>("foo") attr-dict }]>;
+
 //===----------------------------------------------------------------------===//
 // Variables
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-tblgen/op-format.td b/mlir/test/mlir-tblgen/op-format.td
new file mode 100644
index 0000000000000..1fdb485e5e56c
--- /dev/null
+++ b/mlir/test/mlir-tblgen/op-format.td
@@ -0,0 +1,42 @@
+// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+
+def TestDialect : Dialect {
+  let name = "test";
+}
+class TestFormat_Op<string fmt, list<Trait> traits = []>
+    : Op<TestDialect, "format_op", traits> {
+  let assemblyFormat = fmt;
+}
+
+//===----------------------------------------------------------------------===//
+// Directives
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// custom
+
+// CHECK-LABEL: CustomStringLiteralA::parse
+// CHECK: parseFoo({{.*}}, parser.getBuilder().getI1Type())
+// CHECK-LABEL: CustomStringLiteralA::print
+// CHECK: printFoo({{.*}}, parser.getBuilder().getI1Type())
+def CustomStringLiteralA : TestFormat_Op<[{
+  custom<Foo>("$_builder.getI1Type()") attr-dict
+}]>;
+
+// CHECK-LABEL: CustomStringLiteralB::parse
+// CHECK: parseFoo({{.*}}, IndexType::get(parser.getContext()))
+// CHECK-LABEL: CustomStringLiteralB::print
+// CHECK: printFoo({{.*}}, IndexType::get(parser.getContext()))
+def CustomStringLiteralB : TestFormat_Op<[{
+  custom<Foo>("IndexType::get($_ctxt)") attr-dict
+}]>;
+
+// CHECK-LABEL: CustomStringLiteralC::parse
+// CHECK: parseFoo({{.*}}, parser.getBuilder().getStringAttr("foo"))
+// CHECK-LABEL: CustomStringLiteralC::print
+// CHECK: printFoo({{.*}}, parser.getBuilder().getStringAttr("foo"))
+def CustomStringLiteralC : TestFormat_Op<[{
+  custom<Foo>("$_builder.getStringAttr(\"foo\")") attr-dict
+}]>;

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 752556fc129cb..c249e23e531e5 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -629,14 +629,12 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
   os.indent();
   for (FormatElement *arg : el->getArguments()) {
     os << ",\n";
-    FormatElement *param;
-    if (auto *ref = dyn_cast<RefDirective>(arg)) {
-      os << "*";
-      param = ref->getArg();
-    } else {
-      param = arg;
-    }
-    os << "_result_" << cast<ParameterElement>(param)->getName();
+    if (auto *param = dyn_cast<ParameterElement>(arg))
+      os << "_result_" << param->getName();
+    else if (auto *ref = dyn_cast<RefDirective>(arg))
+      os << "*_result_" << cast<ParameterElement>(ref->getArg())->getName();
+    else
+      os << tgfmt(cast<StringElement>(arg)->getValue(), &ctx);
   }
   os.unindent() << ");\n";
   os << "if (::mlir::failed(odsCustomResult)) return {};\n";
@@ -845,11 +843,15 @@ void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
   os << tgfmt("print$0($_printer", &ctx, el->getName());
   os.indent();
   for (FormatElement *arg : el->getArguments()) {
-    FormatElement *param = arg;
-    if (auto *ref = dyn_cast<RefDirective>(arg))
-      param = ref->getArg();
-    os << ",\n"
-       << cast<ParameterElement>(param)->getParam().getAccessorName() << "()";
+    os << ",\n";
+    if (auto *param = dyn_cast<ParameterElement>(arg)) {
+      os << param->getParam().getAccessorName() << "()";
+    } else if (auto *ref = dyn_cast<RefDirective>(arg)) {
+      os << cast<ParameterElement>(ref->getArg())->getParam().getAccessorName()
+         << "()";
+    } else {
+      os << tgfmt(cast<StringElement>(arg)->getValue(), &ctx);
+    }
   }
   os.unindent() << ");\n";
 }

diff  --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp
index 8d08340800c91..029293265a9eb 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/FormatGen.cpp
@@ -129,6 +129,8 @@ FormatToken FormatLexer::lexToken() {
     return lexLiteral(tokStart);
   case '$':
     return lexVariable(tokStart);
+  case '"':
+    return lexString(tokStart);
   }
 }
 
@@ -153,6 +155,17 @@ FormatToken FormatLexer::lexVariable(const char *tokStart) {
   return formToken(FormatToken::variable, tokStart);
 }
 
+FormatToken FormatLexer::lexString(const char *tokStart) {
+  // Lex until another quote, respecting escapes.
+  bool escape = false;
+  while (const char curChar = *curPtr++) {
+    if (!escape && curChar == '"')
+      return formToken(FormatToken::string, tokStart);
+    escape = curChar == '\\';
+  }
+  return emitError(curPtr - 1, "unexpected end of file in string");
+}
+
 FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
   // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
   while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
@@ -212,6 +225,8 @@ FailureOr<std::vector<FormatElement *>> FormatParser::parse() {
 FailureOr<FormatElement *> FormatParser::parseElement(Context ctx) {
   if (curToken.is(FormatToken::literal))
     return parseLiteral(ctx);
+  if (curToken.is(FormatToken::string))
+    return parseString(ctx);
   if (curToken.is(FormatToken::variable))
     return parseVariable(ctx);
   if (curToken.isKeyword())
@@ -253,6 +268,28 @@ FailureOr<FormatElement *> FormatParser::parseLiteral(Context ctx) {
   return create<LiteralElement>(value);
 }
 
+FailureOr<FormatElement *> FormatParser::parseString(Context ctx) {
+  FormatToken tok = curToken;
+  SMLoc loc = tok.getLoc();
+  consumeToken();
+
+  if (ctx != CustomDirectiveContext) {
+    return emitError(
+        loc, "strings may only be used as 'custom' directive arguments");
+  }
+  // Escape the string.
+  std::string value;
+  StringRef contents = tok.getSpelling().drop_front().drop_back();
+  value.reserve(contents.size());
+  bool escape = false;
+  for (char c : contents) {
+    escape = c == '\\';
+    if (!escape)
+      value.push_back(c);
+  }
+  return create<StringElement>(std::move(value));
+}
+
 FailureOr<FormatElement *> FormatParser::parseVariable(Context ctx) {
   FormatToken tok = curToken;
   SMLoc loc = tok.getLoc();

diff  --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h
index f180f2da48e8d..cc57ff9ee8719 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.h
+++ b/mlir/tools/mlir-tblgen/FormatGen.h
@@ -78,6 +78,7 @@ class FormatToken {
     identifier,
     literal,
     variable,
+    string,
   };
 
   FormatToken(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
@@ -130,10 +131,11 @@ class FormatLexer {
   /// Return the next character in the stream.
   int getNextChar();
 
-  /// Lex an identifier, literal, or variable.
+  /// Lex an identifier, literal, variable, or string.
   FormatToken lexIdentifier(const char *tokStart);
   FormatToken lexLiteral(const char *tokStart);
   FormatToken lexVariable(const char *tokStart);
+  FormatToken lexString(const char *tokStart);
 
   /// Create a token with the current pointer and a start pointer.
   FormatToken formToken(FormatToken::Kind kind, const char *tokStart) {
@@ -163,7 +165,7 @@ class FormatElement {
   virtual ~FormatElement();
 
   // The top-level kinds of format elements.
-  enum Kind { Literal, Variable, Whitespace, Directive, Optional };
+  enum Kind { Literal, String, Variable, Whitespace, Directive, Optional };
 
   /// Support LLVM-style RTTI.
   static bool classof(const FormatElement *el) { return true; }
@@ -212,6 +214,20 @@ class LiteralElement : public FormatElementBase<FormatElement::Literal> {
   StringRef spelling;
 };
 
+/// This class represents a raw string that can contain arbitrary C++ code.
+class StringElement : public FormatElementBase<FormatElement::String> {
+public:
+  /// Create a string element with the given contents.
+  explicit StringElement(std::string value) : value(std::move(value)) {}
+
+  /// Get the value of the string element.
+  StringRef getValue() const { return value; }
+
+private:
+  /// The contents of the string.
+  std::string value;
+};
+
 /// This class represents a variable element. A variable refers to some part of
 /// the object being parsed, e.g. an attribute or operand on an operation or a
 /// parameter on an attribute.
@@ -447,6 +463,8 @@ class FormatParser {
   FailureOr<FormatElement *> parseElement(Context ctx);
   /// Parse a literal.
   FailureOr<FormatElement *> parseLiteral(Context ctx);
+  /// Parse a string.
+  FailureOr<FormatElement *> parseString(Context ctx);
   /// Parse a variable.
   FailureOr<FormatElement *> parseVariable(Context ctx);
   /// Parse a directive.

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 55a4a8fec3672..1e03bad9bd470 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -916,6 +916,13 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
       body << llvm::formatv("{0}Type", listName);
     else
       body << formatv("{0}RawTypes[0]", listName);
+
+  } else if (auto *string = dyn_cast<StringElement>(param)) {
+    FmtContext ctx;
+    ctx.withBuilder("parser.getBuilder()");
+    ctx.addSubst("_ctxt", "parser.getContext()");
+    body << tgfmt(string->getValue(), &ctx);
+
   } else {
     llvm_unreachable("unknown custom directive parameter");
   }
@@ -1715,6 +1722,13 @@ static void genCustomDirectiveParameterPrinter(FormatElement *element,
       body << llvm::formatv("({0}() ? {0}().getType() : Type())", name);
     else
       body << name << "().getType()";
+
+  } else if (auto *string = dyn_cast<StringElement>(element)) {
+    FmtContext ctx;
+    ctx.withBuilder("parser.getBuilder()");
+    ctx.addSubst("_ctxt", "parser.getContext()");
+    body << tgfmt(string->getValue(), &ctx);
+
   } else {
     llvm_unreachable("unknown custom directive parameter");
   }
@@ -2826,8 +2840,9 @@ OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context,
 LogicalResult OpFormatParser::verifyCustomDirectiveArguments(
     SMLoc loc, ArrayRef<FormatElement *> arguments) {
   for (FormatElement *argument : arguments) {
-    if (!isa<RefDirective, TypeDirective, AttrDictDirective, AttributeVariable,
-             OperandVariable, RegionVariable, SuccessorVariable>(argument)) {
+    if (!isa<StringElement, RefDirective, TypeDirective, AttrDictDirective,
+             AttributeVariable, OperandVariable, RegionVariable,
+             SuccessorVariable>(argument)) {
       // TODO: FormatElement should have location info attached.
       return emitError(loc, "only variables and types may be used as "
                             "parameters to a custom directive");


        


More information about the Mlir-commits mailing list