[Mlir-commits] [mlir] cf9503c - [mlir] Add subtensor_insert operation

Nicolas Vasilache llvmlistbot at llvm.org
Fri Oct 2 03:34:50 PDT 2020


Author: Nicolas Vasilache
Date: 2020-10-02T06:32:31-04:00
New Revision: cf9503c1b752062d9abfb2c7922a50574d9c5de4

URL: https://github.com/llvm/llvm-project/commit/cf9503c1b752062d9abfb2c7922a50574d9c5de4
DIFF: https://github.com/llvm/llvm-project/commit/cf9503c1b752062d9abfb2c7922a50574d9c5de4.diff

LOG: [mlir] Add subtensor_insert operation

Differential revision: https://reviews.llvm.org/D88657

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/IR/core-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 3d9daee964b6..c62be7571aad 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2922,15 +2922,20 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
 
     The SubView operation supports the following arguments:
 
-    * Memref: the "base" memref on which to create a "view" memref.
-    * Offsets: memref-rank number of dynamic offsets or static integer
-               attributes into the "base" memref at which to create the "view"
-               memref.
-    * Sizes: memref-rank number of dynamic sizes or static integer attributes
-             which specify the sizes of the result "view" memref type.
-    * Strides: memref-rank number of dynamic strides or static integer
-               attributes that compose multiplicatively with  the base memref
-               strides in each dimension.
+    * semref: the "base" memref on which to create a "view" memref.
+    * offsets: memref-rank number of offsets into the "base" memref at which to
+               create the "view" memref.
+    * sizes: memref-rank number of sizes which specify the sizes of the result
+             "view" memref type.
+    * strides: memref-rank number of strides that compose multiplicatively with
+               the base memref strides in each dimension.
+
+    The representation based on offsets, sizes and strides support a
+    partially-static specification via attributes specified through the
+    `static_offsets`, `static_sizes` and `static_strides` arguments. A special
+    sentinel value ShapedType::kDynamicSize and
+    ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has
+    a static value.
 
     A subview operation may additionally reduce the rank of the resulting view
     by removing dimensions that are statically known to be of size 1.
@@ -3076,7 +3081,7 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
 
   let extraClassDeclaration = extraBaseClassDeclaration # [{
     /// Returns the type of the base memref operand.
-    MemRefType getSourceMemRefType() {
+    MemRefType getSourceType() {
       return source().getType().cast<MemRefType>();
     }
 
@@ -3108,13 +3113,19 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor"> {
     The subtensor operation supports the following arguments:
 
     * tensor: the "base" tensor from which to extract a subtensor.
-    * offsets: tensor-rank number of dynamic offsets or static integer
-               attributes into the "base" tensor from which to extract the
-               subtensor.
-    * sizes: tensor-rank number of dynamic sizes or static integer attributes
-             which specify the sizes of the result tensor type.
-    * strides: tensor-rank number of dynamic strides or static integer
-               attributes specifying susampling in each dimension.
+    * offsets: tensor-rank number of offsets into the "base" tensor from which
+               to extract the subtensor.
+    * sizes: tensor-rank number of sizes which specify the sizes of the result
+             tensor type.
+    * strides: tensor-rank number of strides specifying subsampling in each 
+               dimension.
+
+    The representation based on offsets, sizes and strides support a
+    partially-static specification via attributes specified through the
+    `static_offsets`, `static_sizes` and `static_strides` arguments. A special
+    sentinel value ShapedType::kDynamicSize and
+    ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has
+    a static value.
 
     After buffer-allocation, the "subtensor" op is expected to lower into a
     "subview" op.
@@ -3144,9 +3155,22 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor"> {
   );
   let results = (outs AnyRankedTensor:$result);
 
+  let builders = [
+    // Build a SubViewOp with mixed static and dynamic entries.
+    OpBuilder<
+      "Value source, ArrayRef<int64_t> staticOffsets, "
+      "ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides, "
+      "ValueRange offsets, ValueRange sizes, ValueRange strides, "
+      "ArrayRef<NamedAttribute> attrs = {}">,
+    // Build a SubViewOp with all dynamic entries.
+    OpBuilder<
+      "Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, "
+      "ArrayRef<NamedAttribute> attrs = {}">
+  ];
+
   let extraClassDeclaration = extraBaseClassDeclaration # [{
     /// Returns the type of the base tensor operand.
-    RankedTensorType getSourceRankedTensorType() {
+    RankedTensorType getSourceType() {
       return source().getType().cast<RankedTensorType>();
     }
 
@@ -3167,6 +3191,80 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor"> {
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// SubTensorInsertOp
+//===----------------------------------------------------------------------===//
+
+def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<"subtensor_insert"> {
+  let summary = "subtensor_insert operation";
+  let description = [{
+    The "subtensor_insert" operation insert a tensor `source` into another
+    tensor `dest` as specified by the operation's offsets, sizes and strides
+    arguments.
+
+    It returns a copy of `dest` with the proper subtensor updated with the value
+    of `source`.
+
+    The subtensor_insert operation has the encodes the following information:
+
+    * source: the tensor that is inserted.
+    * dest: the tensor into which the source tensor is inserted.
+    * offsets: tensor-rank number of offsets into the "base" tensor from which
+               to extract the subtensor.
+    * sizes: tensor-rank number of sizes which specify the sizes of the result
+             tensor type.
+    * strides: tensor-rank number of strides that specify subsampling in each
+               dimension.
+
+    The representation based on offsets, sizes and strides support a
+    partially-static specification via attributes specified through the
+    `static_offsets`, `static_sizes` and `static_strides` arguments. A special
+    sentinel value ShapedType::kDynamicSize and
+    ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has
+    a static value.
+
+    After buffer-allocation, the "subtensor_insert" op is expected to become
+    an in-place buffer update.
+  }];
+
+  let arguments = (ins
+    AnyRankedTensor:$source,
+    AnyRankedTensor:$dest,
+    Variadic<Index>:$offsets,
+    Variadic<Index>:$sizes,
+    Variadic<Index>:$strides,
+    I64ArrayAttr:$static_offsets,
+    I64ArrayAttr:$static_sizes,
+    I64ArrayAttr:$static_strides
+  );
+  let results = (outs AnyRankedTensor:$result);
+
+  let builders = [
+    // Build a SubViewOp with mixed static and dynamic entries.
+    OpBuilder<
+      "Value source, Value dest, ArrayRef<int64_t> staticOffsets, "
+      "ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides, "
+      "ValueRange offsets, ValueRange sizes, ValueRange strides, "
+      "ArrayRef<NamedAttribute> attrs = {}">,
+    // Build a SubViewOp with all dynamic entries.
+    OpBuilder<
+      "Value source, Value dest, ValueRange offsets, ValueRange sizes, "
+      "ValueRange strides, ArrayRef<NamedAttribute> attrs = {}">
+  ];
+
+  let extraClassDeclaration = extraBaseClassDeclaration # [{
+    /// Returns the type of the base tensor operand.
+    RankedTensorType getSourceType() {
+      return source().getType().cast<RankedTensorType>();
+    }
+
+    /// The result of a subtensor is always a tensor.
+    RankedTensorType getType() {
+      return getResult().getType().cast<RankedTensorType>();
+    }
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // TanhOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 5548274eee18..7f4e2ffa5262 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -23,6 +23,7 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
@@ -2639,10 +2640,15 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
 ///     `:` strided-memref-type `to` strided-memref-type
 /// ```
 template <typename OpType>
-static void printOpWithOffsetsSizesAndStrides(OpAsmPrinter &p, OpType op) {
+static void printOpWithOffsetsSizesAndStrides(
+    OpAsmPrinter &p, OpType op,
+    llvm::function_ref<void(OpAsmPrinter &p, OpType op)> printExtraOperands =
+        [](OpAsmPrinter &p, OpType op) {},
+    StringLiteral resultTypeKeyword = "to") {
   int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
   p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' ';
-  p << op.getOperand(0);
+  p << op.source();
+  printExtraOperands(p, op);
   printSubViewListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(),
                                        ShapedType::isDynamicStrideOrOffset);
   printSubViewListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(),
@@ -2651,27 +2657,35 @@ static void printOpWithOffsetsSizesAndStrides(OpAsmPrinter &p, OpType op) {
                                        ShapedType::isDynamicStrideOrOffset);
   p.printOptionalAttrDict(op.getAttrs(),
                           /*elidedAttrs=*/{OpType::getSpecialAttrNames()});
-  p << " : " << op.getOperand(0).getType() << " to " << op.getType();
+  p << " : " << op.getSourceType() << " " << resultTypeKeyword << " "
+    << op.getType();
 }
 
 static void print(OpAsmPrinter &p, SubViewOp op) {
   return printOpWithOffsetsSizesAndStrides<SubViewOp>(p, op);
 }
 
-/// Parse SubViewOp of the form:
+/// Parse of the form:
 /// ```
-///   `name` ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
-///     `:` strided-memref-type `to` strided-memref-type
+///   `name` ssa-name (extra-operands)?
+///       `[` offset-list `]` `[` size-list `]` `[` stride-list `]`
+///     `:` strided-memref-type `resultTypeKeyword strided-memref-type
 /// ```
 template <typename OpType>
-static ParseResult parseOpWithOffsetsSizesAndStrides(OpAsmParser &parser,
-                                                     OperationState &result) {
-  OpAsmParser::OperandType srcInfo;
+static ParseResult parseOpWithOffsetsSizesAndStrides(
+    OpAsmParser &parser, OperationState &result,
+    std::function<ParseResult(OpAsmParser &p,
+                              OpAsmParser::OperandType &dstInfo)>
+        parseExtraOperand = nullptr,
+    StringLiteral resultTypeKeyword = "to") {
+  OpAsmParser::OperandType srcInfo, dstInfo;
   SmallVector<OpAsmParser::OperandType, 4> offsetsInfo, sizesInfo, stridesInfo;
   auto indexType = parser.getBuilder().getIndexType();
   Type srcType, dstType;
   if (parser.parseOperand(srcInfo))
     return failure();
+  if (parseExtraOperand && parseExtraOperand(parser, dstInfo))
+    return failure();
   if (parseListOfOperandsOrIntegers(
           parser, result, OpType::getStaticOffsetsAttrName(),
           ShapedType::kDynamicStrideOrOffset, offsetsInfo) ||
@@ -2683,21 +2697,27 @@ static ParseResult parseOpWithOffsetsSizesAndStrides(OpAsmParser &parser,
           ShapedType::kDynamicStrideOrOffset, stridesInfo))
     return failure();
 
+  // Handle segment sizes.
   auto b = parser.getBuilder();
-  SmallVector<int, 4> segmentSizes{1, static_cast<int>(offsetsInfo.size()),
-                                   static_cast<int>(sizesInfo.size()),
-                                   static_cast<int>(stridesInfo.size())};
+  SmallVector<int, 4> segmentSizes = {1, static_cast<int>(offsetsInfo.size()),
+                                      static_cast<int>(sizesInfo.size()),
+                                      static_cast<int>(stridesInfo.size())};
+  // If we parse an extra operand it needs to appear in the segmentSizes
+  if (parseExtraOperand)
+    segmentSizes.insert(segmentSizes.begin(), 1);
   result.addAttribute(OpType::getOperandSegmentSizeAttr(),
                       b.getI32VectorAttr(segmentSizes));
 
   return failure(
       parser.parseOptionalAttrDict(result.attributes) ||
       parser.parseColonType(srcType) ||
+      parser.parseKeywordType(resultTypeKeyword.str().c_str(), dstType) ||
       parser.resolveOperand(srcInfo, srcType, result.operands) ||
+      (parseExtraOperand &&
+       parser.resolveOperand(dstInfo, dstType, result.operands)) ||
       parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
       parser.resolveOperands(stridesInfo, indexType, result.operands) ||
-      parser.parseKeywordType("to", dstType) ||
       parser.addTypeToList(dstType, result.types));
 }
 
@@ -2894,7 +2914,7 @@ static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) {
 
 /// Verifier for SubViewOp.
 static LogicalResult verify(SubViewOp op) {
-  MemRefType baseType = op.getSourceMemRefType();
+  MemRefType baseType = op.getSourceType();
   MemRefType subViewType = op.getType();
 
   // The base memref and the view memref should be in the same memory space.
@@ -3273,8 +3293,7 @@ static LogicalResult verify(SubTensorOp op) {
 
   // Verify result type against inferred type.
   auto expectedType = SubTensorOp::inferResultType(
-      op.getSourceRankedTensorType(),
-      extractFromI64ArrayAttr(op.static_offsets()),
+      op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()),
       extractFromI64ArrayAttr(op.static_sizes()),
       extractFromI64ArrayAttr(op.static_strides()));
   if (!isRankReducedType(expectedType, op.getType()))
@@ -3291,6 +3310,72 @@ void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
           context);
 }
 
+//===----------------------------------------------------------------------===//
+// SubTensorInsertOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter &p, SubTensorInsertOp op) {
+  return printOpWithOffsetsSizesAndStrides<SubTensorInsertOp>(
+      p, op,
+      [](OpAsmPrinter &p, SubTensorInsertOp op) { p << " into " << op.dest(); },
+      /*resultTypeKeyword=*/"into");
+}
+
+static ParseResult parseSubTensorInsertOp(OpAsmParser &parser,
+                                          OperationState &result) {
+  return parseOpWithOffsetsSizesAndStrides<SubTensorInsertOp>(
+      parser, result,
+      [](OpAsmParser &parser, OpAsmParser::OperandType &dstInfo) {
+        return failure(parser.parseKeyword("into") ||
+                       parser.parseOperand(dstInfo));
+      },
+      "into");
+}
+
+void mlir::SubTensorInsertOp::build(
+    OpBuilder &b, OperationState &result, Value source, Value dest,
+    ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
+    ArrayRef<int64_t> staticStrides, ValueRange offsets, ValueRange sizes,
+    ValueRange strides, ArrayRef<NamedAttribute> attrs) {
+  build(b, result, dest.getType(), source, dest, offsets, sizes, strides,
+        b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes),
+        b.getI64ArrayAttr(staticStrides));
+  result.addAttributes(attrs);
+}
+
+/// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes`
+/// and `staticStrides` are  automatically filled with source-memref-rank
+/// sentinel values that encode dynamic entries.
+void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result,
+                                    Value source, Value dest,
+                                    ValueRange offsets, ValueRange sizes,
+                                    ValueRange strides,
+                                    ArrayRef<NamedAttribute> attrs) {
+  auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
+  unsigned rank = sourceRankedTensorType.getRank();
+  SmallVector<int64_t, 4> staticOffsetsVector(
+      rank, ShapedType::kDynamicStrideOrOffset);
+  SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
+  SmallVector<int64_t, 4> staticStridesVector(
+      rank, ShapedType::kDynamicStrideOrOffset);
+  build(b, result, source, dest, staticOffsetsVector, staticSizesVector,
+        staticStridesVector, offsets, sizes, strides, attrs);
+}
+
+SmallVector<Range, 8> SubTensorInsertOp::getOrCreateRanges(OpBuilder &b,
+                                                           Location loc) {
+  return ::getOrCreateRangesImpl(*this, b, loc);
+}
+
+/// Verifier for SubViewOp.
+static LogicalResult verify(SubTensorInsertOp op) {
+  if (failed(verifyOpWithOffsetSizesAndStrides(op)))
+    return failure();
+  if (op.getType() != op.dest().getType())
+    return op.emitError("expected result type to be ") << op.dest().getType();
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TensorCastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 72a063ff9d51..2590dc0105c4 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -901,7 +901,6 @@ func @assume_alignment(%0: memref<4x4xf16>) {
   return
 }
 
-
 // CHECK-LABEL: func @subtensor({{.*}}) {
 func @subtensor(%t: tensor<8x16x4xf32>, %idx : index) {
   %c0 = constant 0 : index
@@ -924,3 +923,21 @@ func @subtensor(%t: tensor<8x16x4xf32>, %idx : index) {
 
   return
 }
+
+// CHECK-LABEL: func @subtensor_insert({{.*}}) {
+func @subtensor_insert(%t: tensor<8x16x4xf32>, %t2: tensor<16x32x8xf32>, %idx : index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+
+  // CHECK: subtensor_insert
+  // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32>
+  %1 = subtensor_insert %t into %t2[%c0, %c0, %c0][%idx, %idx, %idx][%c1, %c1, %c1]
+    : tensor<8x16x4xf32> into tensor<16x32x8xf32>
+
+  // CHECK: subtensor_insert
+  // CHECK-SAME: tensor<8x16x4xf32> into tensor<16x32x8xf32>
+  %2 = subtensor_insert %t into %t2[%c0, %idx, %c0][%idx, 4, %idx][%c1, 1, %c1]
+    : tensor<8x16x4xf32> into tensor<16x32x8xf32>
+
+  return
+}


        


More information about the Mlir-commits mailing list