[Mlir-commits] [mlir] 16d4bbe - [mlir][Linalg] Introduce linalg.pad_tensor op.

Hanhan Wang llvmlistbot at llvm.org
Thu Jan 21 22:16:42 PST 2021


Author: Hanhan Wang
Date: 2021-01-21T22:09:28-08:00
New Revision: 16d4bbef30a9e625e04653047759d5636f9e58a5

URL: https://github.com/llvm/llvm-project/commit/16d4bbef30a9e625e04653047759d5636f9e58a5
DIFF: https://github.com/llvm/llvm-project/commit/16d4bbef30a9e625e04653047759d5636f9e58a5.diff

LOG: [mlir][Linalg] Introduce linalg.pad_tensor op.

`linalg.pad_tensor` is an operation that pads the `source` tensor
with given `low` and `high` padding config.

Example 1:

```mlir
  %pad_value = ... : f32
  %1 = linalg.pad_tensor %0 low[1, 2] high[2, 3] {
  ^bb0(%arg0 : index, %arg1 : index):
    linalg.yield %pad_value : f32
  } : tensor<?x?xf32> to tensor<?x?xf32>
```

Example 2:
```mlir
  %pad_value = ... : f32
  %1 = linalg.pad_tensor %arg0 low[2, %arg1, 3, 3] high[3, 3, %arg1, 2] {
  ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index):
    linalg.yield %pad_value : f32
  } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
```

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D93704

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 0ce86e403681..ae9f81d043f5 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -117,6 +117,101 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> {
   let hasCanonicalizer = 1;
 }
 
+def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
+    [AttrSizedOperandSegments, SingleBlockImplicitTerminator<"YieldOp">]> {
+  let summary = "tensor pad operation";
+  let description = [{
+    `linalg.pad_tensor` is an operation that pads the `source` tensor
+    with given `low` and `high` padding config.
+
+    The PadTensor operation supports the following arguments:
+
+    * source: the "base" tensor on which to pad.
+    * low: A list contains the padding along the start of each
+           dimension, i.e `low`.
+    * high: A list contains the padding along the end of each
+           dimension, i.e. `high`.
+
+    The result tensor dimensions are `low` + `dim` + `high` along that
+    dimension. The number of elements of `low` and `high` must match
+    the rank of the input tensor (which is also the rank of the output
+    tensor). They can be either a constant or a dynamic value.
+
+    The region of the `pad_tensor` operation returns the value to use
+    for the padding. The arguments of the region represent the index
+    of the source being accessed. There should be as many arguments as
+    the rank of the `source` tensor. The value `yield`-ed by the
+    region is used as the value of the view at the given position.
+
+    Example 1:
+
+    ```mlir
+      %pad_value = ... : f32
+      %0 = linalg.pad_tensor %0 low[1, 2] high[2, 3] {
+      ^bb0(%arg0 : index, %arg1 : index):
+        linalg.yield %pad_value : f32
+      } : tensor<?x?xf32> to tensor<?x?xf32>
+    ```
+
+    Example 2:
+
+    ```mlir
+      %pad_value = ... : f32
+      %0 = linalg.pad_tensor %arg0 low[2, %arg1, 3, 3] high[3, 3, %arg1, 2] {
+      ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index):
+          linalg.yield %pad_value : f32
+      } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
+    ```
+
+    Example 3:
+
+    ```mlir
+      %pad_value = ... : f32
+      %0 = linalg.pad_tensor %arg0 low[0, 0] high[%ub0, %ub1] {
+      ^bb0(%arg1: index, %arg2: index):
+        linalg.yield %pad_value : f32
+      } : tensor<2x3xf32> to tensor<?x?xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    AnyTensor:$source,
+    Variadic<Index>:$low,
+    Variadic<Index>:$high,
+    I64ArrayAttr:$static_low,
+    I64ArrayAttr:$static_high);
+
+  let regions = (region AnyRegion:$region);
+
+  let results = (outs AnyTensor:$result);
+
+  let extraClassDeclaration = [{
+    static StringRef getStaticLowAttrName() {
+      return "static_low";
+    }
+
+    static StringRef getStaticHighAttrName() {
+      return "static_high";
+    }
+
+    // Infer the shape of the result tensor given the static shapes
+    // and element type of the result tensor.
+    static RankedTensorType inferResultType(RankedTensorType sourceType,
+                                ArrayRef<int64_t> staticLow,
+                                ArrayRef<int64_t> staticHigh);
+  }];
+
+  let builders = [
+    // Build a PadTensorOp with mixed static and dynamic entries.
+    OpBuilderDAG<(ins "Value":$source, "ArrayRef<int64_t>":$staticLow,
+      "ArrayRef<int64_t>":$staticHigh, "ValueRange":$low, "ValueRange":$high,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build a PadTensorOp with all dynamic entries.
+    OpBuilderDAG<(ins "Value":$source, "ValueRange":$low, "ValueRange":$high,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
+  ];
+}
+
 def Linalg_RangeOp :
     Linalg_Op<"range", [NoSideEffect]>,
     Arguments<(ins Index:$min, Index:$max, Index:$step)>,

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index fa98ed0cfbc9..b500eefa9d0c 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -915,6 +915,151 @@ void InitTensorOp::getCanonicalizationPatterns(
                  ReplaceStaticShapeDims>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// PadTensorOp
+//===----------------------------------------------------------------------===//
+
+/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
+static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
+  return llvm::to_vector<4>(
+      llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
+        return a.cast<IntegerAttr>().getInt();
+      }));
+}
+
+static LogicalResult verify(PadTensorOp op) {
+  auto sourceType = op.source().getType().cast<RankedTensorType>();
+  auto resultType = op.result().getType().cast<RankedTensorType>();
+  auto expectedType = PadTensorOp::inferResultType(
+      sourceType, extractFromI64ArrayAttr(op.static_low()),
+      extractFromI64ArrayAttr(op.static_high()));
+  if (resultType != expectedType) {
+    return op.emitError("specified type ")
+           << resultType << " does not match the inferred type "
+           << expectedType;
+  }
+
+  auto &region = op.region();
+  if (!llvm::hasSingleElement(region))
+    return op.emitOpError("expected region with 1 block");
+  unsigned rank = resultType.getRank();
+  Block &block = region.front();
+  if (block.getNumArguments() != rank)
+    return op.emitError("expected the block to have ") << rank << " arguments";
+
+  // Note: the number and type of yield values are checked in the YieldOp.
+  for (auto en : llvm::enumerate(block.getArgumentTypes())) {
+    if (!en.value().isIndex())
+      return op.emitOpError("expected block argument ")
+             << (en.index() + 1) << " to be an index";
+  }
+
+  return success();
+}
+
+RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType,
+                                              ArrayRef<int64_t> staticLow,
+                                              ArrayRef<int64_t> staticHigh) {
+  unsigned rank = sourceType.getRank();
+  assert(staticLow.size() == rank && "unexpected staticLow size mismatch");
+  assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch");
+
+  SmallVector<int64_t, 4> resultShape;
+  for (auto i : llvm::seq<unsigned>(0, rank)) {
+    if (sourceType.isDynamicDim(i) ||
+        staticLow[i] == ShapedType::kDynamicSize ||
+        staticHigh[i] == ShapedType::kDynamicSize) {
+      resultShape.push_back(ShapedType::kDynamicSize);
+    } else {
+      int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
+      resultShape.push_back(size);
+    }
+  }
+
+  return RankedTensorType::get(resultShape, sourceType.getElementType());
+}
+
+static ParseResult parsePadTensorOp(OpAsmParser &parser,
+                                    OperationState &result) {
+  OpAsmParser::OperandType baseInfo;
+  SmallVector<OpAsmParser::OperandType, 8> operands;
+  SmallVector<Type, 8> types;
+  if (parser.parseOperand(baseInfo))
+    return failure();
+
+  IndexType indexType = parser.getBuilder().getIndexType();
+  SmallVector<OpAsmParser::OperandType, 4> lowPadding, highPadding;
+  if (parser.parseKeyword("low") ||
+      parseListOfOperandsOrIntegers(parser, result,
+                                    PadTensorOp::getStaticLowAttrName(),
+                                    ShapedType::kDynamicSize, lowPadding))
+    return failure();
+  if (parser.parseKeyword("high") ||
+      parseListOfOperandsOrIntegers(parser, result,
+                                    PadTensorOp::getStaticHighAttrName(),
+                                    ShapedType::kDynamicSize, highPadding))
+    return failure();
+
+  SmallVector<OpAsmParser::OperandType, 8> regionOperands;
+  std::unique_ptr<Region> region = std::make_unique<Region>();
+  SmallVector<Type, 8> operandTypes, regionTypes;
+  if (parser.parseRegion(*region, regionOperands, regionTypes))
+    return failure();
+  result.addRegion(std::move(region));
+
+  Type srcType, dstType;
+  if (parser.parseColonType(srcType) || parser.parseKeywordType("to", dstType))
+    return failure();
+
+  if (parser.addTypeToList(dstType, result.types))
+    return failure();
+
+  SmallVector<int, 4> segmentSizesFinal = {1}; // source tensor
+  segmentSizesFinal.append({static_cast<int>(lowPadding.size()),
+                            static_cast<int>(highPadding.size())});
+  result.addAttribute(
+      OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(),
+      parser.getBuilder().getI32VectorAttr(segmentSizesFinal));
+  return failure(
+      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.resolveOperand(baseInfo, srcType, result.operands) ||
+      parser.resolveOperands(lowPadding, indexType, result.operands) ||
+      parser.resolveOperands(highPadding, indexType, result.operands));
+}
+
+static void print(OpAsmPrinter &p, PadTensorOp op) {
+  p << op->getName().getStringRef() << ' ';
+  p << op.source();
+  p << " low";
+  printListOfOperandsOrIntegers(p, op.low(), op.static_low(),
+                                ShapedType::isDynamic);
+  p << " high";
+  printListOfOperandsOrIntegers(p, op.high(), op.static_high(),
+                                ShapedType::isDynamic);
+  p.printRegion(op.region());
+  p << " : " << op.source().getType() << " to " << op.getType();
+}
+
+void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
+                        ArrayRef<int64_t> staticLow,
+                        ArrayRef<int64_t> staticHigh, ValueRange low,
+                        ValueRange high, ArrayRef<NamedAttribute> attrs) {
+  auto sourceType = source.getType().cast<RankedTensorType>();
+  auto resultType = inferResultType(sourceType, staticLow, staticHigh);
+  build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
+        b.getI64ArrayAttr(staticHigh));
+  result.addAttributes(attrs);
+}
+
+void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
+                        ValueRange low, ValueRange high,
+                        ArrayRef<NamedAttribute> attrs) {
+  auto sourceType = source.getType().cast<RankedTensorType>();
+  unsigned rank = sourceType.getRank();
+  SmallVector<int64_t, 4> staticVector(ShapedType::kDynamicSize, rank);
+  build(b, result, source, staticVector, staticVector, low, high, attrs);
+}
+
 //===----------------------------------------------------------------------===//
 // ReshapeOp
 //===----------------------------------------------------------------------===//
@@ -1557,6 +1702,13 @@ static LogicalResult verify(linalg::YieldOp op) {
   if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
     return verifyYield(op, cast<LinalgOp>(parentOp));
 
+  if (auto padTensorOp = dyn_cast<linalg::PadTensorOp>(parentOp)) {
+    return success(
+        op.getNumOperands() == 1 &&
+        op.getOperand(0).getType() ==
+            padTensorOp.getType().cast<ShapedType>().getElementType());
+  }
+
   return op.emitOpError("expected parent op with LinalgOp interface");
 }
 

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 4359eebebbc1..a3ef242c29f9 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -617,3 +617,45 @@ func @illegal_expanding_reshape_mixed_memref_2(%arg0 : memref<?x4x5xf32>) -> mem
        memref<?x4x5xf32> into memref<?x?xf32>
   return %0 : memref<?x?xf32>
 }
+
+// -----
+
+func @pad_result_type(%arg0: tensor<?x2x3x4xi32>, %arg1: index, %arg2: i32) -> tensor<?x?x?x8xf32> {
+  // expected-error @+1 {{specified type 'tensor<?x?x?x8xf32>' does not match the inferred type 'tensor<?x?x?x9xi32>}}
+  %0 = linalg.pad_tensor %arg0 low[1, %arg1, 2, 2] high[1, 2, %arg1, 3] {
+  ^bb0(%arg3: index, %arg4: index):  // no predecessors
+    linalg.yield %arg2 : i32
+  } : tensor<?x2x3x4xi32> to tensor<?x?x?x8xf32>
+  return %0 : tensor<?x?x?x8xf32>
+}
+
+// -----
+
+func @pad_number_of_block_args(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9xi32> {
+  // expected-error @+1 {{expected the block to have 2 arguments}}
+  %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] {
+  ^bb0(%arg2: index, %arg3: index, %arg4: index):  // no predecessors
+    linalg.yield %arg1 : i32
+  } : tensor<?x4xi32> to tensor<?x9xi32>
+  return %0 : tensor<?x9xi32>
+}
+
+// -----
+
+func @pad_no_block(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9xi32> {
+  // expected-error @+1 {{expected region with 1 block}}
+  %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] {
+  } : tensor<?x4xi32> to tensor<?x9xi32>
+  return %0 : tensor<?x9xi32>
+}
+
+// -----
+
+func @pad_block_args(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9xi32> {
+  // expected-error @+1 {{op expected block argument 1 to be an index}}
+  %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] {
+  ^bb0(%arg2: i32, %arg3: i32):  // no predecessors
+    linalg.yield %arg1 : i32
+  } : tensor<?x4xi32> to tensor<?x9xi32>
+  return %0 : tensor<?x9xi32>
+}

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index d0121b0c90c7..c4a3247fdc88 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -5,6 +5,58 @@
 // Test that we can lower all the way to LLVM without crashing, don't check results here.
 // DISABLED: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1
 
+func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index,
+                  %pad_value: f32) -> tensor<6x?x?x?xf32> {
+  %0 = linalg.pad_tensor %arg0 low[2, %low, 3, 3] high[3, 3, %high, 2] {
+    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
+      linalg.yield %pad_value : f32
+    } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
+  return %0 : tensor<6x?x?x?xf32>
+}
+// CHECK-LABEL: func @pad_dynamic
+//  CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+//  CHECK-SAME: %[[LOW:[a-zA-Z0-9_]*]]
+//  CHECK-SAME: %[[HIGH:[a-zA-Z0-9_]*]]
+//       CHECK:   linalg.pad_tensor %[[ARG0]]
+//  CHECK-SAME:     low[2, %[[LOW]], 3, 3]
+//  CHECK-SAME:     high[3, 3, %[[HIGH]], 2]
+//       CHECK:    : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
+
+// -----
+
+func @pad_static(%arg0: tensor<3x4xf32>, %pad_value: f32) -> tensor<6x9xf32> {
+  %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] {
+    ^bb0(%arg1 : index, %arg2 : index):
+      linalg.yield %pad_value : f32
+    } : tensor<3x4xf32> to tensor<6x9xf32>
+  return %0 : tensor<6x9xf32>
+}
+// CHECK-LABEL: func @pad_static
+//  CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+//       CHECK:   linalg.pad_tensor %[[ARG0]] low[1, 2] high[2, 3]
+//       CHECK:    : tensor<3x4xf32> to tensor<6x9xf32>
+
+// -----
+
+func @pad_asymmetrical(%arg0: tensor<2x3xf32>, %ub0: index, %ub1: index,
+                       %pad_value: f32) -> tensor<?x?xf32> {
+  %0 = linalg.pad_tensor %arg0 low[0, 0] high[%ub0, %ub1] {
+    ^bb0(%arg1: index, %arg2: index):
+      linalg.yield %pad_value : f32
+    } : tensor<2x3xf32> to tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @pad_asymmetrical
+//  CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+//  CHECK-SAME: %[[UB0:[a-zA-Z0-9_]*]]
+//  CHECK-SAME: %[[UB1:[a-zA-Z0-9_]*]]
+//       CHECK:   linalg.pad_tensor %[[ARG0]]
+//  CHECK-SAME:     low[0, 0]
+//  CHECK-SAME:     high[%[[UB0]], %[[UB1]]]
+//       CHECK:    : tensor<2x3xf32> to tensor<?x?xf32>
+
+// -----
+
 func @range(%arg0: index, %arg1: index, %arg2: index) {
   %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
   return


        


More information about the Mlir-commits mailing list