[Mlir-commits] [mlir] a2ad3ec - [mlir][ods] Support string literals in `custom` directives
Jeff Niu
llvmlistbot at llvm.org
Fri Aug 12 17:55:16 PDT 2022
Author: Jeff Niu
Date: 2022-08-12T20:55:11-04:00
New Revision: a2ad3ec7ac6279370630ec05d6426c97f4cf50d6
URL: https://github.com/llvm/llvm-project/commit/a2ad3ec7ac6279370630ec05d6426c97f4cf50d6
DIFF: https://github.com/llvm/llvm-project/commit/a2ad3ec7ac6279370630ec05d6426c97f4cf50d6.diff
LOG: [mlir][ods] Support string literals in `custom` directives
This patch adds support for string literals as `custom` directive
arguments. This can be useful for re-using custom parsers and printers
when arguments have a known value. For example:
```
ParseResult parseTypedAttr(AsmParser &parser, Attribute &attr, Type type) {
return parser.parseAttribute(attr, type);
}
void printTypedAttr(AsmPrinter &printer, Attribute attr, Type type) {
return parser.printAttributeWithoutType(attr);
}
```
And in TableGen:
```
def FooOp : ... {
let arguments = (ins AnyAttr:$a);
let assemblyFormat = [{ custom<TypedAttr>($a, "$_builder.getI1Type()")
attr-dict }];
}
def BarOp : ... {
let arguments = (ins AnyAttr:$a);
let assemblyFormat = [{ custom<TypedAttr>($a, "$_builder.getIndexType()")
attr-dict }];
}
```
Instead of writing two separate sets of custom parsers and printers.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D131603
Added:
mlir/test/mlir-tblgen/op-format.td
Modified:
mlir/docs/AttributesAndTypes.md
mlir/docs/OpDefinitions.md
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/Interfaces/ViewLikeInterface.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Interfaces/ViewLikeInterface.cpp
mlir/test/mlir-tblgen/attr-or-type-format.td
mlir/test/mlir-tblgen/op-format-invalid.td
mlir/test/mlir-tblgen/op-format-spec.td
mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
mlir/tools/mlir-tblgen/FormatGen.cpp
mlir/tools/mlir-tblgen/FormatGen.h
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md
index caf0078dd2b4b..748b69fd86ee0 100644
--- a/mlir/docs/AttributesAndTypes.md
+++ b/mlir/docs/AttributesAndTypes.md
@@ -895,6 +895,19 @@ void printStringParam(AsmPrinter &printer, StringRef value);
The custom parser is considered to have failed if it returns failure or if any
bound parameters have failure values afterwards.
+A string of C++ code can be used as a `custom` directive argument. When
+generating the custom parser and printer call, the string is pasted as a
+function argument. For example, `parseBar` and `printBar` can be re-used with
+a constant integer:
+
+```tablegen
+let parameters = (ins "int":$bar);
+let assemblyFormat = [{ custom<Bar>($foo, "1") }];
+```
+
+The string is pasted verbatim but with substitutions for `$_builder` and
+`$_ctxt`. String literals can be used to parameterize custom directives.
+
### Verification
If the `genVerifyDecl` field is set, additional verification methods are
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index eea9ad79bf54f..51d1076bbceb1 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -768,9 +768,9 @@ when generating the C++ code for the format. The `UserDirective` is an
identifier used as a suffix to these two calls, i.e., `custom<MyDirective>(...)`
would result in calls to `parseMyDirective` and `printMyDirective` within the
parser and printer respectively. `Params` may be any combination of variables
-(i.e. Attribute, Operand, Successor, etc.), type directives, and `attr-dict`.
-The type directives must refer to a variable, but that variable need not also be
-a parameter to the custom directive.
+(i.e. Attribute, Operand, Successor, etc.), type directives, `attr-dict`, and
+strings of C++ code. The type directives must refer to a variable, but that
+variable need not also be a parameter to the custom directive.
The arguments to the `parse<UserDirective>` method are firstly a reference to
the `OpAsmParser`(`OpAsmParser &`), and secondly a set of output parameters
@@ -837,7 +837,16 @@ declarative parameter to `print` method argument is detailed below:
- VariadicOfVariadic: `TypeRangeRange`
* `attr-dict` Directive: `DictionaryAttr`
-When a variable is optional, the provided value may be null.
+When a variable is optional, the provided value may be null. When a variable is
+referenced in a custom directive parameter using `ref`, it is passed in by
+value. Referenced variables to `print<UserDirective>` are passed as the same as
+bound variables, but referenced variables to `parse<UserDirective>` are passed
+like to the printer.
+
+A custom directive can take a string of C++ code as a parameter. The code is
+pasted verbatim in the calls to the custom parser and printers, with the
+substitutions `$_builder` and `$_ctxt`. String literals can be used to
+parameterize custom directives.
#### Optional Groups
@@ -1462,7 +1471,7 @@ std::string stringifyMyBitEnum(MyBitEnum symbol) {
if (2u == (2u & val)) { strs.push_back("Bit1"); }
if (4u == (4u & val)) { strs.push_back("Bit2"); }
if (8u == (8u & val)) { strs.push_back("Bit3"); }
-
+
return llvm::join(strs, "|");
}
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 6380104473feb..c4b8add825491 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -45,8 +45,9 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
let results = (outs AnyTensor:$result);
let assemblyFormat = [{
- custom<OperandsOrIntegersSizesList>($sizes, $static_sizes) attr-dict
- `:` type($result)
+ custom<DynamicIndexList>($sizes, $static_sizes,
+ "ShapedType::kDynamicSize")
+ attr-dict `:` type($result)
}];
let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index d397d2911c184..1e780121c2428 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1023,11 +1023,14 @@ def MemRef_ReinterpretCastOp
let assemblyFormat = [{
$source `to` `offset` `` `:`
- custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
+ custom<DynamicIndexList>($offsets, $static_offsets,
+ "ShapedType::kDynamicStrideOrOffset")
`` `,` `sizes` `` `:`
- custom<OperandsOrIntegersSizesList>($sizes, $static_sizes) `` `,` `strides`
- `` `:`
- custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
+ custom<DynamicIndexList>($sizes, $static_sizes,
+ "ShapedType::kDynamicSize")
+ `` `,` `strides` `` `:`
+ custom<DynamicIndexList>($strides, $static_strides,
+ "ShapedType::kDynamicStrideOrOffset")
attr-dict `:` type($source) `to` type($result)
}];
@@ -1586,9 +1589,12 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
let assemblyFormat = [{
$source ``
- custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
- custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
- custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
+ custom<DynamicIndexList>($offsets, $static_offsets,
+ "ShapedType::kDynamicStrideOrOffset")
+ custom<DynamicIndexList>($sizes, $static_sizes,
+ "ShapedType::kDynamicSize")
+ custom<DynamicIndexList>($strides, $static_strides,
+ "ShapedType::kDynamicStrideOrOffset")
attr-dict `:` type($source) `to` type($result)
}];
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index e3f953f8c528c..4095d4e036a5a 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -219,11 +219,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
To disambiguate, the inference helpers `inferCanonicalRankReducedResultType`
only drop the first unit dimensions, in order:
e.g. 1x6x1 rank-reduced to 2-D will infer the 6x1 2-D shape, but not 1x6.
-
+
Verification however has access to result type and does not need to infer.
- The verifier calls `isRankReducedType(getSource(), getResult())` to
+ The verifier calls `isRankReducedType(getSource(), getResult())` to
determine whether the result type is rank-reduced from the source type.
- This computes a so-called rank-reduction mask, consisting of dropped unit
+ This computes a so-called rank-reduction mask, consisting of dropped unit
dims, to map the rank-reduced type to the source type by dropping ones:
e.g. 1x6 is a rank-reduced version of 1x6x1 by mask {2}
6x1 is a rank-reduced version of 1x6x1 by mask {0}
@@ -254,9 +254,12 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
let assemblyFormat = [{
$source ``
- custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
- custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
- custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
+ custom<DynamicIndexList>($offsets, $static_offsets,
+ "ShapedType::kDynamicStrideOrOffset")
+ custom<DynamicIndexList>($sizes, $static_sizes,
+ "ShapedType::kDynamicSize")
+ custom<DynamicIndexList>($strides, $static_strides,
+ "ShapedType::kDynamicStrideOrOffset")
attr-dict `:` type($source) `to` type($result)
}];
@@ -298,12 +301,12 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
/// tensor type to the result tensor type by dropping unit dims.
llvm::Optional<llvm::SmallDenseSet<unsigned>>
computeRankReductionMask() {
- return ::mlir::computeRankReductionMask(getSourceType().getShape(),
+ return ::mlir::computeRankReductionMask(getSourceType().getShape(),
getType().getShape());
};
/// An extract_slice result type can be inferred, when it is not
- /// rank-reduced, from the source type and the static representation of
+ /// rank-reduced, from the source type and the static representation of
/// offsets, sizes and strides. Special sentinels encode the dynamic case.
static RankedTensorType inferResultType(
ShapedType sourceShapedTensorType,
@@ -580,9 +583,12 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
let assemblyFormat = [{
$source `into` $dest ``
- custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
- custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
- custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
+ custom<DynamicIndexList>($offsets, $static_offsets,
+ "ShapedType::kDynamicStrideOrOffset")
+ custom<DynamicIndexList>($sizes, $static_sizes,
+ "ShapedType::kDynamicSize")
+ custom<DynamicIndexList>($strides, $static_strides,
+ "ShapedType::kDynamicStrideOrOffset")
attr-dict `:` type($source) `into` type($dest)
}];
@@ -608,7 +614,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
RankedTensorType getType() {
return getResult().getType().cast<RankedTensorType>();
}
-
+
/// The `dest` type is the same as the result type.
RankedTensorType getDestType() {
return getType();
@@ -962,8 +968,10 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
let assemblyFormat = [{
$source
(`nofold` $nofold^)?
- `low` `` custom<OperandsOrIntegersSizesList>($low, $static_low)
- `high` `` custom<OperandsOrIntegersSizesList>($high, $static_high)
+ `low` `` custom<DynamicIndexList>($low, $static_low,
+ "ShapedType::kDynamicSize")
+ `high` `` custom<DynamicIndexList>($high, $static_high,
+ "ShapedType::kDynamicSize")
$region attr-dict `:` type($source) `to` type($result)
}];
@@ -1069,15 +1077,15 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
// HasParent<"ParallelCombiningOpInterface">
]> {
let summary = [{
- Specify the tensor slice update of a single thread of a parent
+ Specify the tensor slice update of a single thread of a parent
ParallelCombiningOpInterface op.
}];
let description = [{
- The `parallel_insert_slice` yields a subset tensor value to its parent
+ The `parallel_insert_slice` yields a subset tensor value to its parent
ParallelCombiningOpInterface. These subset tensor values are aggregated to
- in some unspecified order into a full tensor value returned by the parent
- parallel iterating op.
- The `parallel_insert_slice` is one such op allowed in the
+ in some unspecified order into a full tensor value returned by the parent
+ parallel iterating op.
+ The `parallel_insert_slice` is one such op allowed in the
ParallelCombiningOpInterface op.
Conflicting writes result in undefined semantics, in that the indices written
@@ -1118,12 +1126,12 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
into a memref.subview op.
A parallel_insert_slice operation may additionally specify insertion into a
- tensor of higher rank than the source tensor, along dimensions that are
+ tensor of higher rank than the source tensor, along dimensions that are
statically known to be of size 1.
This rank-altering behavior is not required by the op semantics: this
flexibility allows to progressively drop unit dimensions while lowering
between
diff erent flavors of ops on that operate on tensors.
- The rank-altering behavior of tensor.parallel_insert_slice matches the
+ The rank-altering behavior of tensor.parallel_insert_slice matches the
rank-reducing behavior of tensor.insert_slice and tensor.extract_slice.
Verification in the rank-reduced case:
@@ -1144,9 +1152,12 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
);
let assemblyFormat = [{
$source `into` $dest ``
- custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
- custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
- custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
+ custom<DynamicIndexList>($offsets, $static_offsets,
+ "ShapedType::kDynamicStrideOrOffset")
+ custom<DynamicIndexList>($sizes, $static_sizes,
+ "ShapedType::kDynamicSize")
+ custom<DynamicIndexList>($strides, $static_strides,
+ "ShapedType::kDynamicStrideOrOffset")
attr-dict `:` type($source) `into` type($dest)
}];
@@ -1194,7 +1205,7 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
];
-
+
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index e590be2bbc54b..4bb13ce750607 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -84,73 +84,40 @@ namespace mlir {
/// Printer hook for custom directive in assemblyFormat.
///
-/// custom<OperandsOrIntegersOffsetsOrStridesList>($values, $integers)
+/// custom<DynamicIndexList>($values, $integers)
///
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
-/// type `I64ArrayAttr`. for use in in assemblyFormat. Prints a list with
-/// either (1) the static integer value in `integers` if the value is
-/// ShapedType::kDynamicStrideOrOffset or (2) the next value otherwise. This
-/// allows idiomatic printing of mixed value and integer attributes in a
-/// list. E.g. `[%arg0, 7, 42, %arg42]`.
-void printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &printer,
- Operation *op,
- OperandRange values,
- ArrayAttr integers);
-
-/// Printer hook for custom directive in assemblyFormat.
-///
-/// custom<OperandsOrIntegersSizesList>($values, $integers)
-///
-/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
-/// type `I64ArrayAttr`. for use in in assemblyFormat. Prints a list with
-/// either (1) the static integer value in `integers` if the value is
-/// ShapedType::kDynamicSize or (2) the next value otherwise. This
-/// allows idiomatic printing of mixed value and integer attributes in a
-/// list. E.g. `[%arg0, 7, 42, %arg42]`.
-void printOperandsOrIntegersSizesList(OpAsmPrinter &printer, Operation *op,
- OperandRange values, ArrayAttr integers);
-
-/// Pasrer hook for custom directive in assemblyFormat.
-///
-/// custom<OperandsOrIntegersOffsetsOrStridesList>($values, $integers)
-///
-/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
-/// type `I64ArrayAttr`. for use in in assemblyFormat. Parse a mixed list with
-/// either (1) static integer values or (2) SSA values. Fill `integers` with
-/// the integer ArrayAttr, where ShapedType::kDynamicStrideOrOffset encodes the
-/// position of SSA values. Add the parsed SSA values to `values` in-order.
-//
-/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
-/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
-/// 2. `ssa` is filled with "[%arg0, %arg1]".
-ParseResult parseOperandsOrIntegersOffsetsOrStridesList(
- OpAsmParser &parser,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- ArrayAttr &integers);
+/// type `I64ArrayAttr`. Prints a list with either (1) the static integer value
+/// in `integers` is `dynVal` or (2) the next value otherwise. This allows
+/// idiomatic printing of mixed value and integer attributes in a list. E.g.
+/// `[%arg0, 7, 42, %arg42]`.
+void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
+ OperandRange values, ArrayAttr integers,
+ int64_t dynVal);
/// Pasrer hook for custom directive in assemblyFormat.
///
-/// custom<OperandsOrIntegersSizesList>($values, $integers)
+/// custom<DynamicIndexList>($values, $integers)
///
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
-/// type `I64ArrayAttr`. for use in in assemblyFormat. Parse a mixed list with
-/// either (1) static integer values or (2) SSA values. Fill `integers` with
-/// the integer ArrayAttr, where ShapedType::kDynamicSize encodes the
-/// position of SSA values. Add the parsed SSA values to `values` in-order.
+/// type `I64ArrayAttr`. Parse a mixed list with either (1) static integer
+/// values or (2) SSA values. Fill `integers` with the integer ArrayAttr, where
+/// `dynVal` encodes the position of SSA values. Add the parsed SSA values
+/// to `values` in-order.
//
/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
/// 2. `ssa` is filled with "[%arg0, %arg1]".
-ParseResult parseOperandsOrIntegersSizesList(
- OpAsmParser &parser,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- ArrayAttr &integers);
+ParseResult
+parseDynamicIndexList(OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ ArrayAttr &integers, int64_t dynVal);
/// Verify that a the `values` has as many elements as the number of entries in
/// `attr` for which `isDynamic` evaluates to true.
LogicalResult verifyListOfOperandsOrIntegers(
Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr,
- ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic);
+ ValueRange values, function_ref<bool(int64_t)> isDynamic);
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 328c4141e50fa..be4f202cc0b39 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -987,7 +987,8 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
auto pdlOperationType = pdl::OperationType::get(parser.getContext());
if (parser.parseOperand(target) ||
parser.resolveOperand(target, pdlOperationType, result.operands) ||
- parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) ||
+ parseDynamicIndexList(parser, dynamicSizes, staticSizes,
+ ShapedType::kDynamicSize) ||
parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
parser.parseOptionalAttrDict(result.attributes))
return ParseResult::failure();
@@ -1001,8 +1002,8 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
void TileOp::print(OpAsmPrinter &p) {
p << ' ' << getTarget();
- printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(),
- getStaticSizes());
+ printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
+ ShapedType::kDynamicSize);
p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
}
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index a593ec6a2a240..89ebd81271721 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -70,45 +70,29 @@ mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
return success();
}
-template <int64_t dynVal>
-static void printOperandsOrIntegersListImpl(OpAsmPrinter &p, ValueRange values,
- ArrayAttr arrayAttr) {
- p << '[';
- if (arrayAttr.empty()) {
- p << "]";
+void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
+ OperandRange values, ArrayAttr integers,
+ int64_t dynVal) {
+ printer << '[';
+ if (integers.empty()) {
+ printer << "]";
return;
}
unsigned idx = 0;
- llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
+ llvm::interleaveComma(integers, printer, [&](Attribute a) {
int64_t val = a.cast<IntegerAttr>().getInt();
if (val == dynVal)
- p << values[idx++];
+ printer << values[idx++];
else
- p << val;
+ printer << val;
});
- p << ']';
+ printer << ']';
}
-void mlir::printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &p,
- Operation *op,
- OperandRange values,
- ArrayAttr integers) {
- return printOperandsOrIntegersListImpl<ShapedType::kDynamicStrideOrOffset>(
- p, values, integers);
-}
-
-void mlir::printOperandsOrIntegersSizesList(OpAsmPrinter &p, Operation *op,
- OperandRange values,
- ArrayAttr integers) {
- return printOperandsOrIntegersListImpl<ShapedType::kDynamicSize>(p, values,
- integers);
-}
-
-template <int64_t dynVal>
-static ParseResult parseOperandsOrIntegersImpl(
+ParseResult mlir::parseDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- ArrayAttr &integers) {
+ ArrayAttr &integers, int64_t dynVal) {
if (failed(parser.parseLSquare()))
return failure();
// 0-D.
@@ -142,22 +126,6 @@ static ParseResult parseOperandsOrIntegersImpl(
return success();
}
-ParseResult mlir::parseOperandsOrIntegersOffsetsOrStridesList(
- OpAsmParser &parser,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- ArrayAttr &integers) {
- return parseOperandsOrIntegersImpl<ShapedType::kDynamicStrideOrOffset>(
- parser, values, integers);
-}
-
-ParseResult mlir::parseOperandsOrIntegersSizesList(
- OpAsmParser &parser,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- ArrayAttr &integers) {
- return parseOperandsOrIntegersImpl<ShapedType::kDynamicSize>(parser, values,
- integers);
-}
-
bool mlir::detail::sameOffsetsSizesAndStrides(
OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) {
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index 960c3870a9c0d..937c35d503769 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -555,3 +555,19 @@ def TypeK : TestType<"TestM"> {
let mnemonic = "type_k";
let assemblyFormat = "$a";
}
+
+// TYPE-LABEL: ::mlir::Type TestNType::parse
+// TYPE: parseFoo(
+// TYPE-NEXT: _result_a,
+// TYPE-NEXT: 1);
+
+// TYPE-LABEL: void TestNType::print
+// TYPE: printFoo(
+// TYPE-NEXT: getA(),
+// TYPE-NEXT: 1);
+
+def TypeL : TestType<"TestN"> {
+ let parameters = (ins "int":$a);
+ let mnemonic = "type_l";
+ let assemblyFormat = [{ custom<Foo>($a, "1") }];
+}
diff --git a/mlir/test/mlir-tblgen/op-format-invalid.td b/mlir/test/mlir-tblgen/op-format-invalid.td
index d165587e079ca..aae9ed45b92d1 100644
--- a/mlir/test/mlir-tblgen/op-format-invalid.td
+++ b/mlir/test/mlir-tblgen/op-format-invalid.td
@@ -403,6 +403,13 @@ def OptionalInvalidP : TestFormat_Op<[{
($arg^):(`test`)
}]>, Arguments<(ins Variadic<I64>:$arg)>;
+//===----------------------------------------------------------------------===//
+// Strings
+//===----------------------------------------------------------------------===//
+
+// CHECK: error: strings may only be used as 'custom' directive arguments
+def StringInvalidA : TestFormat_Op<[{ "foo" }]>;
+
//===----------------------------------------------------------------------===//
// Variables
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 553c996190a79..6d21afdfe6de7 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -135,6 +135,13 @@ def OptionalValidA : TestFormat_Op<[{
(` ` `` $arg^)? attr-dict
}]>, Arguments<(ins Optional<I32>:$arg)>;
+//===----------------------------------------------------------------------===//
+// Strings
+//===----------------------------------------------------------------------===//
+
+// CHECK-NOT: error
+def StringInvalidA : TestFormat_Op<[{ custom<Foo>("foo") attr-dict }]>;
+
//===----------------------------------------------------------------------===//
// Variables
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/op-format.td b/mlir/test/mlir-tblgen/op-format.td
new file mode 100644
index 0000000000000..1fdb485e5e56c
--- /dev/null
+++ b/mlir/test/mlir-tblgen/op-format.td
@@ -0,0 +1,42 @@
+// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+
+def TestDialect : Dialect {
+ let name = "test";
+}
+class TestFormat_Op<string fmt, list<Trait> traits = []>
+ : Op<TestDialect, "format_op", traits> {
+ let assemblyFormat = fmt;
+}
+
+//===----------------------------------------------------------------------===//
+// Directives
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// custom
+
+// CHECK-LABEL: CustomStringLiteralA::parse
+// CHECK: parseFoo({{.*}}, parser.getBuilder().getI1Type())
+// CHECK-LABEL: CustomStringLiteralA::print
+// CHECK: printFoo({{.*}}, parser.getBuilder().getI1Type())
+def CustomStringLiteralA : TestFormat_Op<[{
+ custom<Foo>("$_builder.getI1Type()") attr-dict
+}]>;
+
+// CHECK-LABEL: CustomStringLiteralB::parse
+// CHECK: parseFoo({{.*}}, IndexType::get(parser.getContext()))
+// CHECK-LABEL: CustomStringLiteralB::print
+// CHECK: printFoo({{.*}}, IndexType::get(parser.getContext()))
+def CustomStringLiteralB : TestFormat_Op<[{
+ custom<Foo>("IndexType::get($_ctxt)") attr-dict
+}]>;
+
+// CHECK-LABEL: CustomStringLiteralC::parse
+// CHECK: parseFoo({{.*}}, parser.getBuilder().getStringAttr("foo"))
+// CHECK-LABEL: CustomStringLiteralC::print
+// CHECK: printFoo({{.*}}, parser.getBuilder().getStringAttr("foo"))
+def CustomStringLiteralC : TestFormat_Op<[{
+ custom<Foo>("$_builder.getStringAttr(\"foo\")") attr-dict
+}]>;
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 752556fc129cb..c249e23e531e5 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -629,14 +629,12 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
os.indent();
for (FormatElement *arg : el->getArguments()) {
os << ",\n";
- FormatElement *param;
- if (auto *ref = dyn_cast<RefDirective>(arg)) {
- os << "*";
- param = ref->getArg();
- } else {
- param = arg;
- }
- os << "_result_" << cast<ParameterElement>(param)->getName();
+ if (auto *param = dyn_cast<ParameterElement>(arg))
+ os << "_result_" << param->getName();
+ else if (auto *ref = dyn_cast<RefDirective>(arg))
+ os << "*_result_" << cast<ParameterElement>(ref->getArg())->getName();
+ else
+ os << tgfmt(cast<StringElement>(arg)->getValue(), &ctx);
}
os.unindent() << ");\n";
os << "if (::mlir::failed(odsCustomResult)) return {};\n";
@@ -845,11 +843,15 @@ void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
os << tgfmt("print$0($_printer", &ctx, el->getName());
os.indent();
for (FormatElement *arg : el->getArguments()) {
- FormatElement *param = arg;
- if (auto *ref = dyn_cast<RefDirective>(arg))
- param = ref->getArg();
- os << ",\n"
- << cast<ParameterElement>(param)->getParam().getAccessorName() << "()";
+ os << ",\n";
+ if (auto *param = dyn_cast<ParameterElement>(arg)) {
+ os << param->getParam().getAccessorName() << "()";
+ } else if (auto *ref = dyn_cast<RefDirective>(arg)) {
+ os << cast<ParameterElement>(ref->getArg())->getParam().getAccessorName()
+ << "()";
+ } else {
+ os << tgfmt(cast<StringElement>(arg)->getValue(), &ctx);
+ }
}
os.unindent() << ");\n";
}
diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp
index 8d08340800c91..029293265a9eb 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/FormatGen.cpp
@@ -129,6 +129,8 @@ FormatToken FormatLexer::lexToken() {
return lexLiteral(tokStart);
case '$':
return lexVariable(tokStart);
+ case '"':
+ return lexString(tokStart);
}
}
@@ -153,6 +155,17 @@ FormatToken FormatLexer::lexVariable(const char *tokStart) {
return formToken(FormatToken::variable, tokStart);
}
+FormatToken FormatLexer::lexString(const char *tokStart) {
+ // Lex until another quote, respecting escapes.
+ bool escape = false;
+ while (const char curChar = *curPtr++) {
+ if (!escape && curChar == '"')
+ return formToken(FormatToken::string, tokStart);
+ escape = curChar == '\\';
+ }
+ return emitError(curPtr - 1, "unexpected end of file in string");
+}
+
FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
// Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
@@ -212,6 +225,8 @@ FailureOr<std::vector<FormatElement *>> FormatParser::parse() {
FailureOr<FormatElement *> FormatParser::parseElement(Context ctx) {
if (curToken.is(FormatToken::literal))
return parseLiteral(ctx);
+ if (curToken.is(FormatToken::string))
+ return parseString(ctx);
if (curToken.is(FormatToken::variable))
return parseVariable(ctx);
if (curToken.isKeyword())
@@ -253,6 +268,28 @@ FailureOr<FormatElement *> FormatParser::parseLiteral(Context ctx) {
return create<LiteralElement>(value);
}
+FailureOr<FormatElement *> FormatParser::parseString(Context ctx) {
+ FormatToken tok = curToken;
+ SMLoc loc = tok.getLoc();
+ consumeToken();
+
+ if (ctx != CustomDirectiveContext) {
+ return emitError(
+ loc, "strings may only be used as 'custom' directive arguments");
+ }
+ // Escape the string.
+ std::string value;
+ StringRef contents = tok.getSpelling().drop_front().drop_back();
+ value.reserve(contents.size());
+ bool escape = false;
+ for (char c : contents) {
+ escape = c == '\\';
+ if (!escape)
+ value.push_back(c);
+ }
+ return create<StringElement>(std::move(value));
+}
+
FailureOr<FormatElement *> FormatParser::parseVariable(Context ctx) {
FormatToken tok = curToken;
SMLoc loc = tok.getLoc();
diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h
index f180f2da48e8d..cc57ff9ee8719 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.h
+++ b/mlir/tools/mlir-tblgen/FormatGen.h
@@ -78,6 +78,7 @@ class FormatToken {
identifier,
literal,
variable,
+ string,
};
FormatToken(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
@@ -130,10 +131,11 @@ class FormatLexer {
/// Return the next character in the stream.
int getNextChar();
- /// Lex an identifier, literal, or variable.
+ /// Lex an identifier, literal, variable, or string.
FormatToken lexIdentifier(const char *tokStart);
FormatToken lexLiteral(const char *tokStart);
FormatToken lexVariable(const char *tokStart);
+ FormatToken lexString(const char *tokStart);
/// Create a token with the current pointer and a start pointer.
FormatToken formToken(FormatToken::Kind kind, const char *tokStart) {
@@ -163,7 +165,7 @@ class FormatElement {
virtual ~FormatElement();
// The top-level kinds of format elements.
- enum Kind { Literal, Variable, Whitespace, Directive, Optional };
+ enum Kind { Literal, String, Variable, Whitespace, Directive, Optional };
/// Support LLVM-style RTTI.
static bool classof(const FormatElement *el) { return true; }
@@ -212,6 +214,20 @@ class LiteralElement : public FormatElementBase<FormatElement::Literal> {
StringRef spelling;
};
+/// This class represents a raw string that can contain arbitrary C++ code.
+class StringElement : public FormatElementBase<FormatElement::String> {
+public:
+ /// Create a string element with the given contents.
+ explicit StringElement(std::string value) : value(std::move(value)) {}
+
+ /// Get the value of the string element.
+ StringRef getValue() const { return value; }
+
+private:
+ /// The contents of the string.
+ std::string value;
+};
+
/// This class represents a variable element. A variable refers to some part of
/// the object being parsed, e.g. an attribute or operand on an operation or a
/// parameter on an attribute.
@@ -447,6 +463,8 @@ class FormatParser {
FailureOr<FormatElement *> parseElement(Context ctx);
/// Parse a literal.
FailureOr<FormatElement *> parseLiteral(Context ctx);
+ /// Parse a string.
+ FailureOr<FormatElement *> parseString(Context ctx);
/// Parse a variable.
FailureOr<FormatElement *> parseVariable(Context ctx);
/// Parse a directive.
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 55a4a8fec3672..1e03bad9bd470 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -916,6 +916,13 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
body << llvm::formatv("{0}Type", listName);
else
body << formatv("{0}RawTypes[0]", listName);
+
+ } else if (auto *string = dyn_cast<StringElement>(param)) {
+ FmtContext ctx;
+ ctx.withBuilder("parser.getBuilder()");
+ ctx.addSubst("_ctxt", "parser.getContext()");
+ body << tgfmt(string->getValue(), &ctx);
+
} else {
llvm_unreachable("unknown custom directive parameter");
}
@@ -1715,6 +1722,13 @@ static void genCustomDirectiveParameterPrinter(FormatElement *element,
body << llvm::formatv("({0}() ? {0}().getType() : Type())", name);
else
body << name << "().getType()";
+
+ } else if (auto *string = dyn_cast<StringElement>(element)) {
+ FmtContext ctx;
+ ctx.withBuilder("parser.getBuilder()");
+ ctx.addSubst("_ctxt", "parser.getContext()");
+ body << tgfmt(string->getValue(), &ctx);
+
} else {
llvm_unreachable("unknown custom directive parameter");
}
@@ -2826,8 +2840,9 @@ OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context,
LogicalResult OpFormatParser::verifyCustomDirectiveArguments(
SMLoc loc, ArrayRef<FormatElement *> arguments) {
for (FormatElement *argument : arguments) {
- if (!isa<RefDirective, TypeDirective, AttrDictDirective, AttributeVariable,
- OperandVariable, RegionVariable, SuccessorVariable>(argument)) {
+ if (!isa<StringElement, RefDirective, TypeDirective, AttrDictDirective,
+ AttributeVariable, OperandVariable, RegionVariable,
+ SuccessorVariable>(argument)) {
// TODO: FormatElement should have location info attached.
return emitError(loc, "only variables and types may be used as "
"parameters to a custom directive");
More information about the Mlir-commits
mailing list