[Mlir-commits] [mlir] c247081 - [mlir] NFC - Refactor and expose a helper printOffsetSizesAndStrides helper function.
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Nov 24 12:02:14 PST 2020
Author: Nicolas Vasilache
Date: 2020-11-24T20:00:59Z
New Revision: c24708102501115efae27f82c24d5991059a5770
URL: https://github.com/llvm/llvm-project/commit/c24708102501115efae27f82c24d5991059a5770
DIFF: https://github.com/llvm/llvm-project/commit/c24708102501115efae27f82c24d5991059a5770.diff
LOG: [mlir] NFC - Refactor and expose a helper printOffsetSizesAndStrides helper function.
Print part of an op of the form:
```
<optional-offset-prefix>`[` offset-list `]`
<optional-size-prefix>`[` size-list `]`
<optional-stride-prefix>[` stride-list `]`
```
Also address some leftover nits.
Differential revision: https://reviews.llvm.org/D92031
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/Interfaces/ViewLikeInterface.h
mlir/include/mlir/Interfaces/ViewLikeInterface.td
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Interfaces/ViewLikeInterface.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 7a775c3a317b..c44d99b1620d 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -232,15 +232,6 @@ class BaseOpWithOffsetSizesAndStrides<string mnemonic, list<OpTrait> traits = []
SmallVector<Range, 8> getOrCreateRanges(OpBuilder &b, Location loc) {
return mlir::getOrCreateRanges(*this, b, loc);
}
-
- static ArrayRef<StringRef> getSpecialAttrNames() {
- static SmallVector<StringRef, 4> names{
- OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(),
- OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(),
- OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(),
- getOperandSegmentSizeAttr()};
- return names;
- }
}];
}
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 15ba5d18a6d6..b7d796f39f4d 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -29,6 +29,24 @@ struct Range {
class OffsetSizeAndStrideOpInterface;
LogicalResult verify(OffsetSizeAndStrideOpInterface op);
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/ViewLikeInterface.h.inc"
+
+namespace mlir {
+/// Print part of an op of the form:
+/// ```
+/// <optional-offset-prefix>`[` offset-list `]`
+/// <optional-size-prefix>`[` size-list `]`
+/// <optional-stride-prefix>[` stride-list `]`
+/// ```
+void printOffsetsSizesAndStrides(
+ OpAsmPrinter &p, OffsetSizeAndStrideOpInterface op,
+ StringRef offsetPrefix = "", StringRef sizePrefix = " ",
+ StringRef stridePrefix = " ",
+ ArrayRef<StringRef> elidedAttrs =
+ OffsetSizeAndStrideOpInterface::getSpecialAttrNames());
/// Parse trailing part of an op of the form:
/// ```
@@ -59,10 +77,16 @@ ParseResult parseOffsetsSizesAndStrides(
nullptr,
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix =
nullptr);
+/// `preResolutionFn`-less version of `parseOffsetsSizesAndStrides`.
+ParseResult parseOffsetsSizesAndStrides(
+ OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
+ llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix =
+ nullptr,
+ llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix =
+ nullptr,
+ llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix =
+ nullptr);
} // namespace mlir
-/// Include the generated interface declarations.
-#include "mlir/Interfaces/ViewLikeInterface.h.inc"
-
#endif // MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index d3a7bf185d13..31f9bca8d7fb 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -357,6 +357,14 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
static StringRef getStaticStridesAttrName() {
return "static_strides";
}
+ static ArrayRef<StringRef> getSpecialAttrNames() {
+ static SmallVector<StringRef, 4> names{
+ OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(),
+ OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(),
+ OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(),
+ OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr()};
+ return names;
+ }
}];
let verify = [{
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 3160e8f8be0b..6ed17c744f8c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -793,12 +793,6 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
return fusableDependences;
}
-static bool isZero(Value v) {
- if (auto cst = v.getDefiningOp<ConstantIndexOp>())
- return cst.getValue() == 0;
- return false;
-}
-
/// Tile the fused loops in the root operation, by setting the tile sizes for
/// all other loops to zero (those will be tiled later).
static Optional<TiledLinalgOp> tileRootOperation(
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 1fe52b70992e..1437552f6e2a 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -248,49 +248,6 @@ OpFoldResult AddIOp::fold(ArrayRef<Attribute> operands) {
[](APInt a, APInt b) { return a + b; });
}
-//===----------------------------------------------------------------------===//
-// BaseOpWithOffsetSizesAndStridesOp
-//===----------------------------------------------------------------------===//
-
-/// Print a list with either (1) the static integer value in `arrayAttr` if
-/// `isDynamic` evaluates to false or (2) the next value otherwise.
-/// This allows idiomatic printing of mixed value and integer attributes in a
-/// list. E.g. `[%arg0, 7, 42, %arg42]`.
-static void
-printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values,
- ArrayAttr arrayAttr,
- llvm::function_ref<bool(int64_t)> isDynamic) {
- p << '[';
- unsigned idx = 0;
- llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
- int64_t val = a.cast<IntegerAttr>().getInt();
- if (isDynamic(val))
- p << values[idx++];
- else
- p << val;
- });
- p << ']';
-}
-
-/// Verify that a particular offset/size/stride static attribute is well-formed.
-static LogicalResult verifyOpWithOffsetSizesAndStridesPart(
- OffsetSizeAndStrideOpInterface op, StringRef name,
- unsigned expectedNumElements, StringRef attrName, ArrayAttr attr,
- llvm::function_ref<bool(int64_t)> isDynamic, ValueRange values) {
- /// Check static and dynamic offsets/sizes/strides breakdown.
- if (attr.size() != expectedNumElements)
- return op.emitError("expected ")
- << expectedNumElements << " " << name << " values";
- unsigned expectedNumDynamicEntries =
- llvm::count_if(attr.getValue(), [&](Attribute attr) {
- return isDynamic(attr.cast<IntegerAttr>().getInt());
- });
- if (values.size() != expectedNumDynamicEntries)
- return op.emitError("expected ")
- << expectedNumDynamicEntries << " dynamic " << name << " values";
- return success();
-}
-
/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
return llvm::to_vector<4>(
@@ -2390,9 +2347,9 @@ void mlir::MemRefReinterpretCastOp::build(OpBuilder &b, OperationState &result,
staticStridesVector, offset, sizes, strides, attrs);
}
-/// Print of the form:
+/// Print a memref_reinterpret_cast op of the form:
/// ```
-/// `name` ssa-name to
+/// `memref_reinterpret_cast` ssa-name to
/// offset: `[` offset `]`
/// sizes: `[` size-list `]`
/// strides:`[` stride-list `]`
@@ -2400,19 +2357,11 @@ void mlir::MemRefReinterpretCastOp::build(OpBuilder &b, OperationState &result,
/// ```
static void print(OpAsmPrinter &p, MemRefReinterpretCastOp op) {
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
- p << op.getOperationName().drop_front(stdDotLen) << " " << op.source()
- << " to offset: ";
- printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
- ShapedType::isDynamicStrideOrOffset);
- p << ", sizes: ";
- printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
- ShapedType::isDynamic);
- p << ", strides: ";
- printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
- ShapedType::isDynamicStrideOrOffset);
- p.printOptionalAttrDict(
- op.getAttrs(),
- /*elidedAttrs=*/MemRefReinterpretCastOp::getSpecialAttrNames());
+ p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
+ p << op.source() << " ";
+ printOffsetsSizesAndStrides(
+ p, op, /*offsetPrefix=*/"to offset: ", /*sizePrefix=*/", sizes: ",
+ /*stridePrefix=*/", strides: ");
p << ": " << op.source().getType() << " to " << op.getType();
}
@@ -2451,8 +2400,8 @@ static ParseResult parseMemRefReinterpretCastOp(OpAsmParser &parser,
parser.parseKeywordType("to", dstType) ||
parser.resolveOperand(srcInfo, srcType, result.operands));
};
- SmallVector<int, 4> segmentSizes{1}; // source memref
- if (failed(parseOffsetsSizesAndStrides(parser, result, segmentSizes,
+ if (failed(parseOffsetsSizesAndStrides(parser, result,
+ /*segmentSizes=*/{1}, // source memref
preResolutionFn, parseOffsetPrefix,
parseSizePrefix, parseStridePrefix)))
return failure();
@@ -3122,38 +3071,18 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
sourceMemRefType.getMemorySpace());
}
-/// Print SubViewOp in the form:
+/// Print a subview op of the form:
/// ```
-/// subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
+/// `subview` ssa-name
+/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
/// `:` strided-memref-type `to` strided-memref-type
/// ```
-template <typename OpType>
-static void printOpWithOffsetsSizesAndStrides(
- OpAsmPrinter &p, OpType op,
- llvm::function_ref<void(OpAsmPrinter &p, OpType op)> printExtraOperands =
- [](OpAsmPrinter &p, OpType op) {},
- StringRef resultTypeKeyword = "to") {
+static void print(OpAsmPrinter &p, SubViewOp op) {
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
p << op.source();
- printExtraOperands(p, op);
- printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
- ShapedType::isDynamicStrideOrOffset);
- p << ' ';
- printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
- ShapedType::isDynamic);
- p << ' ';
- printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
- ShapedType::isDynamicStrideOrOffset);
- p << ' ';
- p.printOptionalAttrDict(op.getAttrs(),
- /*elidedAttrs=*/{OpType::getSpecialAttrNames()});
- p << " : " << op.getSourceType() << " " << resultTypeKeyword << " "
- << op.getType();
-}
-
-static void print(OpAsmPrinter &p, SubViewOp op) {
- return printOpWithOffsetsSizesAndStrides<SubViewOp>(p, op);
+ printOffsetsSizesAndStrides(p, op);
+ p << " : " << op.getSourceType() << " to " << op.getType();
}
/// Parse a subview op of the form:
@@ -3173,8 +3102,9 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
parser.parseKeywordType("to", dstType) ||
parser.resolveOperand(srcInfo, srcType, result.operands));
};
- SmallVector<int, 4> segmentSizes{1}; // source memref
- if (failed(parseOffsetsSizesAndStrides(parser, result, segmentSizes,
+
+ if (failed(parseOffsetsSizesAndStrides(parser, result,
+ /*segmentSizes=*/{1}, // source memref
preResolutionFn)))
return failure();
return parser.addTypeToList(dstType, result.types);
@@ -3750,8 +3680,18 @@ OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
// SubTensorOp
//===----------------------------------------------------------------------===//
+/// Print a subtensor op of the form:
+/// ```
+/// `subtensor` ssa-name
+/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
+/// `:` ranked-tensor-type `to` ranked-tensor-type
+/// ```
static void print(OpAsmPrinter &p, SubTensorOp op) {
- return printOpWithOffsetsSizesAndStrides<SubTensorOp>(p, op);
+ int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
+ p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
+ p << op.source();
+ printOffsetsSizesAndStrides(p, op);
+ p << " : " << op.getSourceType() << " to " << op.getType();
}
/// Parse a subtensor op of the form:
@@ -3772,8 +3712,9 @@ static ParseResult parseSubTensorOp(OpAsmParser &parser,
parser.parseKeywordType("to", dstType) ||
parser.resolveOperand(srcInfo, srcType, result.operands));
};
- SmallVector<int, 4> segmentSizes{1}; // source tensor
- if (failed(parseOffsetsSizesAndStrides(parser, result, segmentSizes,
+
+ if (failed(parseOffsetsSizesAndStrides(parser, result,
+ /*segmentSizes=*/{1}, // source tensor
preResolutionFn)))
return failure();
return parser.addTypeToList(dstType, result.types);
@@ -3853,11 +3794,18 @@ void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
// SubTensorInsertOp
//===----------------------------------------------------------------------===//
+/// Print a subtensor_insert op of the form:
+/// ```
+/// `subtensor_insert` ssa-name `into` ssa-name
+/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
+/// `:` ranked-tensor-type `into` ranked-tensor-type
+/// ```
static void print(OpAsmPrinter &p, SubTensorInsertOp op) {
- return printOpWithOffsetsSizesAndStrides<SubTensorInsertOp>(
- p, op,
- [](OpAsmPrinter &p, SubTensorInsertOp op) { p << " into " << op.dest(); },
- /*resultTypeKeyword=*/"into");
+ int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
+ p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
+ p << op.source() << " into " << op.dest();
+ printOffsetsSizesAndStrides(p, op);
+ p << " : " << op.getSourceType() << " into " << op.getType();
}
/// Parse a subtensor_insert op of the form:
@@ -3880,9 +3828,11 @@ static ParseResult parseSubTensorInsertOp(OpAsmParser &parser,
parser.resolveOperand(srcInfo, srcType, result.operands) ||
parser.resolveOperand(dstInfo, dstType, result.operands));
};
- SmallVector<int, 4> segmentSizes{1, 1}; // source tensor, destination tensor
- if (failed(parseOffsetsSizesAndStrides(parser, result, segmentSizes,
- preResolutionFn)))
+
+ if (failed(parseOffsetsSizesAndStrides(
+ parser, result,
+ /*segmentSizes=*/{1, 1}, // source tensor, destination tensor
+ preResolutionFn)))
return failure();
return parser.addTypeToList(dstType, result.types);
}
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index d8a540fa72ff..6127d08a8fc5 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -57,6 +57,44 @@ LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) {
return success();
}
+/// Print a list with either (1) the static integer value in `arrayAttr` if
+/// `isDynamic` evaluates to false or (2) the next value otherwise.
+/// This allows idiomatic printing of mixed value and integer attributes in a
+/// list. E.g. `[%arg0, 7, 42, %arg42]`.
+static void
+printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values,
+ ArrayAttr arrayAttr,
+ llvm::function_ref<bool(int64_t)> isDynamic) {
+ p << '[';
+ unsigned idx = 0;
+ llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
+ int64_t val = a.cast<IntegerAttr>().getInt();
+ if (isDynamic(val))
+ p << values[idx++];
+ else
+ p << val;
+ });
+ p << ']';
+}
+
+void mlir::printOffsetsSizesAndStrides(OpAsmPrinter &p,
+ OffsetSizeAndStrideOpInterface op,
+ StringRef offsetPrefix,
+ StringRef sizePrefix,
+ StringRef stridePrefix,
+ ArrayRef<StringRef> elidedAttrs) {
+ p << offsetPrefix;
+ printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
+ ShapedType::isDynamicStrideOrOffset);
+ p << sizePrefix;
+ printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
+ ShapedType::isDynamic);
+ p << stridePrefix;
+ printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(),
+ ShapedType::isDynamicStrideOrOffset);
+ p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
+}
+
/// Parse a mixed list with either (1) static integer values or (2) SSA values.
/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal`
/// encode the position of SSA values. Add the parsed SSA values to `ssa`
@@ -105,9 +143,17 @@ parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result,
}
ParseResult mlir::parseOffsetsSizesAndStrides(
- OpAsmParser &parser,
- OperationState &result,
- ArrayRef<int> segmentSizes,
+ OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
+ llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix,
+ llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalSizePrefix,
+ llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalStridePrefix) {
+ return parseOffsetsSizesAndStrides(
+ parser, result, segmentSizes, nullptr, parseOptionalOffsetPrefix,
+ parseOptionalSizePrefix, parseOptionalStridePrefix);
+}
+
+ParseResult mlir::parseOffsetsSizesAndStrides(
+ OpAsmParser &parser, OperationState &result, ArrayRef<int> segmentSizes,
llvm::function_ref<ParseResult(OpAsmParser &, OperationState &)>
preResolutionFn,
llvm::function_ref<ParseResult(OpAsmParser &)> parseOptionalOffsetPrefix,
@@ -132,14 +178,14 @@ ParseResult mlir::parseOffsetsSizesAndStrides(
ShapedType::kDynamicStrideOrOffset, stridesInfo))
return failure();
// Add segment sizes to result
- SmallVector<int, 4> segmentSizesFinal(segmentSizes.begin(), segmentSizes.end());
+ SmallVector<int, 4> segmentSizesFinal(segmentSizes.begin(),
+ segmentSizes.end());
segmentSizesFinal.append({static_cast<int>(offsetsInfo.size()),
- static_cast<int>(sizesInfo.size()),
- static_cast<int>(stridesInfo.size())});
- auto b = parser.getBuilder();
+ static_cast<int>(sizesInfo.size()),
+ static_cast<int>(stridesInfo.size())});
result.addAttribute(
OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(),
- b.getI32VectorAttr(segmentSizesFinal));
+ parser.getBuilder().getI32VectorAttr(segmentSizesFinal));
return failure(
(preResolutionFn && preResolutionFn(parser, result)) ||
parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
More information about the Mlir-commits
mailing list