[Mlir-commits] [mlir] 16d4bbe - [mlir][Linalg] Introduce linalg.pad_tensor op.
Hanhan Wang
llvmlistbot at llvm.org
Thu Jan 21 22:16:42 PST 2021
Author: Hanhan Wang
Date: 2021-01-21T22:09:28-08:00
New Revision: 16d4bbef30a9e625e04653047759d5636f9e58a5
URL: https://github.com/llvm/llvm-project/commit/16d4bbef30a9e625e04653047759d5636f9e58a5
DIFF: https://github.com/llvm/llvm-project/commit/16d4bbef30a9e625e04653047759d5636f9e58a5.diff
LOG: [mlir][Linalg] Introduce linalg.pad_tensor op.
`linalg.pad_tensor` is an operation that pads the `source` tensor
with given `low` and `high` padding config.
Example 1:
```mlir
%pad_value = ... : f32
%1 = linalg.pad_tensor %0 low[1, 2] high[2, 3] {
^bb0(%arg0 : index, %arg1 : index):
linalg.yield %pad_value : f32
} : tensor<?x?xf32> to tensor<?x?xf32>
```
Example 2:
```mlir
%pad_value = ... : f32
%1 = linalg.pad_tensor %arg0 low[2, %arg1, 3, 3] high[3, 3, %arg1, 2] {
^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index):
linalg.yield %pad_value : f32
} : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
```
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D93704
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 0ce86e403681..ae9f81d043f5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -117,6 +117,101 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> {
let hasCanonicalizer = 1;
}
+def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
+ [AttrSizedOperandSegments, SingleBlockImplicitTerminator<"YieldOp">]> {
+ let summary = "tensor pad operation";
+ let description = [{
+ `linalg.pad_tensor` is an operation that pads the `source` tensor
+ with given `low` and `high` padding config.
+
+ The PadTensor operation supports the following arguments:
+
+ * source: the "base" tensor on which to pad.
+ * low: A list contains the padding along the start of each
+ dimension, i.e `low`.
+ * high: A list contains the padding along the end of each
+ dimension, i.e. `high`.
+
+ The result tensor dimensions are `low` + `dim` + `high` along that
+ dimension. The number of elements of `low` and `high` must match
+ the rank of the input tensor (which is also the rank of the output
+ tensor). They can be either a constant or a dynamic value.
+
+ The region of the `pad_tensor` operation returns the value to use
+ for the padding. The arguments of the region represent the index
+ of the source being accessed. There should be as many arguments as
+ the rank of the `source` tensor. The value `yield`-ed by the
+ region is used as the value of the view at the given position.
+
+ Example 1:
+
+ ```mlir
+ %pad_value = ... : f32
+ %0 = linalg.pad_tensor %0 low[1, 2] high[2, 3] {
+ ^bb0(%arg0 : index, %arg1 : index):
+ linalg.yield %pad_value : f32
+ } : tensor<?x?xf32> to tensor<?x?xf32>
+ ```
+
+ Example 2:
+
+ ```mlir
+ %pad_value = ... : f32
+ %0 = linalg.pad_tensor %arg0 low[2, %arg1, 3, 3] high[3, 3, %arg1, 2] {
+ ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index):
+ linalg.yield %pad_value : f32
+ } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
+ ```
+
+ Example 3:
+
+ ```mlir
+ %pad_value = ... : f32
+ %0 = linalg.pad_tensor %arg0 low[0, 0] high[%ub0, %ub1] {
+ ^bb0(%arg1: index, %arg2: index):
+ linalg.yield %pad_value : f32
+ } : tensor<2x3xf32> to tensor<?x?xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ AnyTensor:$source,
+ Variadic<Index>:$low,
+ Variadic<Index>:$high,
+ I64ArrayAttr:$static_low,
+ I64ArrayAttr:$static_high);
+
+ let regions = (region AnyRegion:$region);
+
+ let results = (outs AnyTensor:$result);
+
+ let extraClassDeclaration = [{
+ static StringRef getStaticLowAttrName() {
+ return "static_low";
+ }
+
+ static StringRef getStaticHighAttrName() {
+ return "static_high";
+ }
+
+ // Infer the shape of the result tensor given the static shapes
+ // and element type of the result tensor.
+ static RankedTensorType inferResultType(RankedTensorType sourceType,
+ ArrayRef<int64_t> staticLow,
+ ArrayRef<int64_t> staticHigh);
+ }];
+
+ let builders = [
+ // Build a PadTensorOp with mixed static and dynamic entries.
+ OpBuilderDAG<(ins "Value":$source, "ArrayRef<int64_t>":$staticLow,
+ "ArrayRef<int64_t>":$staticHigh, "ValueRange":$low, "ValueRange":$high,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+ // Build a PadTensorOp with all dynamic entries.
+ OpBuilderDAG<(ins "Value":$source, "ValueRange":$low, "ValueRange":$high,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
+ ];
+}
+
def Linalg_RangeOp :
Linalg_Op<"range", [NoSideEffect]>,
Arguments<(ins Index:$min, Index:$max, Index:$step)>,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index fa98ed0cfbc9..b500eefa9d0c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -915,6 +915,151 @@ void InitTensorOp::getCanonicalizationPatterns(
ReplaceStaticShapeDims>(context);
}
+//===----------------------------------------------------------------------===//
+// PadTensorOp
+//===----------------------------------------------------------------------===//
+
+/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
+static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
+ return llvm::to_vector<4>(
+ llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
+ return a.cast<IntegerAttr>().getInt();
+ }));
+}
+
+static LogicalResult verify(PadTensorOp op) {
+ auto sourceType = op.source().getType().cast<RankedTensorType>();
+ auto resultType = op.result().getType().cast<RankedTensorType>();
+ auto expectedType = PadTensorOp::inferResultType(
+ sourceType, extractFromI64ArrayAttr(op.static_low()),
+ extractFromI64ArrayAttr(op.static_high()));
+ if (resultType != expectedType) {
+ return op.emitError("specified type ")
+ << resultType << " does not match the inferred type "
+ << expectedType;
+ }
+
+ auto ®ion = op.region();
+ if (!llvm::hasSingleElement(region))
+ return op.emitOpError("expected region with 1 block");
+ unsigned rank = resultType.getRank();
+ Block &block = region.front();
+ if (block.getNumArguments() != rank)
+ return op.emitError("expected the block to have ") << rank << " arguments";
+
+ // Note: the number and type of yield values are checked in the YieldOp.
+ for (auto en : llvm::enumerate(block.getArgumentTypes())) {
+ if (!en.value().isIndex())
+ return op.emitOpError("expected block argument ")
+ << (en.index() + 1) << " to be an index";
+ }
+
+ return success();
+}
+
+RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType,
+ ArrayRef<int64_t> staticLow,
+ ArrayRef<int64_t> staticHigh) {
+ unsigned rank = sourceType.getRank();
+ assert(staticLow.size() == rank && "unexpected staticLow size mismatch");
+ assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch");
+
+ SmallVector<int64_t, 4> resultShape;
+ for (auto i : llvm::seq<unsigned>(0, rank)) {
+ if (sourceType.isDynamicDim(i) ||
+ staticLow[i] == ShapedType::kDynamicSize ||
+ staticHigh[i] == ShapedType::kDynamicSize) {
+ resultShape.push_back(ShapedType::kDynamicSize);
+ } else {
+ int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
+ resultShape.push_back(size);
+ }
+ }
+
+ return RankedTensorType::get(resultShape, sourceType.getElementType());
+}
+
+static ParseResult parsePadTensorOp(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::OperandType baseInfo;
+ SmallVector<OpAsmParser::OperandType, 8> operands;
+ SmallVector<Type, 8> types;
+ if (parser.parseOperand(baseInfo))
+ return failure();
+
+ IndexType indexType = parser.getBuilder().getIndexType();
+ SmallVector<OpAsmParser::OperandType, 4> lowPadding, highPadding;
+ if (parser.parseKeyword("low") ||
+ parseListOfOperandsOrIntegers(parser, result,
+ PadTensorOp::getStaticLowAttrName(),
+ ShapedType::kDynamicSize, lowPadding))
+ return failure();
+ if (parser.parseKeyword("high") ||
+ parseListOfOperandsOrIntegers(parser, result,
+ PadTensorOp::getStaticHighAttrName(),
+ ShapedType::kDynamicSize, highPadding))
+ return failure();
+
+ SmallVector<OpAsmParser::OperandType, 8> regionOperands;
+ std::unique_ptr<Region> region = std::make_unique<Region>();
+ SmallVector<Type, 8> operandTypes, regionTypes;
+ if (parser.parseRegion(*region, regionOperands, regionTypes))
+ return failure();
+ result.addRegion(std::move(region));
+
+ Type srcType, dstType;
+ if (parser.parseColonType(srcType) || parser.parseKeywordType("to", dstType))
+ return failure();
+
+ if (parser.addTypeToList(dstType, result.types))
+ return failure();
+
+ SmallVector<int, 4> segmentSizesFinal = {1}; // source tensor
+ segmentSizesFinal.append({static_cast<int>(lowPadding.size()),
+ static_cast<int>(highPadding.size())});
+ result.addAttribute(
+ OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(),
+ parser.getBuilder().getI32VectorAttr(segmentSizesFinal));
+ return failure(
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.resolveOperand(baseInfo, srcType, result.operands) ||
+ parser.resolveOperands(lowPadding, indexType, result.operands) ||
+ parser.resolveOperands(highPadding, indexType, result.operands));
+}
+
+static void print(OpAsmPrinter &p, PadTensorOp op) {
+ p << op->getName().getStringRef() << ' ';
+ p << op.source();
+ p << " low";
+ printListOfOperandsOrIntegers(p, op.low(), op.static_low(),
+ ShapedType::isDynamic);
+ p << " high";
+ printListOfOperandsOrIntegers(p, op.high(), op.static_high(),
+ ShapedType::isDynamic);
+ p.printRegion(op.region());
+ p << " : " << op.source().getType() << " to " << op.getType();
+}
+
+void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
+ ArrayRef<int64_t> staticLow,
+ ArrayRef<int64_t> staticHigh, ValueRange low,
+ ValueRange high, ArrayRef<NamedAttribute> attrs) {
+ auto sourceType = source.getType().cast<RankedTensorType>();
+ auto resultType = inferResultType(sourceType, staticLow, staticHigh);
+ build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
+ b.getI64ArrayAttr(staticHigh));
+ result.addAttributes(attrs);
+}
+
+void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
+ ValueRange low, ValueRange high,
+ ArrayRef<NamedAttribute> attrs) {
+ auto sourceType = source.getType().cast<RankedTensorType>();
+ unsigned rank = sourceType.getRank();
+ SmallVector<int64_t, 4> staticVector(ShapedType::kDynamicSize, rank);
+ build(b, result, source, staticVector, staticVector, low, high, attrs);
+}
+
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
@@ -1557,6 +1702,13 @@ static LogicalResult verify(linalg::YieldOp op) {
if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
return verifyYield(op, cast<LinalgOp>(parentOp));
+ if (auto padTensorOp = dyn_cast<linalg::PadTensorOp>(parentOp)) {
+ return success(
+ op.getNumOperands() == 1 &&
+ op.getOperand(0).getType() ==
+ padTensorOp.getType().cast<ShapedType>().getElementType());
+ }
+
return op.emitOpError("expected parent op with LinalgOp interface");
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 4359eebebbc1..a3ef242c29f9 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -617,3 +617,45 @@ func @illegal_expanding_reshape_mixed_memref_2(%arg0 : memref<?x4x5xf32>) -> mem
memref<?x4x5xf32> into memref<?x?xf32>
return %0 : memref<?x?xf32>
}
+
+// -----
+
+func @pad_result_type(%arg0: tensor<?x2x3x4xi32>, %arg1: index, %arg2: i32) -> tensor<?x?x?x8xf32> {
+ // expected-error @+1 {{specified type 'tensor<?x?x?x8xf32>' does not match the inferred type 'tensor<?x?x?x9xi32>}}
+ %0 = linalg.pad_tensor %arg0 low[1, %arg1, 2, 2] high[1, 2, %arg1, 3] {
+ ^bb0(%arg3: index, %arg4: index): // no predecessors
+ linalg.yield %arg2 : i32
+ } : tensor<?x2x3x4xi32> to tensor<?x?x?x8xf32>
+ return %0 : tensor<?x?x?x8xf32>
+}
+
+// -----
+
+func @pad_number_of_block_args(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9xi32> {
+ // expected-error @+1 {{expected the block to have 2 arguments}}
+ %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] {
+ ^bb0(%arg2: index, %arg3: index, %arg4: index): // no predecessors
+ linalg.yield %arg1 : i32
+ } : tensor<?x4xi32> to tensor<?x9xi32>
+ return %0 : tensor<?x9xi32>
+}
+
+// -----
+
+func @pad_no_block(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9xi32> {
+ // expected-error @+1 {{expected region with 1 block}}
+ %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] {
+ } : tensor<?x4xi32> to tensor<?x9xi32>
+ return %0 : tensor<?x9xi32>
+}
+
+// -----
+
+func @pad_block_args(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9xi32> {
+ // expected-error @+1 {{op expected block argument 1 to be an index}}
+ %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] {
+ ^bb0(%arg2: i32, %arg3: i32): // no predecessors
+ linalg.yield %arg1 : i32
+ } : tensor<?x4xi32> to tensor<?x9xi32>
+ return %0 : tensor<?x9xi32>
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index d0121b0c90c7..c4a3247fdc88 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -5,6 +5,58 @@
// Test that we can lower all the way to LLVM without crashing, don't check results here.
// DISABLED: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1
+func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index,
+ %pad_value: f32) -> tensor<6x?x?x?xf32> {
+ %0 = linalg.pad_tensor %arg0 low[2, %low, 3, 3] high[3, 3, %high, 2] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
+ linalg.yield %pad_value : f32
+ } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
+ return %0 : tensor<6x?x?x?xf32>
+}
+// CHECK-LABEL: func @pad_dynamic
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[LOW:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[HIGH:[a-zA-Z0-9_]*]]
+// CHECK: linalg.pad_tensor %[[ARG0]]
+// CHECK-SAME: low[2, %[[LOW]], 3, 3]
+// CHECK-SAME: high[3, 3, %[[HIGH]], 2]
+// CHECK: : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
+
+// -----
+
+func @pad_static(%arg0: tensor<3x4xf32>, %pad_value: f32) -> tensor<6x9xf32> {
+ %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] {
+ ^bb0(%arg1 : index, %arg2 : index):
+ linalg.yield %pad_value : f32
+ } : tensor<3x4xf32> to tensor<6x9xf32>
+ return %0 : tensor<6x9xf32>
+}
+// CHECK-LABEL: func @pad_static
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK: linalg.pad_tensor %[[ARG0]] low[1, 2] high[2, 3]
+// CHECK: : tensor<3x4xf32> to tensor<6x9xf32>
+
+// -----
+
+func @pad_asymmetrical(%arg0: tensor<2x3xf32>, %ub0: index, %ub1: index,
+ %pad_value: f32) -> tensor<?x?xf32> {
+ %0 = linalg.pad_tensor %arg0 low[0, 0] high[%ub0, %ub1] {
+ ^bb0(%arg1: index, %arg2: index):
+ linalg.yield %pad_value : f32
+ } : tensor<2x3xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @pad_asymmetrical
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[UB0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[UB1:[a-zA-Z0-9_]*]]
+// CHECK: linalg.pad_tensor %[[ARG0]]
+// CHECK-SAME: low[0, 0]
+// CHECK-SAME: high[%[[UB0]], %[[UB1]]]
+// CHECK: : tensor<2x3xf32> to tensor<?x?xf32>
+
+// -----
+
func @range(%arg0: index, %arg1: index, %arg2: index) {
%0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
return
More information about the Mlir-commits
mailing list