[Mlir-commits] [mlir] 342d466 - [mlir] Add custom directive hooks for printing mixed integer or value operands.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 1 19:04:11 PST 2021
Author: MaheshRavishankar
Date: 2021-02-01T19:03:49-08:00
New Revision: 342d4662e1c930bd0a856c6c88d0cde5f106cc81
URL: https://github.com/llvm/llvm-project/commit/342d4662e1c930bd0a856c6c88d0cde5f106cc81
DIFF: https://github.com/llvm/llvm-project/commit/342d4662e1c930bd0a856c6c88d0cde5f106cc81.diff
LOG: [mlir] Add custom directive hooks for printing mixed integer or value operands.
Add printer and parser hooks for a custom directive that allows
parsing and printing of idioms that can represent a list of values
each of which is either an integer or an SSA value. For example in
`subview %source[%offset_0, 1] [4, %size_1] [%stride_0, 3]`
each of the list (which represents offset, size and strides) is a mix
of either statically know integer values or dynamically computed SSA
values. Since this is used in many places adding a custom directive to
parse/print this idiom allows using assembly format on operations
which use this idiom.
Differential Revision: https://reviews.llvm.org/D95773
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/Interfaces/ViewLikeInterface.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Interfaces/ViewLikeInterface.cpp
mlir/test/Dialect/Linalg/invalid.mlir
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 4f2a8afcdbc8..3185502b1879 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -45,6 +45,11 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> {
let results = (outs AnyTensor:$result);
+ let assemblyFormat = [{
+ custom<OperandsOrIntegersSizesList>($sizes, $static_sizes) attr-dict
+ `:` type($result)
+ }];
+
let verifier = [{ return ::verify(*this); }];
let extraClassDeclaration = [{
@@ -118,7 +123,7 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> {
}
def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
- [AttrSizedOperandSegments, SingleBlockImplicitTerminator<"YieldOp">]> {
+ [AttrSizedOperandSegments]> {
let summary = "tensor pad operation";
let description = [{
`linalg.pad_tensor` is an operation that pads the `source` tensor
@@ -181,10 +186,16 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
I64ArrayAttr:$static_low,
I64ArrayAttr:$static_high);
- let regions = (region AnyRegion:$region);
+ let regions = (region SizedRegion<1>:$region);
let results = (outs AnyTensor:$result);
+ let assemblyFormat = [{
+ $source `low` `` custom<OperandsOrIntegersSizesList>($low, $static_low)
+ `high` `` custom<OperandsOrIntegersSizesList>($high, $static_high)
+ $region attr-dict `:` type($source) `to` type($result)
+ }];
+
let extraClassDeclaration = [{
static StringRef getStaticLowAttrName() {
return "static_low";
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 180cec0c4091..516618502357 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1956,6 +1956,19 @@ def MemRefReinterpretCastOp:
);
let results = (outs AnyMemRef:$result);
+ let assemblyFormat = [{
+ $source `to` `offset` `` `:`
+ custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
+ `` `,` `sizes` `` `:`
+ custom<OperandsOrIntegersSizesList>($sizes, $static_sizes) `` `,` `strides`
+ `` `:`
+ custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
+ attr-dict `:` type($source) `to` type($result)
+ }];
+
+ let parser=?;
+ let printer=?;
+
let builders = [
// Build a ReinterpretCastOp with mixed static and dynamic entries.
OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$source,
@@ -2931,6 +2944,14 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
);
let results = (outs AnyMemRef:$result);
+ let assemblyFormat = [{
+ $source ``
+ custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
+ custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
+ custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
+ attr-dict `:` type($source) `to` type($result)
+ }];
+
let builders = [
// Build a SubViewOp with mixed static and dynamic entries and custom
// result type. If the type passed is nullptr, it is inferred.
@@ -3053,6 +3074,14 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<
);
let results = (outs AnyRankedTensor:$result);
+ let assemblyFormat = [{
+ $source ``
+ custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
+ custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
+ custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
+ attr-dict `:` type($source) `to` type($result)
+ }];
+
let builders = [
// Build a SubTensorOp with mixed static and dynamic entries and inferred
// result type.
@@ -3115,7 +3144,10 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<
//===----------------------------------------------------------------------===//
def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
- "subtensor_insert", [OffsetSizeAndStrideOpInterface]> {
+ "subtensor_insert",
+ [OffsetSizeAndStrideOpInterface,
+ TypesMatchWith<"expected result type to match dest type",
+ "dest", "result", "$_self">]> {
let summary = "subtensor_insert operation";
let description = [{
The "subtensor_insert" operation insert a tensor `source` into another
@@ -3159,6 +3191,16 @@ def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
);
let results = (outs AnyRankedTensor:$result);
+ let assemblyFormat = [{
+ $source `into` $dest ``
+ custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
+ custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
+ custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
+ attr-dict `:` type($source) `into` type($dest)
+ }];
+
+ let verifier = ?;
+
let builders = [
// Build a SubTensorInsertOp with mixed static and dynamic entries.
OpBuilderDAG<(ins "Value":$source, "Value":$dest,
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 2b3a054338ab..434717a6ba88 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -36,78 +36,68 @@ LogicalResult verify(OffsetSizeAndStrideOpInterface op);
#include "mlir/Interfaces/ViewLikeInterface.h.inc"
namespace mlir {
-/// Print a list with either (1) the static integer value in `arrayAttr` if
-/// `isDynamic` evaluates to false or (2) the next value otherwise.
-/// This allows idiomatic printing of mixed value and integer attributes in a
+
+/// Printer hook for custom directive in assemblyFormat.
+///
+/// custom<OperandsOrIntegersOffsetsOrStridesList>($values, $integers)
+///
+/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
+/// type `I64ArrayAttr`. for use in in assemblyFormat. Prints a list with
+/// either (1) the static integer value in `integers` if the value is
+/// ShapedType::kDynamicStrideOrOffset or (2) the next value otherwise. This
+/// allows idiomatic printing of mixed value and integer attributes in a
/// list. E.g. `[%arg0, 7, 42, %arg42]`.
-void printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values,
- ArrayAttr arrayAttr,
- llvm::function_ref<bool(int64_t)> isDynamic);
+void printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &printer,
+ Operation *op,
+ OperandRange values,
+ ArrayAttr integers);
-/// Print part of an op of the form:
-/// ```
-/// <optional-offset-prefix>`[` offset-list `]`
-/// <optional-size-prefix>`[` size-list `]`
-/// <optional-stride-prefix>[` stride-list `]`
-/// ```
-void printOffsetsSizesAndStrides(
- OpAsmPrinter &p, OffsetSizeAndStrideOpInterface op,
- StringRef offsetPrefix = "", StringRef sizePrefix = " ",
- StringRef stridePrefix = " ",
- ArrayRef<StringRef> elidedAttrs =
- OffsetSizeAndStrideOpInterface::getSpecialAttrNames());
+/// Printer hook for custom directive in assemblyFormat.
+///
+/// custom<OperandsOrIntegersSizesList>($values, $integers)
+///
+/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
+/// type `I64ArrayAttr`. for use in in assemblyFormat. Prints a list with
+/// either (1) the static integer value in `integers` if the value is
+/// ShapedType::kDynamicSize or (2) the next value otherwise. This
+/// allows idiomatic printing of mixed value and integer attributes in a
+/// list. E.g. `[%arg0, 7, 42, %arg42]`.
+void printOperandsOrIntegersSizesList(OpAsmPrinter &printer, Operation *op,
+ OperandRange values, ArrayAttr integers);
-/// Parse a mixed list with either (1) static integer values or (2) SSA values.
-/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
-/// encode the position of SSA values. Add the parsed SSA values to `ssa`
-/// in-order.
+/// Pasrer hook for custom directive in assemblyFormat.
+///
+/// custom<OperandsOrIntegersOffsetsOrStridesList>($values, $integers)
+///
+/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
+/// type `I64ArrayAttr`. for use in in assemblyFormat. Parse a mixed list with
+/// either (1) static integer values or (2) SSA values. Fill `integers` with
+/// the integer ArrayAttr, where ShapedType::kDynamicStrideOrOffset encodes the
+/// position of SSA values. Add the parsed SSA values to `values` in-order.
//
/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
/// 2. `ssa` is filled with "[%arg0, %arg1]".
-ParseResult
-parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
- StringRef attrName, int64_t dynVal,
- SmallVectorImpl<OpAsmParser::OperandType> &ssa);
+ParseResult parseOperandsOrIntegersOffsetsOrStridesList(
+ OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
+ ArrayAttr &integers);
-/// Parse trailing part of an op of the form:
-/// ```
-/// <optional-offset-prefix>`[` offset-list `]`
-/// <optional-size-prefix>`[` size-list `]`
-/// <optional-stride-prefix>[` stride-list `]`
-/// ```
-/// Each entry in the offset, size and stride list either resolves to an integer
-/// constant or an operand of index type.
-/// Constants are added to the `result` as named integer array attributes with
-/// name `OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName()` (resp.
-/// `getStaticSizesAttrName()`, `getStaticStridesAttrName()`).
+/// Pasrer hook for custom directive in assemblyFormat.
///
-/// Append the number of offset, size and stride operands to `segmentSizes`
-/// before adding it to `result` as the named attribute:
-/// `OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr()`.
+/// custom<OperandsOrIntegersSizesList>($values, $integers)
///
-/// Offset, size and stride operands resolution occurs after `preResolutionFn`
-/// to give a chance to leading operands to resolve first, after parsing the
-/// types.
-ParseResult parseOffsetsSizesAndStrides(
- OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
- llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)>
- preResolutionFn = nullptr,
- llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix =
- nullptr,
- llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix =
- nullptr,
- llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix =
- nullptr);
-/// `preResolutionFn`-less version of `parseOffsetsSizesAndStrides`.
-ParseResult parseOffsetsSizesAndStrides(
- OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
- llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix =
- nullptr,
- llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix =
- nullptr,
- llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix =
- nullptr);
+/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
+/// type `I64ArrayAttr`. for use in in assemblyFormat. Parse a mixed list with
+/// either (1) static integer values or (2) SSA values. Fill `integers` with
+/// the integer ArrayAttr, where ShapedType::kDynamicSize encodes the
+/// position of SSA values. Add the parsed SSA values to `values` in-order.
+//
+/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
+/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
+/// 2. `ssa` is filled with "[%arg0, %arg1]".
+ParseResult parseOperandsOrIntegersSizesList(
+ OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
+ ArrayAttr &integers);
/// Verify that a the `values` has as many elements as the number of entries in
/// `attr` for which `isDynamic` evaluates to true.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 21c9108985e6..7d2685f8166a 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -676,30 +676,6 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
// InitTensorOp
//===----------------------------------------------------------------------===//
-static ParseResult parseInitTensorOp(OpAsmParser &parser,
- OperationState &result) {
- OpAsmParser::OperandType srcInfo;
- Type dstType;
- SmallVector<OpAsmParser::OperandType, 2> sizeInfo;
- IndexType indexType = parser.getBuilder().getIndexType();
- if (failed(parseListOfOperandsOrIntegers(
- parser, result, InitTensorOp::getStaticSizesAttrName(),
- ShapedType::kDynamicSize, sizeInfo)) ||
- failed(parser.parseOptionalAttrDict(result.attributes)) ||
- failed(parser.parseColonType(dstType)) ||
- failed(parser.resolveOperands(sizeInfo, indexType, result.operands)))
- return failure();
- return parser.addTypeToList(dstType, result.types);
-}
-
-static void print(OpAsmPrinter &p, InitTensorOp op) {
- p << op.getOperation()->getName() << ' ';
- printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
- ShapedType::isDynamic);
- p.printOptionalAttrDict(op.getAttrs(),
- InitTensorOp::getStaticSizesAttrName());
- p << " : " << op.getType();
-}
static LogicalResult verify(InitTensorOp op) {
RankedTensorType resultType = op.getType();
@@ -981,8 +957,6 @@ static LogicalResult verify(PadTensorOp op) {
}
auto ®ion = op.region();
- if (!llvm::hasSingleElement(region))
- return op.emitOpError("expected region with 1 block");
unsigned rank = resultType.getRank();
Block &block = region.front();
if (block.getNumArguments() != rank)
@@ -1020,67 +994,6 @@ RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType,
return RankedTensorType::get(resultShape, sourceType.getElementType());
}
-static ParseResult parsePadTensorOp(OpAsmParser &parser,
- OperationState &result) {
- OpAsmParser::OperandType baseInfo;
- SmallVector<OpAsmParser::OperandType, 8> operands;
- SmallVector<Type, 8> types;
- if (parser.parseOperand(baseInfo))
- return failure();
-
- IndexType indexType = parser.getBuilder().getIndexType();
- SmallVector<OpAsmParser::OperandType, 4> lowPadding, highPadding;
- if (parser.parseKeyword("low") ||
- parseListOfOperandsOrIntegers(parser, result,
- PadTensorOp::getStaticLowAttrName(),
- ShapedType::kDynamicSize, lowPadding))
- return failure();
- if (parser.parseKeyword("high") ||
- parseListOfOperandsOrIntegers(parser, result,
- PadTensorOp::getStaticHighAttrName(),
- ShapedType::kDynamicSize, highPadding))
- return failure();
-
- SmallVector<OpAsmParser::OperandType, 8> regionOperands;
- std::unique_ptr<Region> region = std::make_unique<Region>();
- SmallVector<Type, 8> operandTypes, regionTypes;
- if (parser.parseRegion(*region, regionOperands, regionTypes))
- return failure();
- result.addRegion(std::move(region));
-
- Type srcType, dstType;
- if (parser.parseColonType(srcType) || parser.parseKeywordType("to", dstType))
- return failure();
-
- if (parser.addTypeToList(dstType, result.types))
- return failure();
-
- SmallVector<int, 4> segmentSizesFinal = {1}; // source tensor
- segmentSizesFinal.append({static_cast<int>(lowPadding.size()),
- static_cast<int>(highPadding.size())});
- result.addAttribute(
- OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(),
- parser.getBuilder().getI32VectorAttr(segmentSizesFinal));
- return failure(
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.resolveOperand(baseInfo, srcType, result.operands) ||
- parser.resolveOperands(lowPadding, indexType, result.operands) ||
- parser.resolveOperands(highPadding, indexType, result.operands));
-}
-
-static void print(OpAsmPrinter &p, PadTensorOp op) {
- p << op->getName().getStringRef() << ' ';
- p << op.source();
- p << " low";
- printListOfOperandsOrIntegers(p, op.low(), op.static_low(),
- ShapedType::isDynamic);
- p << " high";
- printListOfOperandsOrIntegers(p, op.high(), op.static_high(),
- ShapedType::isDynamic);
- p.printRegion(op.region());
- p << " : " << op.source().getType() << " to " << op.getType();
-}
-
/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
/// it is a Value or into `staticVec` if it is an IntegerAttr.
/// In the case of a Value, a copy of the `sentinel` value is also pushed to
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 7226d89f835f..04f4b1083e4a 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2148,67 +2148,6 @@ void mlir::MemRefReinterpretCastOp::build(OpBuilder &b, OperationState &result,
build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
}
-/// Print a memref_reinterpret_cast op of the form:
-/// ```
-/// `memref_reinterpret_cast` ssa-name to
-/// offset: `[` offset `]`
-/// sizes: `[` size-list `]`
-/// strides:`[` stride-list `]`
-/// `:` any-memref-type to strided-memref-type
-/// ```
-static void print(OpAsmPrinter &p, MemRefReinterpretCastOp op) {
- int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
- p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
- p << op.source() << " ";
- printOffsetsSizesAndStrides(
- p, op, /*offsetPrefix=*/"to offset: ", /*sizePrefix=*/", sizes: ",
- /*stridePrefix=*/", strides: ");
- p << ": " << op.source().getType() << " to " << op.getType();
-}
-
-/// Parse a memref_reinterpret_cast op of the form:
-/// ```
-/// `memref_reinterpret_cast` ssa-name to
-/// offset: `[` offset `]`
-/// sizes: `[` size-list `]`
-/// strides:`[` stride-list `]`
-/// `:` any-memref-type to strided-memref-type
-/// ```
-static ParseResult parseMemRefReinterpretCastOp(OpAsmParser &parser,
- OperationState &result) {
- // Parse `operand`
- OpAsmParser::OperandType srcInfo;
- if (parser.parseOperand(srcInfo))
- return failure();
-
- auto parseOffsetPrefix = [](OpAsmParser &parser) {
- return failure(parser.parseKeyword("to") || parser.parseKeyword("offset") ||
- parser.parseColon());
- };
- auto parseSizePrefix = [](OpAsmParser &parser) {
- return failure(parser.parseComma() || parser.parseKeyword("sizes") ||
- parser.parseColon());
- };
- auto parseStridePrefix = [](OpAsmParser &parser) {
- return failure(parser.parseComma() || parser.parseKeyword("strides") ||
- parser.parseColon());
- };
-
- Type srcType, dstType;
- auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
- return failure(parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(srcType) ||
- parser.parseKeywordType("to", dstType) ||
- parser.resolveOperand(srcInfo, srcType, result.operands));
- };
- if (failed(parseOffsetsSizesAndStrides(parser, result,
- /*segmentSizes=*/{1}, // source memref
- preResolutionFn, parseOffsetPrefix,
- parseSizePrefix, parseStridePrefix)))
- return failure();
- return parser.addTypeToList(dstType, result.types);
-}
-
// TODO: ponder whether we want to allow missing trailing sizes/strides that are
// completed automatically, like we have for subview and subtensor.
static LogicalResult verify(MemRefReinterpretCastOp op) {
@@ -2892,45 +2831,6 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
sourceMemRefType.getMemorySpace());
}
-/// Print a subview op of the form:
-/// ```
-/// `subview` ssa-name
-/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
-/// `:` strided-memref-type `to` strided-memref-type
-/// ```
-static void print(OpAsmPrinter &p, SubViewOp op) {
- int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
- p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
- p << op.source();
- printOffsetsSizesAndStrides(p, op);
- p << " : " << op.source().getType() << " to " << op.getType();
-}
-
-/// Parse a subview op of the form:
-/// ```
-/// `subview` ssa-name
-/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
-/// `:` strided-memref-type `to` strided-memref-type
-/// ```
-static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
- OpAsmParser::OperandType srcInfo;
- if (parser.parseOperand(srcInfo))
- return failure();
- Type srcType, dstType;
- auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
- return failure(parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(srcType) ||
- parser.parseKeywordType("to", dstType) ||
- parser.resolveOperand(srcInfo, srcType, result.operands));
- };
-
- if (failed(parseOffsetsSizesAndStrides(parser, result,
- /*segmentSizes=*/{1}, // source memref
- preResolutionFn)))
- return failure();
- return parser.addTypeToList(dstType, result.types);
-}
-
// Build a SubViewOp with mixed static and dynamic entries and custom result
// type. If the type passed is nullptr, it is inferred.
void mlir::SubViewOp::build(OpBuilder &b, OperationState &result,
@@ -3466,46 +3366,6 @@ OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
// SubTensorOp
//===----------------------------------------------------------------------===//
-/// Print a subtensor op of the form:
-/// ```
-/// `subtensor` ssa-name
-/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
-/// `:` ranked-tensor-type `to` ranked-tensor-type
-/// ```
-static void print(OpAsmPrinter &p, SubTensorOp op) {
- int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
- p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
- p << op.source();
- printOffsetsSizesAndStrides(p, op);
- p << " : " << op.getSourceType() << " to " << op.getType();
-}
-
-/// Parse a subtensor op of the form:
-/// ```
-/// `subtensor` ssa-name
-/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
-/// `:` ranked-tensor-type `to` ranked-tensor-type
-/// ```
-static ParseResult parseSubTensorOp(OpAsmParser &parser,
- OperationState &result) {
- OpAsmParser::OperandType srcInfo;
- if (parser.parseOperand(srcInfo))
- return failure();
- Type srcType, dstType;
- auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
- return failure(parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(srcType) ||
- parser.parseKeywordType("to", dstType) ||
- parser.resolveOperand(srcInfo, srcType, result.operands));
- };
-
- if (failed(parseOffsetsSizesAndStrides(parser, result,
- /*segmentSizes=*/{1}, // source tensor
- preResolutionFn)))
- return failure();
- return parser.addTypeToList(dstType, result.types);
-}
-
/// A subtensor result type can be fully inferred from the source type and the
/// static representation of offsets, sizes and strides. Special sentinels
/// encode the dynamic case.
@@ -3612,49 +3472,6 @@ void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
// SubTensorInsertOp
//===----------------------------------------------------------------------===//
-/// Print a subtensor_insert op of the form:
-/// ```
-/// `subtensor_insert` ssa-name `into` ssa-name
-/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
-/// `:` ranked-tensor-type `into` ranked-tensor-type
-/// ```
-static void print(OpAsmPrinter &p, SubTensorInsertOp op) {
- int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
- p << op->getName().getStringRef().drop_front(stdDotLen) << ' ';
- p << op.source() << " into " << op.dest();
- printOffsetsSizesAndStrides(p, op);
- p << " : " << op.getSourceType() << " into " << op.getType();
-}
-
-/// Parse a subtensor_insert op of the form:
-/// ```
-/// `subtensor_insert` ssa-name `into` ssa-name
-/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
-/// `:` ranked-tensor-type `into` ranked-tensor-type
-/// ```
-static ParseResult parseSubTensorInsertOp(OpAsmParser &parser,
- OperationState &result) {
- OpAsmParser::OperandType srcInfo, dstInfo;
- if (parser.parseOperand(srcInfo) || parser.parseKeyword("into") ||
- parser.parseOperand(dstInfo))
- return failure();
- Type srcType, dstType;
- auto preResolutionFn = [&](OpAsmParser &parser, OperationState &result) {
- return failure(parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(srcType) ||
- parser.parseKeywordType("into", dstType) ||
- parser.resolveOperand(srcInfo, srcType, result.operands) ||
- parser.resolveOperand(dstInfo, dstType, result.operands));
- };
-
- if (failed(parseOffsetsSizesAndStrides(
- parser, result,
- /*segmentSizes=*/{1, 1}, // source tensor, destination tensor
- preResolutionFn)))
- return failure();
- return parser.addTypeToList(dstType, result.types);
-}
-
// Build a SubTensorInsertOp with mixed static and dynamic entries.
void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result,
Value source, Value dest,
@@ -3691,13 +3508,6 @@ void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result,
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
}
-/// Verifier for SubViewOp.
-static LogicalResult verify(SubTensorInsertOp op) {
- if (op.getType() != op.dest().getType())
- return op.emitError("expected result type to be ") << op.dest().getType();
- return success();
-}
-
//===----------------------------------------------------------------------===//
// TensorLoadOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index 001d66d87be8..aae84803a91e 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -69,14 +69,18 @@ LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) {
return success();
}
-void mlir::printListOfOperandsOrIntegers(
- OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr,
- llvm::function_ref<bool(int64_t)> isDynamic) {
+template <int64_t dynVal>
+static void printOperandsOrIntegersListImpl(OpAsmPrinter &p, ValueRange values,
+ ArrayAttr arrayAttr) {
p << '[';
+ if (arrayAttr.empty()) {
+ p << "]";
+ return;
+ }
unsigned idx = 0;
llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
int64_t val = a.cast<IntegerAttr>().getInt();
- if (isDynamic(val))
+ if (val == dynVal)
p << values[idx++];
else
p << val;
@@ -84,32 +88,31 @@ void mlir::printListOfOperandsOrIntegers(
p << ']';
}
-void mlir::printOffsetsSizesAndStrides(OpAsmPrinter &p,
- OffsetSizeAndStrideOpInterface op,
- StringRef offsetPrefix,
- StringRef sizePrefix,
- StringRef stridePrefix,
- ArrayRef<StringRef> elidedAttrs) {
- p << offsetPrefix;
- printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
- ShapedType::isDynamicStrideOrOffset);
- p << sizePrefix;
- printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
- ShapedType::isDynamic);
- p << stridePrefix;
- printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
- ShapedType::isDynamicStrideOrOffset);
- p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
+void mlir::printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &p,
+ Operation *op,
+ OperandRange values,
+ ArrayAttr integers) {
+ return printOperandsOrIntegersListImpl<ShapedType::kDynamicStrideOrOffset>(
+ p, values, integers);
+}
+
+void mlir::printOperandsOrIntegersSizesList(OpAsmPrinter &p, Operation *op,
+ OperandRange values,
+ ArrayAttr integers) {
+ return printOperandsOrIntegersListImpl<ShapedType::kDynamicSize>(p, values,
+ integers);
}
-ParseResult mlir::parseListOfOperandsOrIntegers(
- OpAsmParser &parser, OperationState &result, StringRef attrName,
- int64_t dynVal, SmallVectorImpl<OpAsmParser::OperandType> &ssa) {
+template <int64_t dynVal>
+static ParseResult
+parseOperandsOrIntegersImpl(OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::OperandType> &values,
+ ArrayAttr &integers) {
if (failed(parser.parseLSquare()))
return failure();
// 0-D.
if (succeeded(parser.parseOptionalRSquare())) {
- result.addAttribute(attrName, parser.getBuilder().getArrayAttr({}));
+ integers = parser.getBuilder().getArrayAttr({});
return success();
}
@@ -118,7 +121,7 @@ ParseResult mlir::parseListOfOperandsOrIntegers(
OpAsmParser::OperandType operand;
auto res = parser.parseOptionalOperand(operand);
if (res.hasValue() && succeeded(res.getValue())) {
- ssa.push_back(operand);
+ values.push_back(operand);
attrVals.push_back(dynVal);
} else {
IntegerAttr attr;
@@ -134,59 +137,20 @@ ParseResult mlir::parseListOfOperandsOrIntegers(
return failure();
break;
}
-
- auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals);
- result.addAttribute(attrName, arrayAttr);
+ integers = parser.getBuilder().getI64ArrayAttr(attrVals);
return success();
}
-ParseResult mlir::parseOffsetsSizesAndStrides(
- OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
- llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix,
- llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix,
- llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) {
- return parseOffsetsSizesAndStrides(
- parser, result, segmentSizes, nullptr, parseOptionalOffsetPrefix,
- parseOptionalSizePrefix, parseOptionalStridePrefix);
+ParseResult mlir::parseOperandsOrIntegersOffsetsOrStridesList(
+ OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
+ ArrayAttr &integers) {
+ return parseOperandsOrIntegersImpl<ShapedType::kDynamicStrideOrOffset>(
+ parser, values, integers);
}
-ParseResult mlir::parseOffsetsSizesAndStrides(
- OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
- llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)>
- preResolutionFn,
- llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix,
- llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix,
- llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) {
- SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo;
- auto indexType = parser.getBuilder().getIndexType();
- if ((parseOptionalOffsetPrefix && parseOptionalOffsetPrefix(parser)) ||
- parseListOfOperandsOrIntegers(
- parser, result,
- OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(),
- ShapedType::kDynamicStrideOrOffset, offsetsInfo) ||
- (parseOptionalSizePrefix && parseOptionalSizePrefix(parser)) ||
- parseListOfOperandsOrIntegers(
- parser, result,
- OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(),
- ShapedType::kDynamicSize, sizesInfo) ||
- (parseOptionalStridePrefix && parseOptionalStridePrefix(parser)) ||
- parseListOfOperandsOrIntegers(
- parser, result,
- OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(),
- ShapedType::kDynamicStrideOrOffset, stridesInfo))
- return failure();
- // Add segment sizes to result
- SmallVector<int, 4> segmentSizesFinal(segmentSizes.begin(),
- segmentSizes.end());
- segmentSizesFinal.append({static_cast<int>(offsetsInfo.size()),
- static_cast<int>(sizesInfo.size()),
- static_cast<int>(stridesInfo.size())});
- result.addAttribute(
- OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(),
- parser.getBuilder().getI32VectorAttr(segmentSizesFinal));
- return failure(
- (preResolutionFn && preResolutionFn(parser, result)) ||
- parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
- parser.resolveOperands(sizesInfo, indexType, result.operands) ||
- parser.resolveOperands(stridesInfo, indexType, result.operands));
+ParseResult mlir::parseOperandsOrIntegersSizesList(
+ OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &values,
+ ArrayAttr &integers) {
+ return parseOperandsOrIntegersImpl<ShapedType::kDynamicSize>(parser, values,
+ integers);
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 6579add14c50..44e51d615006 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -643,7 +643,7 @@ func @pad_number_of_block_args(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9
// -----
func @pad_no_block(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9xi32> {
- // expected-error @+1 {{expected region with 1 block}}
+ // expected-error @+1 {{op region #0 ('region') failed to verify constraint: region with 1 blocks}}
%0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] {
} : tensor<?x4xi32> to tensor<?x9xi32>
return %0 : tensor<?x9xi32>
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 81ac241513b0..33c0c24f57c0 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1581,6 +1581,8 @@ static void genSpacePrinter(bool value, OpMethodBody &body,
if (value) {
body << " p << ' ';\n";
lastWasPunctuation = false;
+ } else {
+ lastWasPunctuation = true;
}
shouldEmitSpace = false;
}
More information about the Mlir-commits
mailing list