[Mlir-commits] [mlir] cf9503c - [mlir] Add subtensor_insert operation
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Oct 2 03:34:50 PDT 2020
Author: Nicolas Vasilache
Date: 2020-10-02T06:32:31-04:00
New Revision: cf9503c1b752062d9abfb2c7922a50574d9c5de4
URL: https://github.com/llvm/llvm-project/commit/cf9503c1b752062d9abfb2c7922a50574d9c5de4
DIFF: https://github.com/llvm/llvm-project/commit/cf9503c1b752062d9abfb2c7922a50574d9c5de4.diff
LOG: [mlir] Add subtensor_insert operation
Differential revision: https://reviews.llvm.org/D88657
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/IR/core-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 3d9daee964b6..c62be7571aad 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2922,15 +2922,20 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
The SubView operation supports the following arguments:
- * Memref: the "base" memref on which to create a "view" memref.
- * Offsets: memref-rank number of dynamic offsets or static integer
- attributes into the "base" memref at which to create the "view"
- memref.
- * Sizes: memref-rank number of dynamic sizes or static integer attributes
- which specify the sizes of the result "view" memref type.
- * Strides: memref-rank number of dynamic strides or static integer
- attributes that compose multiplicatively with the base memref
- strides in each dimension.
+ * semref: the "base" memref on which to create a "view" memref.
+ * offsets: memref-rank number of offsets into the "base" memref at which to
+ create the "view" memref.
+ * sizes: memref-rank number of sizes which specify the sizes of the result
+ "view" memref type.
+ * strides: memref-rank number of strides that compose multiplicatively with
+ the base memref strides in each dimension.
+
+ The representation based on offsets, sizes and strides support a
+ partially-static specification via attributes specified through the
+ `static_offsets`, `static_sizes` and `static_strides` arguments. A special
+ sentinel value ShapedType::kDynamicSize and
+ ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has
+ a static value.
A subview operation may additionally reduce the rank of the resulting view
by removing dimensions that are statically known to be of size 1.
@@ -3076,7 +3081,7 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
let extraClassDeclaration = extraBaseClassDeclaration # [{
/// Returns the type of the base memref operand.
- MemRefType getSourceMemRefType() {
+ MemRefType getSourceType() {
return source().getType().cast<MemRefType>();
}
@@ -3108,13 +3113,19 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor"> {
The subtensor operation supports the following arguments:
* tensor: the "base" tensor from which to extract a subtensor.
- * offsets: tensor-rank number of dynamic offsets or static integer
- attributes into the "base" tensor from which to extract the
- subtensor.
- * sizes: tensor-rank number of dynamic sizes or static integer attributes
- which specify the sizes of the result tensor type.
- * strides: tensor-rank number of dynamic strides or static integer
- attributes specifying susampling in each dimension.
+ * offsets: tensor-rank number of offsets into the "base" tensor from which
+ to extract the subtensor.
+ * sizes: tensor-rank number of sizes which specify the sizes of the result
+ tensor type.
+ * strides: tensor-rank number of strides specifying subsampling in each
+ dimension.
+
+ The representation based on offsets, sizes and strides support a
+ partially-static specification via attributes specified through the
+ `static_offsets`, `static_sizes` and `static_strides` arguments. A special
+ sentinel value ShapedType::kDynamicSize and
+ ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has
+ a static value.
After buffer-allocation, the "subtensor" op is expected to lower into a
"subview" op.
@@ -3144,9 +3155,22 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor"> {
);
let results = (outs AnyRankedTensor:$result);
+ let builders = [
+ // Build a SubViewOp with mixed static and dynamic entries.
+ OpBuilder<
+ "Value source, ArrayRef<int64_t> staticOffsets, "
+ "ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides, "
+ "ValueRange offsets, ValueRange sizes, ValueRange strides, "
+ "ArrayRef<NamedAttribute> attrs = {}">,
+ // Build a SubViewOp with all dynamic entries.
+ OpBuilder<
+ "Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, "
+ "ArrayRef<NamedAttribute> attrs = {}">
+ ];
+
let extraClassDeclaration = extraBaseClassDeclaration # [{
/// Returns the type of the base tensor operand.
- RankedTensorType getSourceRankedTensorType() {
+ RankedTensorType getSourceType() {
return source().getType().cast<RankedTensorType>();
}
@@ -3167,6 +3191,80 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor"> {
let hasCanonicalizer = 1;
}
+//===----------------------------------------------------------------------===//
+// SubTensorInsertOp
+//===----------------------------------------------------------------------===//
+
+def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<"subtensor_insert"> {
+ let summary = "subtensor_insert operation";
+ let description = [{
+ The "subtensor_insert" operation insert a tensor `source` into another
+ tensor `dest` as specified by the operation's offsets, sizes and strides
+ arguments.
+
+ It returns a copy of `dest` with the proper subtensor updated with the value
+ of `source`.
+
+ The subtensor_insert operation has the encodes the following information:
+
+ * source: the tensor that is inserted.
+ * dest: the tensor into which the source tensor is inserted.
+ * offsets: tensor-rank number of offsets into the "base" tensor from which
+ to extract the subtensor.
+ * sizes: tensor-rank number of sizes which specify the sizes of the result
+ tensor type.
+ * strides: tensor-rank number of strides that specify subsampling in each
+ dimension.
+
+ The representation based on offsets, sizes and strides support a
+ partially-static specification via attributes specified through the
+ `static_offsets`, `static_sizes` and `static_strides` arguments. A special
+ sentinel value ShapedType::kDynamicSize and
+ ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has
+ a static value.
+
+ After buffer-allocation, the "subtensor_insert" op is expected to become
+ an in-place buffer update.
+ }];
+
+ let arguments = (ins
+ AnyRankedTensor:$source,
+ AnyRankedTensor:$dest,
+ Variadic<Index>:$offsets,
+ Variadic<Index>:$sizes,
+ Variadic<Index>:$strides,
+ I64ArrayAttr:$static_offsets,
+ I64ArrayAttr:$static_sizes,
+ I64ArrayAttr:$static_strides
+ );
+ let results = (outs AnyRankedTensor:$result);
+
+ let builders = [
+ // Build a SubViewOp with mixed static and dynamic entries.
+ OpBuilder<
+ "Value source, Value dest, ArrayRef<int64_t> staticOffsets, "
+ "ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides, "
+ "ValueRange offsets, ValueRange sizes, ValueRange strides, "
+ "ArrayRef<NamedAttribute> attrs = {}">,
+ // Build a SubViewOp with all dynamic entries.
+ OpBuilder<
+ "Value source, Value dest, ValueRange offsets, ValueRange sizes, "
+ "ValueRange strides, ArrayRef<NamedAttribute> attrs = {}">
+ ];
+
+ let extraClassDeclaration = extraBaseClassDeclaration # [{
+ /// Returns the type of the base tensor operand.
+ RankedTensorType getSourceType() {
+ return source().getType().cast<RankedTensorType>();
+ }
+
+ /// The result of a subtensor is always a tensor.
+ RankedTensorType getType() {
+ return getResult().getType().cast<RankedTensorType>();
+ }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// TanhOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 5548274eee18..7f4e2ffa5262 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -23,6 +23,7 @@
#include "mlir/IR/Value.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
@@ -2639,10 +2640,15 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
/// `:` strided-memref-type `to` strided-memref-type
/// ```
template <typename OpType>
-static void printOpWithOffsetsSizesAndStrides(OpAsmPrinter &p, OpType op) {
+static void printOpWithOffsetsSizesAndStrides(
+ OpAsmPrinter &p, OpType op,
+ llvm::function_ref<void(OpAsmPrinter &p, OpType op)> printExtraOperands =
+ [](OpAsmPrinter &p, OpType op) {},
+ StringLiteral resultTypeKeyword = "to") {
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
- p << op.getOperand(0);
+ p << op.source();
+ printExtraOperands(p, op);
printSubViewListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
ShapedType::isDynamicStrideOrOffset);
printSubViewListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
@@ -2651,27 +2657,35 @@ static void printOpWithOffsetsSizesAndStrides(OpAsmPrinter &p, OpType op) {
ShapedType::isDynamicStrideOrOffset);
p.printOptionalAttrDict(op.getAttrs(),
/*elidedAttrs=*/{OpType::getSpecialAttrNames()});
- p << " : " << op.getOperand(0).getType() << " to " << op.getType();
+ p << " : " << op.getSourceType() << " " << resultTypeKeyword << " "
+ << op.getType();
}
static void print(OpAsmPrinter &p, SubViewOp op) {
return printOpWithOffsetsSizesAndStrides<SubViewOp>(p, op);
}
-/// Parse SubViewOp of the form:
+/// Parse of the form:
/// ```
-/// `name` ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
-/// `:` strided-memref-type `to` strided-memref-type
+/// `name` ssa-name (extra-operands)?
+/// `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
+/// `:` strided-memref-type `resultTypeKeyword strided-memref-type
/// ```
template <typename OpType>
-static ParseResult parseOpWithOffsetsSizesAndStrides(OpAsmParser &parser,
- OperationState &result) {
- OpAsmParser::OperandType srcInfo;
+static ParseResult parseOpWithOffsetsSizesAndStrides(
+ OpAsmParser &parser, OperationState &result,
+ std::function<ParseResult(OpAsmParser &p,
+ OpAsmParser::OperandType &dstInfo)>
+ parseExtraOperand = nullptr,
+ StringLiteral resultTypeKeyword = "to") {
+ OpAsmParser::OperandType srcInfo, dstInfo;
SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo;
auto indexType = parser.getBuilder().getIndexType();
Type srcType, dstType;
if (parser.parseOperand(srcInfo))
return failure();
+ if (parseExtraOperand && parseExtraOperand(parser, dstInfo))
+ return failure();
if (parseListOfOperandsOrIntegers(
parser, result, OpType::getStaticOffsetsAttrName(),
ShapedType::kDynamicStrideOrOffset, offsetsInfo) ||
@@ -2683,21 +2697,27 @@ static ParseResult parseOpWithOffsetsSizesAndStrides(OpAsmParser &parser,
ShapedType::kDynamicStrideOrOffset, stridesInfo))
return failure();
+ // Handle segment sizes.
auto b = parser.getBuilder();
- SmallVector<int, 4> segmentSizes{1, static_cast<int>(offsetsInfo.size()),
- static_cast<int>(sizesInfo.size()),
- static_cast<int>(stridesInfo.size())};
+ SmallVector<int, 4> segmentSizes = {1, static_cast<int>(offsetsInfo.size()),
+ static_cast<int>(sizesInfo.size()),
+ static_cast<int>(stridesInfo.size())};
+ // If we parse an extra operand it needs to appear in the segmentSizes
+ if (parseExtraOperand)
+ segmentSizes.insert(segmentSizes.begin(), 1);
result.addAttribute(OpType::getOperandSegmentSizeAttr(),
b.getI32VectorAttr(segmentSizes));
return failure(
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(srcType) ||
+ parser.parseKeywordType(resultTypeKeyword.str().c_str(), dstType) ||
parser.resolveOperand(srcInfo, srcType, result.operands) ||
+ (parseExtraOperand &&
+ parser.resolveOperand(dstInfo, dstType, result.operands)) ||
parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
parser.resolveOperands(sizesInfo, indexType, result.operands) ||
parser.resolveOperands(stridesInfo, indexType, result.operands) ||
- parser.parseKeywordType("to", dstType) ||
parser.addTypeToList(dstType, result.types));
}
@@ -2894,7 +2914,7 @@ static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) {
/// Verifier for SubViewOp.
static LogicalResult verify(SubViewOp op) {
- MemRefType baseType = op.getSourceMemRefType();
+ MemRefType baseType = op.getSourceType();
MemRefType subViewType = op.getType();
// The base memref and the view memref should be in the same memory space.
@@ -3273,8 +3293,7 @@ static LogicalResult verify(SubTensorOp op) {
// Verify result type against inferred type.
auto expectedType = SubTensorOp::inferResultType(
- op.getSourceRankedTensorType(),
- extractFromI64ArrayAttr(op.static_offsets()),
+ op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()),
extractFromI64ArrayAttr(op.static_sizes()),
extractFromI64ArrayAttr(op.static_strides()));
if (!isRankReducedType(expectedType, op.getType()))
@@ -3291,6 +3310,72 @@ void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
context);
}
+//===----------------------------------------------------------------------===//
+// SubTensorInsertOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, SubTensorInsertOp op) {
+ return printOpWithOffsetsSizesAndStrides<SubTensorInsertOp>(
+ p, op,
+ [](OpAsmPrinter &p, SubTensorInsertOp op) { p << " into " << op.dest(); },
+ /*resultTypeKeyword=*/"into");
+}
+
+static ParseResult parseSubTensorInsertOp(OpAsmParser &parser,
+ OperationState &result) {
+ return parseOpWithOffsetsSizesAndStrides<SubTensorInsertOp>(
+ parser, result,
+ [](OpAsmParser &parser, OpAsmParser::OperandType &dstInfo) {
+ return failure(parser.parseKeyword("into") ||
+ parser.parseOperand(dstInfo));
+ },
+ "into");
+}
+
+void mlir::SubTensorInsertOp::build(
+ OpBuilder &b, OperationState &result, Value source, Value dest,
+ ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
+ ArrayRef<int64_t> staticStrides, ValueRange offsets, ValueRange sizes,
+ ValueRange strides, ArrayRef<NamedAttribute> attrs) {
+ build(b, result, dest.getType(), source, dest, offsets, sizes, strides,
+ b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
+ b.getI64ArrayAttr(staticStrides));
+ result.addAttributes(attrs);
+}
+
+/// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes`
+/// and `staticStrides` are automatically filled with source-memref-rank
+/// sentinel values that encode dynamic entries.
+void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result,
+ Value source, Value dest,
+ ValueRange offsets, ValueRange sizes,
+ ValueRange strides,
+ ArrayRef<NamedAttribute> attrs) {
+ auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
+ unsigned rank = sourceRankedTensorType.getRank();
+ SmallVector<int64_t, 4> staticOffsetsVector(
+ rank, ShapedType::kDynamicStrideOrOffset);
+ SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
+ SmallVector<int64_t, 4> staticStridesVector(
+ rank, ShapedType::kDynamicStrideOrOffset);
+ build(b, result, source, dest, staticOffsetsVector, staticSizesVector,
+ staticStridesVector, offsets, sizes, strides, attrs);
+}
+
+SmallVector<Range, 8> SubTensorInsertOp::getOrCreateRanges(OpBuilder &b,
+ Location loc) {
+ return ::getOrCreateRangesImpl(*this, b, loc);
+}
+
+/// Verifier for SubViewOp.
+static LogicalResult verify(SubTensorInsertOp op) {
+ if (failed(verifyOpWithOffsetSizesAndStrides(op)))
+ return failure();
+ if (op.getType() != op.dest().getType())
+ return op.emitError("expected result type to be ") << op.dest().getType();
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TensorCastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 72a063ff9d51..2590dc0105c4 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -901,7 +901,6 @@ func @assume_alignment(%0: memref<4x4xf16>) {
return
}
-
// CHECK-LABEL: func @subtensor({{.*}}) {
func @subtensor(%t: tensor<8x16x4xf32>, %idx : index) {
%c0 = constant 0 : index
@@ -924,3 +923,21 @@ func @subtensor(%t: tensor<8x16x4xf32>, %idx : index) {
return
}
+
+// CHECK-LABEL: func @subtensor_insert({{.*}}) {
+func @subtensor_insert(%t: tensor<8x16x4xf32>, %t2: tensor<16x32x8xf32>, %idx : index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+
+ // CHECK: subtensor_insert
+ // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32>
+ %1 = subtensor_insert %t into %t2[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1]
+ : tensor<8x16x4xf32> into tensor<16x32x8xf32>
+
+ // CHECK: subtensor_insert
+ // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32>
+ %2 = subtensor_insert %t into %t2[%c0, %idx, %c0][%idx, 4, %idx][%c1, 1, %c1]
+ : tensor<8x16x4xf32> into tensor<16x32x8xf32>
+
+ return
+}
More information about the Mlir-commits
mailing list