[Mlir-commits] [mlir] b686fdb - [mlir][Linalg] Drop output tensor from `linalg.pad_tensor` op.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Aug 31 11:12:51 PDT 2021


Author: MaheshRavishankar
Date: 2021-08-31T11:12:24-07:00
New Revision: b686fdbf92ea5b495804afdf1c7c4d4aab30ef33

URL: https://github.com/llvm/llvm-project/commit/b686fdbf92ea5b495804afdf1c7c4d4aab30ef33
DIFF: https://github.com/llvm/llvm-project/commit/b686fdbf92ea5b495804afdf1c7c4d4aab30ef33.diff

LOG: [mlir][Linalg] Drop output tensor from `linalg.pad_tensor` op.

The output tensor was added for tiling purposes. With use of
`TilingInterface` for tiling pad operations, there is no need for an
explicit operand for the shape of result of `linalg.pad_tensor`
op. The interface allows the tiling pattern to query the value that
can be used for the "init" needed for tiling dynamically.

Differential Revision: https://reviews.llvm.org/D108613

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 9b062f2ebd746..332993d8b0022 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -147,12 +147,11 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
            dimension, i.e `low`.
     * high: A list contains the padding along the end of each
            dimension, i.e. `high`.
-    * output: An optional output operand.
 
     The result tensor dimensions are `low` + `dim` + `high` along that
     dimension. The number of elements of `low` and `high` must match
-    the rank of the input tensor (which is also the rank of the output
-    tensor). They can be either a constant or a dynamic value.
+    the rank of the input tensor. They can be either a constant or a
+    dynamic value.
 
     The region of the `pad_tensor` operation returns the value to use
     for the padding. The arguments of the region represent the index
@@ -196,8 +195,7 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     Variadic<Index>:$low,
     Variadic<Index>:$high,
     I64ArrayAttr:$static_low,
-    I64ArrayAttr:$static_high,
-    Optional<AnyTensor>:$output);
+    I64ArrayAttr:$static_high);
 
   let regions = (region SizedRegion<1>:$region);
 
@@ -208,9 +206,7 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     $source
     `low` `` custom<OperandsOrIntegersSizesList>($low, $static_low)
     `high` `` custom<OperandsOrIntegersSizesList>($high, $static_high)
-    (`into` $output^ )?
     $region attr-dict `:` type($source) `to` type($result)
-    custom<InferType>(ref($output), type($output), ref(type($result)))
   }];
 
   let extraClassDeclaration = [{
@@ -300,11 +296,6 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     OpBuilder<(ins "Type":$resultType, "Value":$source,
       "ArrayRef<OpFoldResult>":$low, "ArrayRef<OpFoldResult>":$high,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
-    // Build a PadTensorOp with mixed static and dynamic entries and custom
-    // result type.
-    OpBuilder<(ins "Type":$resultType, "Value":$source,
-      "ArrayRef<Value>":$low, "ArrayRef<Value>":$high, "ArrayAttr":$staticLow,
-      "ArrayAttr":$staticHigh)>
   ];
 
   let hasCanonicalizer = 1;

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1e759baecf1f4..abb3c44158381 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1040,9 +1040,6 @@ static LogicalResult verify(PadTensorOp op) {
            << resultType << " does not match the inferred type "
            << expectedType;
   }
-  if (op.output() && op.output().getType() != op.getResultType()) {
-    op.emitError("expected that output operand type equals result type");
-  }
 
   auto &region = op.region();
   unsigned rank = resultType.getRank();
@@ -1089,7 +1086,7 @@ void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
   auto sourceType = source.getType().cast<RankedTensorType>();
   auto resultType = inferResultType(sourceType, staticLow, staticHigh);
   build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
-        b.getI64ArrayAttr(staticHigh), /*output=*/Value());
+        b.getI64ArrayAttr(staticHigh));
   result.addAttributes(attrs);
 }
 
@@ -1126,15 +1123,7 @@ void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
         PadTensorOp::inferResultType(sourceType, staticLow, staticHigh);
   }
   build(b, result, resultType, source, dynamicLow, dynamicHigh,
-        b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh),
-        /*output=*/Value());
-}
-
-void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
-                        Value source, ArrayRef<Value> low, ArrayRef<Value> high,
-                        ArrayAttr staticLow, ArrayAttr staticHigh) {
-  build(b, result, resultType, source, low, high, staticLow, staticHigh,
-        /*output=*/{});
+        b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh));
 }
 
 PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad,
@@ -1221,7 +1210,8 @@ static Value getAsValue(OpBuilder &builder, Location loc, OpFoldResult ofr) {
 SmallVector<Value> PadTensorOp::getDestinationOperands(OpBuilder &b) {
   ReifiedRankedShapedTypeDims reifiedShapes;
   (void)reifyResultShapes(b, reifiedShapes);
-  Value initTensor = b.create<InitTensorOp>(getLoc(), reifiedShapes[0],
+  SmallVector<OpFoldResult> mixedSizes = getAsOpFoldResult(reifiedShapes[0]);
+  Value initTensor = b.create<InitTensorOp>(getLoc(), mixedSizes,
                                             getResultType().getElementType());
   return {initTensor};
 }
@@ -1465,21 +1455,6 @@ struct FoldStaticZeroPadding : public OpRewritePattern<PadTensorOp> {
   }
 };
 
-// Fold tensor.dim(pad_tensor(%input, %output)) to tensor.dim(%output).
-struct FoldToDimOfOutputOperand : public OpRewritePattern<tensor::DimOp> {
-  using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
-                                PatternRewriter &rewriter) const override {
-    auto padTensorOp = dimOp.source().getDefiningOp<PadTensorOp>();
-    if (!padTensorOp || !padTensorOp.output())
-      return failure();
-    rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, padTensorOp.output(),
-                                               dimOp.index());
-    return success();
-  }
-};
-
 // Fold CastOp into PadTensorOp when adding static information.
 struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
   using OpRewritePattern<PadTensorOp>::OpRewritePattern;
@@ -1503,7 +1478,7 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
       auto newOp = rewriter.create<PadTensorOp>(
           padTensorOp->getLoc(), newResultType, padTensorOp.source(),
           padTensorOp.low(), padTensorOp.high(), padTensorOp.static_low(),
-          padTensorOp.static_high(), /*output=*/nullptr);
+          padTensorOp.static_high());
       BlockAndValueMapping mapper;
       padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
 
@@ -1517,8 +1492,7 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
 
 void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<FoldStaticZeroPadding, FoldToDimOfOutputOperand,
-              FoldSourceTensorCast>(context);
+  results.add<FoldStaticZeroPadding, FoldSourceTensorCast>(context);
 }
 
 /// Return the padding value of the PadTensorOp if it constant. In this context,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index dff35353405bb..acef26a281437 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -357,10 +357,6 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
 static LogicalResult tilePadTensorOp(OpBuilder &builder, PadTensorOp op,
                                      PadTensorOp &newPadOp, LoopNest &loopNest,
                                      const LinalgTilingOptions &options) {
-  // Can tile only PadTensorOp that have an output operand.
-  if (!op.output())
-    return failure();
-
   Location loc = op.getLoc();
   OpBuilder::InsertionGuard g(builder);
   builder.setInsertionPoint(op);
@@ -383,8 +379,9 @@ static LogicalResult tilePadTensorOp(OpBuilder &builder, PadTensorOp op,
     }
   }
   // Generate loop nest: One loop per dimension.
+  SmallVector<Value> destOperand = op.getDestinationOperands(builder);
   loopNest = mlir::scf::buildLoopNest(
-      builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(op.output()),
+      builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(destOperand),
       [&](OpBuilder &b, Location loc, ValueRange localIvs,
           ValueRange iterArgs) -> scf::ValueVector {
         // Compute offsets and sizes of ExtractSliceOp.

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 41a4bfe9c9800..9b6be7920a999 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -904,24 +904,6 @@ func @rank_reducing_init_extract(%sz : index, %idx : index) -> tensor<2xf32> {
 
 // -----
 
-// CHECK-LABEL: func @dim_of_pad_tensor(
-//  CHECK-SAME:     %[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>
-//       CHECK:     %[[C0:.*]] = constant 0 : index
-//       CHECK:     %[[RESULT:.*]] = tensor.dim %[[ARG1]], %[[C0]]
-//       CHECK:     return %[[RESULT]]
-func @dim_of_pad_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
-                        %pad_value: f32) -> index {
-  %c0 = constant 0 : index
-  %0 = linalg.pad_tensor %arg0 low[2, 3] high[4, 5] into %arg1 {
-    ^bb0(%arg2: index, %arg3: index):
-      linalg.yield %pad_value : f32
-    } : tensor<?x?xf32> to tensor<?x?xf32>
-  %r = tensor.dim %0, %c0 : tensor<?x?xf32>
-  return %r : index
-}
-
-// -----
-
 // CHECK-LABEL: func @dim_of_tiled_loop_input(
 //  CHECK-SAME:     %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
 //       CHECK:   %[[c0:.*]] = constant 0 : index

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 36860272a80c8..3592d592acc3a 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -459,18 +459,6 @@ func @pad_result_type(%arg0: tensor<?x2x3x4xi32>, %arg1: index, %arg2: i32) -> t
 
 // -----
 
-// expected-note at +1 {{prior use here}}
-func @pad_output_type(%arg0: tensor<?x2x3x4xi32>, %arg1: index, %arg2: i32, %output: tensor<?x6x6x7xf32>) -> tensor<?x?x?x8xf32> {
-  // expected-error @+1 {{use of value '%output' expects 
diff erent type than prior uses: 'tensor<?x5x6x7xf32>' vs 'tensor<?x6x6x7xf32>'}}
-  %0 = linalg.pad_tensor %arg0 low[1, 1, 1, 1] high[2, 2, 2, 2] into %output {
-  ^bb0(%arg3: index, %arg4: index):  // no predecessors
-    linalg.yield %arg2 : i32
-  } : tensor<?x2x3x4xi32> to tensor<?x5x6x7xf32>
-  return %0 : tensor<?x5x6x7xf32>
-}
-
-// -----
-
 func @pad_number_of_block_args(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9xi32> {
   // expected-error @+1 {{expected the block to have 2 arguments}}
   %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] {

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index e0d7ab2dfb24f..23e29e0ab082a 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -49,24 +49,6 @@ func @pad_static(%arg0: tensor<3x4xf32>, %pad_value: f32) -> tensor<6x9xf32> {
 
 // -----
 
-func @pad_static_with_output(%arg0: tensor<3x4xf32>,
-                             %out_tensor : tensor<6x9xf32>,
-                             %pad_value: f32)
-    -> tensor<6x9xf32> {
-  %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] into %out_tensor {
-    ^bb0(%arg1 : index, %arg2 : index):
-      linalg.yield %pad_value : f32
-    } : tensor<3x4xf32> to tensor<6x9xf32>
-  return %0 : tensor<6x9xf32>
-}
-// CHECK-LABEL: func @pad_static
-//  CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: tensor<3x4xf32>,
-//  CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: tensor<6x9xf32>,
-//       CHECK:   linalg.pad_tensor %[[ARG0]] low[1, 2] high[2, 3] into %[[ARG1]]
-//       CHECK:    : tensor<3x4xf32> to tensor<6x9xf32>
-
-// -----
-
 func @pad_asymmetrical(%arg0: tensor<2x3xf32>, %ub0: index, %ub1: index,
                        %pad_value: f32) -> tensor<?x?xf32> {
   %0 = linalg.pad_tensor %arg0 low[0, 0] high[%ub0, %ub1] {

diff  --git a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
index 67be544db78ad..e2a22fa104b22 100644
--- a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
+++ b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
@@ -1,12 +1,12 @@
-// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3" -cse -split-input-file | \
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3" -resolve-shaped-type-result-dims -cse -split-input-file | \
 // RUN: FileCheck %s -check-prefix=TILE2
-// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,3" -cse -split-input-file | \
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,3" -resolve-shaped-type-result-dims -cse -split-input-file | \
 // RUN: FileCheck %s -check-prefix=TILE1
 
 //  TILE2-DAG:  #[[MAP0:.*]] = affine_map<()[s0] -> (s0 + 8)>
 //  TILE2-DAG:  #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 7)>
 //       TILE2: func @dynamic_pad_tensor(
-//  TILE2-SAME:     %[[IN:.*]]: tensor<?x?xf32>, %[[OUT:.*]]: tensor<?x?xf32>
+//  TILE2-SAME:     %[[IN:.*]]: tensor<?x?xf32>
 //   TILE2-DAG:   %[[C0:.*]] = constant 0 : index
 //   TILE2-DAG:   %[[C1:.*]] = constant 1 : index
 //   TILE2-DAG:   %[[C2:.*]] = constant 2 : index
@@ -25,16 +25,18 @@
 //       TILE2:       tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
 //       TILE2:   return %[[RESULT]]
 
-//   TILE1-DAG: #[[MAP:.*]] = affine_map<()[s0] -> (s0 + 7)>
+//   TILE1-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 + 7)>
+//   TILE1-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 8)>
 //       TILE1: func @dynamic_pad_tensor(
-//  TILE1-SAME:     %[[IN:.*]]: tensor<?x?xf32>, %[[OUT:.*]]: tensor<?x?xf32>
+//  TILE1-SAME:     %[[IN:.*]]: tensor<?x?xf32>
 //   TILE1-DAG:   %[[C0:.*]] = constant 0 : index
 //   TILE1-DAG:   %[[C1:.*]] = constant 1 : index
 //   TILE1-DAG:   %[[C3:.*]] = constant 3 : index
 //       TILE1:   %[[DIM_IN1:.*]] = tensor.dim %[[IN]], %[[C1]]
-//       TILE1:   %[[DIM1:.*]] = affine.apply #[[MAP]]()[%[[DIM_IN1]]]
+//       TILE1:   %[[DIM1:.*]] = affine.apply #[[MAP0]]()[%[[DIM_IN1]]]
+//       TILE1:   %[[DIM_IN0:.*]] = tensor.dim %[[IN]], %[[C0]]
+//       TILE1:   %[[DIM0:.*]] = affine.apply #[[MAP1]]()[%[[DIM_IN0]]]
 //       TILE1:   %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[DIM1]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
-//       TILE1:     %[[DIM0:.*]] = tensor.dim %[[OUT]], %[[C0]]
 //       TILE1:     %[[SWAP_RESULT:.*]] = scf.if
 //       TILE1:       tensor.generate
 //       TILE1:     else
@@ -44,10 +46,8 @@
 //       TILE1:   return %[[RESULT]]
 
 func @dynamic_pad_tensor(%input_tensor: tensor<?x?xf32>,
-                         %output_tensor: tensor<?x?xf32>,
                          %pad_value: f32) -> tensor<?x?xf32> {
-  %0 = linalg.pad_tensor %input_tensor
-    low[3, 4] high[5, 3] into %output_tensor{
+  %0 = linalg.pad_tensor %input_tensor low[3, 4] high[5, 3] {
     ^bb0(%arg1: index, %arg2: index):
       linalg.yield %pad_value : f32
     } : tensor<?x?xf32> to tensor<?x?xf32>
@@ -57,7 +57,7 @@ func @dynamic_pad_tensor(%input_tensor: tensor<?x?xf32>,
 // -----
 
 // TILE2-LABEL: func @static_pad_tensor(
-//  TILE2-SAME:     %[[IN:.*]]: tensor<7x9xf32>, %[[OUT:.*]]: tensor<15x16xf32>
+//  TILE2-SAME:     %[[IN:.*]]: tensor<7x9xf32>
 //   TILE2-DAG:   %[[C0:.*]] = constant 0 : index
 //   TILE2-DAG:   %[[C2:.*]] = constant 2 : index
 //   TILE2-DAG:   %[[C3:.*]] = constant 3 : index
@@ -75,7 +75,7 @@ func @dynamic_pad_tensor(%input_tensor: tensor<?x?xf32>,
 
 
 // TILE1-LABEL: func @static_pad_tensor(
-//  TILE1-SAME:     %[[IN:.*]]: tensor<7x9xf32>, %[[OUT:.*]]: tensor<15x16xf32>
+//  TILE1-SAME:     %[[IN:.*]]: tensor<7x9xf32>
 //   TILE1-DAG:   %[[C0:.*]] = constant 0 : index
 //   TILE1-DAG:   %[[C3:.*]] = constant 3 : index
 //   TILE1-DAG:   %[[C16:.*]] = constant 16 : index
@@ -89,10 +89,8 @@ func @dynamic_pad_tensor(%input_tensor: tensor<?x?xf32>,
 //       TILE1:   return %[[RESULT]]
 
 func @static_pad_tensor(%input_tensor: tensor<7x9xf32>,
-                        %output_tensor: tensor<15x16xf32>,
                         %pad_value: f32) -> tensor<15x16xf32> {
-  %0 = linalg.pad_tensor %input_tensor
-    low[3, 4] high[5, 3] into %output_tensor {
+  %0 = linalg.pad_tensor %input_tensor low[3, 4] high[5, 3] {
     ^bb0(%arg1: index, %arg2: index):
       linalg.yield %pad_value : f32
     } : tensor<7x9xf32> to tensor<15x16xf32>
@@ -112,7 +110,7 @@ func @static_pad_tensor(%input_tensor: tensor<7x9xf32>,
 //       TILE1:       scf.yield %[[GEN]] : tensor<14x3xf32>
 //       TILE1:     else
 //       TILE1:       %[[SLICE:.*]] = tensor.extract_slice %arg0[0, %{{.*}}] [7, %{{.*}}] [1, 1] : tensor<7x9xf32> to tensor<7x?xf32>
-//       TILE1:       %[[PAD:.*]] = linalg.pad_tensor %8 low[0, 0] high[7, %{{.*}}]
+//       TILE1:       %[[PAD:.*]] = linalg.pad_tensor %[[SLICE]] low[0, 0] high[7, %{{.*}}]
 //       TILE1:       %[[CAST:.*]] = tensor.cast %[[PAD]] : tensor<14x?xf32> to tensor<14x3xf32>
 //       TILE1:       scf.yield %[[CAST]] : tensor<14x3xf32>
 //       TILE1:     %[[R3:.*]] = tensor.insert_slice %[[R2]] into %[[INNER_OUT]][0, %[[IV]]] [14, 3] [1, 1] : tensor<14x3xf32> into tensor<14x15xf32>
@@ -121,8 +119,7 @@ func @static_pad_tensor(%input_tensor: tensor<7x9xf32>,
 func @static_pad_tile_evenly(%input_tensor: tensor<7x9xf32>,
                              %output_tensor: tensor<14x15xf32>,
                              %pad_value: f32) -> tensor<14x15xf32> {
-  %0 = linalg.pad_tensor %input_tensor
-    low[0, 0] high[7, 6] into %output_tensor {
+  %0 = linalg.pad_tensor %input_tensor low[0, 0] high[7, 6] {
     ^bb0(%arg1: index, %arg2: index):
       linalg.yield %pad_value : f32
     } : tensor<7x9xf32> to tensor<14x15xf32>


        


More information about the Mlir-commits mailing list