[Mlir-commits] [mlir] 5da010a - [mlir][linalg] Add optional output operand to PadTensorOp

Matthias Springer llvmlistbot at llvm.org
Wed Jul 14 18:27:02 PDT 2021


Author: Matthias Springer
Date: 2021-07-15T10:20:36+09:00
New Revision: 5da010af9a058a70fc301b9c02d4ff370ab2f9a7

URL: https://github.com/llvm/llvm-project/commit/5da010af9a058a70fc301b9c02d4ff370ab2f9a7
DIFF: https://github.com/llvm/llvm-project/commit/5da010af9a058a70fc301b9c02d4ff370ab2f9a7.diff

LOG: [mlir][linalg] Add optional output operand to PadTensorOp

This optional operand will be used for tiling in a subsequent commit.

Differential Revision: https://reviews.llvm.org/D105459

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index ffd65f7138efc..cb0507d6f903b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -146,6 +146,7 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
            dimension, i.e `low`.
     * high: A list contains the padding along the end of each
            dimension, i.e. `high`.
+    * output: An optional output operand.
 
     The result tensor dimensions are `low` + `dim` + `high` along that
     dimension. The number of elements of `low` and `high` must match
@@ -194,16 +195,21 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     Variadic<Index>:$low,
     Variadic<Index>:$high,
     I64ArrayAttr:$static_low,
-    I64ArrayAttr:$static_high);
+    I64ArrayAttr:$static_high,
+    Optional<AnyTensor>:$output);
 
   let regions = (region SizedRegion<1>:$region);
 
   let results = (outs AnyTensor:$result);
 
+  // TODO: Remove custom<InferType> when AllTypesMatch supports opt. operands.
   let assemblyFormat = [{
-    $source `low` `` custom<OperandsOrIntegersSizesList>($low, $static_low)
+    $source
+    `low` `` custom<OperandsOrIntegersSizesList>($low, $static_low)
     `high` `` custom<OperandsOrIntegersSizesList>($high, $static_high)
+    (`into` $output^ )?
     $region attr-dict `:` type($source) `to` type($result)
+    custom<InferType>(ref($output), type($output), ref(type($result)))
   }];
 
   let extraClassDeclaration = [{
@@ -292,7 +298,12 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     // result type. If the type passed is nullptr, it is inferred.
     OpBuilder<(ins "Type":$resultType, "Value":$source,
       "ArrayRef<OpFoldResult>":$low, "ArrayRef<OpFoldResult>":$high,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
+    // Build a PadTensorOp with mixed static and dynamic entries and custom
+    // result type.
+    OpBuilder<(ins "Type":$resultType, "Value":$source,
+      "ArrayRef<Value>":$low, "ArrayRef<Value>":$high, "ArrayAttr":$staticLow,
+      "ArrayAttr":$staticHigh)>
   ];
 
   let hasCanonicalizer = 1;

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 4ef942d547765..32efdc2fe19ee 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -855,6 +855,19 @@ LogicalResult InitTensorOp::reifyReturnTypeShapesPerResultDim(
 // PadTensorOp
 //===----------------------------------------------------------------------===//
 
+// TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
+// supports optional types.
+void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
+                    Type typeToInfer, Type typeToInferFrom) {}
+
+ParseResult parseInferType(OpAsmParser &parser,
+                           Optional<OpAsmParser::OperandType> optOperand,
+                           Type &typeToInfer, Type typeToInferFrom) {
+  if (optOperand)
+    typeToInfer = typeToInferFrom;
+  return success();
+}
+
 static LogicalResult verify(PadTensorOp op) {
   auto sourceType = op.source().getType().cast<RankedTensorType>();
   auto resultType = op.result().getType().cast<RankedTensorType>();
@@ -870,6 +883,9 @@ static LogicalResult verify(PadTensorOp op) {
            << resultType << " does not match the inferred type "
            << expectedType;
   }
+  if (op.output() && op.output().getType() != op.getResultType()) {
+    op.emitError("expected that output operand type equals result type");
+  }
 
   auto &region = op.region();
   unsigned rank = resultType.getRank();
@@ -916,7 +932,7 @@ void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
   auto sourceType = source.getType().cast<RankedTensorType>();
   auto resultType = inferResultType(sourceType, staticLow, staticHigh);
   build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
-        b.getI64ArrayAttr(staticHigh));
+        b.getI64ArrayAttr(staticHigh), /*output=*/Value());
   result.addAttributes(attrs);
 }
 
@@ -953,7 +969,15 @@ void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
         PadTensorOp::inferResultType(sourceType, staticLow, staticHigh);
   }
   build(b, result, resultType, source, dynamicLow, dynamicHigh,
-        b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh));
+        b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh),
+        /*output=*/Value());
+}
+
+void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
+                        Value source, ArrayRef<Value> low, ArrayRef<Value> high,
+                        ArrayAttr staticLow, ArrayAttr staticHigh) {
+  build(b, result, resultType, source, low, high, staticLow, staticHigh,
+        /*output=*/{});
 }
 
 PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad,
@@ -1038,11 +1062,25 @@ struct FoldStaticZeroPadding : public OpRewritePattern<PadTensorOp> {
   }
 };
 
+// Fold tensor.dim(pad_tensor(%input, %output)) to tensor.dim(%output).
+struct FoldToDimOfOutputOperand : public OpRewritePattern<tensor::DimOp> {
+  using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
+                                PatternRewriter &rewriter) const override {
+    auto padTensorOp = dimOp.source().getDefiningOp<PadTensorOp>();
+    if (!padTensorOp || !padTensorOp.output())
+      return failure();
+    rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, padTensorOp.output(),
+                                               dimOp.index());
+    return success();
+  }
+};
 } // namespace
 
 void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<FoldStaticZeroPadding>(context);
+  results.add<FoldStaticZeroPadding, FoldToDimOfOutputOperand>(context);
 }
 
 /// Return the padding value of the PadTensorOp if it constant. In this context,

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 5f7ad8ddfde28..03e19909fa284 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -902,3 +902,21 @@ func @rank_reducing_init_extract(%sz : index, %idx : index) -> tensor<2xf32> {
   %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor<?x2xf32> to tensor<2xf32>
   return %r: tensor<2xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @dim_of_pad_tensor(
+//  CHECK-SAME:     %[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>
+//       CHECK:     %[[C0:.*]] = constant 0 : index
+//       CHECK:     %[[RESULT:.*]] = tensor.dim %[[ARG1]], %[[C0]]
+//       CHECK:     return %[[RESULT]]
+func @dim_of_pad_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
+                        %pad_value: f32) -> index {
+  %c0 = constant 0 : index
+  %0 = linalg.pad_tensor %arg0 low[2, 3] high[4, 5] into %arg1 {
+    ^bb0(%arg2: index, %arg3: index):
+      linalg.yield %pad_value : f32
+    } : tensor<?x?xf32> to tensor<?x?xf32>
+  %r = tensor.dim %0, %c0 : tensor<?x?xf32>
+  return %r : index
+}

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 6d8536a730d7a..1de4747088037 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -584,6 +584,18 @@ func @pad_result_type(%arg0: tensor<?x2x3x4xi32>, %arg1: index, %arg2: i32) -> t
 
 // -----
 
+// expected-note at +1 {{prior use here}}
+func @pad_output_type(%arg0: tensor<?x2x3x4xi32>, %arg1: index, %arg2: i32, %output: tensor<?x6x6x7xf32>) -> tensor<?x?x?x8xf32> {
+  // expected-error @+1 {{use of value '%output' expects 
diff erent type than prior uses: 'tensor<?x5x6x7xf32>' vs 'tensor<?x6x6x7xf32>'}}
+  %0 = linalg.pad_tensor %arg0 low[1, 1, 1, 1] high[2, 2, 2, 2] into %output {
+  ^bb0(%arg3: index, %arg4: index):  // no predecessors
+    linalg.yield %arg2 : i32
+  } : tensor<?x2x3x4xi32> to tensor<?x5x6x7xf32>
+  return %0 : tensor<?x5x6x7xf32>
+}
+
+// -----
+
 func @pad_number_of_block_args(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9xi32> {
   // expected-error @+1 {{expected the block to have 2 arguments}}
   %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] {

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index d4f071a74e508..5d842bba59602 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -51,6 +51,24 @@ func @pad_static(%arg0: tensor<3x4xf32>, %pad_value: f32) -> tensor<6x9xf32> {
 
 // -----
 
+func @pad_static_with_output(%arg0: tensor<3x4xf32>,
+                             %out_tensor : tensor<6x9xf32>,
+                             %pad_value: f32)
+    -> tensor<6x9xf32> {
+  %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] into %out_tensor {
+    ^bb0(%arg1 : index, %arg2 : index):
+      linalg.yield %pad_value : f32
+    } : tensor<3x4xf32> to tensor<6x9xf32>
+  return %0 : tensor<6x9xf32>
+}
+// CHECK-LABEL: func @pad_static
+//  CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: tensor<3x4xf32>,
+//  CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: tensor<6x9xf32>,
+//       CHECK:   linalg.pad_tensor %[[ARG0]] low[1, 2] high[2, 3] into %[[ARG1]]
+//       CHECK:    : tensor<3x4xf32> to tensor<6x9xf32>
+
+// -----
+
 func @pad_asymmetrical(%arg0: tensor<2x3xf32>, %ub0: index, %ub1: index,
                        %pad_value: f32) -> tensor<?x?xf32> {
   %0 = linalg.pad_tensor %arg0 low[0, 0] high[%ub0, %ub1] {


        


More information about the Mlir-commits mailing list