[Mlir-commits] [mlir] 01d696e - [mlir] rename the "packing" flag of linalg.pad_tensor to "nofold"
    Alex Zinenko 
    llvmlistbot at llvm.org
       
    Mon Oct  4 12:28:16 PDT 2021
    
    
  
Author: Alex Zinenko
Date: 2021-10-04T21:28:11+02:00
New Revision: 01d696e563545d54852013416dc570b26a085fd7
URL: https://github.com/llvm/llvm-project/commit/01d696e563545d54852013416dc570b26a085fd7
DIFF: https://github.com/llvm/llvm-project/commit/01d696e563545d54852013416dc570b26a085fd7.diff
LOG: [mlir] rename the "packing" flag of linalg.pad_tensor to "nofold"
The discussion in https://reviews.llvm.org/D110425 demonstrated that "packing"
may be a confusing term to define the behavior of this op in presence of the
attribute. Instead, indicate the intended effect of preventing the folder from
being applied.
Reviewed By: nicolasvasilache, silvas
Differential Revision: https://reviews.llvm.org/D111046
Added: 
    
Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
Removed: 
    
################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 6608c7ac7773d..bb61d3f088f91 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -146,9 +146,9 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     * low: A list contains the padding along the start of each
            dimension, i.e `low`.
     * high: A list contains the padding along the end of each
-           dimension, i.e. `high`.
-    * packing: whether the padding operation is guaranteed to create a new
-           tensor suitable for packing, i.e. a copy.
+            dimension, i.e. `high`.
+    * nofold: indicates that the operation should not be folded when source and
+              result types are equal.
 
     The result tensor dimensions are `low` + `dim` + `high` along that
     dimension. The number of elements of `low` and `high` must match
@@ -161,10 +161,9 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     the rank of the `source` tensor. The value `yield`-ed by the
     region is used as the value of the view at the given position.
 
-    If `packing` is indicated, the padding is guaranteed to produce a new
-    tensor, e.g., to use for packing or promotion to faster memory. Such
-    operations are not optimized away even when the source type has the same
-    static shape.
+    If `nofold` is set, the padding operation will not be folded away even
+    if the source type and the padded type have the same static shape. This can
+    be used, e.g., for packing or promotion to faster memory.
 
     Example 1:
 
@@ -199,9 +198,9 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     Example 4:
 
     ```mlir
-      // Force a padded value to be always exist with `packing`.
+      // Force a padded value to be always exist with `nofold`.
       %pad_value = ... : f32
-      %0 = linalg.pad_tensor %arg0 packing low[0, 0] high[0, 0] {
+      %0 = linalg.pad_tensor %arg0 nofold low[0, 0] high[0, 0] {
       ^bb0(%arg1: index, %arg2: index):
         linalg.yield %pad_value : f32
       } : tensor<2x3xf32> to tensor<2x3xf32>
@@ -214,7 +213,7 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     Variadic<Index>:$high,
     I64ArrayAttr:$static_low,
     I64ArrayAttr:$static_high,
-    UnitAttr:$packing);
+    UnitAttr:$nofold);
 
   let regions = (region SizedRegion<1>:$region);
 
@@ -223,7 +222,7 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
   // TODO: Remove custom<InferType> when AllTypesMatch supports opt. operands.
   let assemblyFormat = [{
     $source
-    (`packing` $packing^)?
+    (`nofold` $nofold^)?
     `low` `` custom<OperandsOrIntegersSizesList>($low, $static_low)
     `high` `` custom<OperandsOrIntegersSizesList>($high, $static_high)
     $region attr-dict `:` type($source) `to` type($result)
@@ -260,7 +259,7 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     // "high" padding (i.e. it adds trailing padding values until the desired
     // size is met).
     static linalg::PadTensorOp createPadHighOp(
-        Type type, Value source, Value pad, bool packing, Location loc,
+        Type type, Value source, Value pad, bool nofold, Location loc,
         OpBuilder & builder);
 
     // Return a PadTensorOp that pads `source to `type` size with `pad` value.
@@ -268,7 +267,7 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     // directly. If the type passed is nullptr, it is inferred.
     static linalg::PadTensorOp createPadScalarOp(
         Type type, Value source, Value pad, ArrayRef<OpFoldResult> low,
-        ArrayRef<OpFoldResult> high, bool packing, Location loc,
+        ArrayRef<OpFoldResult> high, bool nofold, Location loc,
         OpBuilder & builder);
 
     // Return the pad value if it is a constant. Return null value otherwise.
@@ -313,17 +312,17 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     // Build a PadTensorOp with mixed static and dynamic entries.
     OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$staticLow,
       "ArrayRef<int64_t>":$staticHigh, "ValueRange":$low, "ValueRange":$high,
-      CArg<"bool", "false">:$packing,
+      CArg<"bool", "false">:$nofold,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
     // Build a PadTensorOp with all dynamic entries.
     OpBuilder<(ins "Value":$source, "ValueRange":$low, "ValueRange":$high,
-      CArg<"bool", "false">:$packing,
+      CArg<"bool", "false">:$nofold,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
     // Build a PadTensorOp with mixed static and dynamic entries and custom
     // result type. If the type passed is nullptr, it is inferred.
     OpBuilder<(ins "Type":$resultType, "Value":$source,
       "ArrayRef<OpFoldResult>":$low, "ArrayRef<OpFoldResult>":$high,
-      CArg<"bool", "false">:$packing,
+      CArg<"bool", "false">:$nofold,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
   ];
 
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8ee0c4334ee8d..1d4132fddbfc7 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -87,7 +87,7 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
 
   return linalg::PadTensorOp::createPadScalarOp(
              RankedTensorType::get(paddedShape, inputETy), input, padValue,
-             lowIndices, highIndices, /*packing=*/false, loc, rewriter)
+             lowIndices, highIndices, /*nofold=*/false, loc, rewriter)
       .result();
 }
 
@@ -2349,7 +2349,7 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
 
     auto newPadOp = linalg::PadTensorOp::createPadScalarOp(
         padOp.getType(), input, constant, lowValues, highValues,
-        /*packing=*/false, loc, rewriter);
+        /*nofold=*/false, loc, rewriter);
 
     rewriter.replaceOp(padOp, newPadOp.getResult());
     return success();
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 4a05c577b3d3c..af292878d1f6e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1085,28 +1085,28 @@ RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType,
 void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
                         ArrayRef<int64_t> staticLow,
                         ArrayRef<int64_t> staticHigh, ValueRange low,
-                        ValueRange high, bool packing,
+                        ValueRange high, bool nofold,
                         ArrayRef<NamedAttribute> attrs) {
   auto sourceType = source.getType().cast<RankedTensorType>();
   auto resultType = inferResultType(sourceType, staticLow, staticHigh);
   build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
-        b.getI64ArrayAttr(staticHigh), packing ? b.getUnitAttr() : UnitAttr());
+        b.getI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr());
   result.addAttributes(attrs);
 }
 
 void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
-                        ValueRange low, ValueRange high, bool packing,
+                        ValueRange low, ValueRange high, bool nofold,
                         ArrayRef<NamedAttribute> attrs) {
   auto sourceType = source.getType().cast<RankedTensorType>();
   unsigned rank = sourceType.getRank();
   SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamicSize);
-  build(b, result, source, staticVector, staticVector, low, high, packing,
+  build(b, result, source, staticVector, staticVector, low, high, nofold,
         attrs);
 }
 
 void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
                         Value source, ArrayRef<OpFoldResult> low,
-                        ArrayRef<OpFoldResult> high, bool packing,
+                        ArrayRef<OpFoldResult> high, bool nofold,
                         ArrayRef<NamedAttribute> attrs) {
   assert(resultType.isa<RankedTensorType>());
   auto sourceType = source.getType().cast<RankedTensorType>();
@@ -1129,17 +1129,17 @@ void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
   }
   build(b, result, resultType, source, dynamicLow, dynamicHigh,
         b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh),
-        packing ? b.getUnitAttr() : UnitAttr());
+        nofold ? b.getUnitAttr() : UnitAttr());
   result.addAttributes(attrs);
 }
 
 PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad,
                                            ArrayRef<OpFoldResult> low,
                                            ArrayRef<OpFoldResult> high,
-                                           bool packing, Location loc,
+                                           bool nofold, Location loc,
                                            OpBuilder &builder) {
-  auto padTensorOp = builder.create<linalg::PadTensorOp>(loc, type, source, low,
-                                                         high, packing);
+  auto padTensorOp =
+      builder.create<linalg::PadTensorOp>(loc, type, source, low, high, nofold);
   int rank = padTensorOp.getResultType().getRank();
   SmallVector<Type, 4> blockArgTypes;
   blockArgTypes.assign(rank, builder.getIndexType());
@@ -1153,7 +1153,7 @@ PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad,
 }
 
 PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad,
-                                         bool packing, Location loc,
+                                         bool nofold, Location loc,
                                          OpBuilder &builder) {
   SmallVector<OpFoldResult, 4> low, high;
   auto rankedTensorType = type.cast<RankedTensorType>();
@@ -1167,7 +1167,7 @@ PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad,
     high.push_back(highValue);
     low.push_back(builder.createOrFold<ConstantIndexOp>(loc, 0));
   }
-  return PadTensorOp::createPadScalarOp(type, source, pad, low, high, packing,
+  return PadTensorOp::createPadScalarOp(type, source, pad, low, high, nofold,
                                         loc, builder);
 }
 
@@ -1440,8 +1440,8 @@ Operation *PadTensorOp::getTiledImplementation(OpBuilder &b, ValueRange dest,
 }
 
 namespace {
-// Folds linalg.pad_tensor when padding is static zeros and packing is not
-// requested.
+// Folds linalg.pad_tensor when padding is static zeros and the attribute
+// doesn't request otherwise.
 struct FoldStaticZeroPadding : public OpRewritePattern<PadTensorOp> {
   using OpRewritePattern<PadTensorOp>::OpRewritePattern;
 
@@ -1449,7 +1449,7 @@ struct FoldStaticZeroPadding : public OpRewritePattern<PadTensorOp> {
                                 PatternRewriter &rewriter) const override {
     if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
       return failure();
-    if (padTensorOp.packing())
+    if (padTensorOp.nofold())
       return failure();
     rewriter.replaceOpWithNewOp<tensor::CastOp>(
         padTensorOp, padTensorOp.result().getType(), padTensorOp.source());
@@ -1481,7 +1481,7 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
       auto newOp = rewriter.create<PadTensorOp>(
           padTensorOp->getLoc(), newResultType, padTensorOp.source(),
           padTensorOp.low(), padTensorOp.high(), padTensorOp.static_low(),
-          padTensorOp.static_high(), padTensorOp.packing());
+          padTensorOp.static_high(), padTensorOp.nofold());
       BlockAndValueMapping mapper;
       padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
 
@@ -1513,7 +1513,7 @@ struct FoldTargetTensorCast : public OpRewritePattern<PadTensorOp> {
         padTensorOp.getLoc(), tensorCastOp.dest().getType(),
         padTensorOp.source(), padTensorOp.low(), padTensorOp.high(),
         padTensorOp.static_low(), padTensorOp.static_high(),
-        padTensorOp.packing());
+        padTensorOp.nofold());
     replacementOp.region().takeBody(padTensorOp.region());
 
     rewriter.replaceOp(padTensorOp, replacementOp.result());
@@ -1555,7 +1555,7 @@ Value PadTensorOp::getConstantPaddingValue() {
 
 OpFoldResult PadTensorOp::fold(ArrayRef<Attribute>) {
   if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
-      !packing())
+      !nofold())
     return source();
   return {};
 }
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index c471459da5c25..aacb20ca97269 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -182,7 +182,7 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
       staticSizes, getElementTypeOrSelf(opOperand->get()));
   result = linalg::PadTensorOp::createPadHighOp(
       staticTensorType, opOperand->get(), paddingValue.getValue(),
-      /*packing=*/true, opToPad->getLoc(), rewriter);
+      /*nofold=*/true, opToPad->getLoc(), rewriter);
   return success();
 }
 
diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index ec5fccf86afbd..1670bde221972 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -630,14 +630,14 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
 
 // -----
 
-// CHECK-LABEL: func @pad_tensor_packing_same_static_shape(
+// CHECK-LABEL: func @pad_tensor_nofold_same_static_shape(
 //  CHECK-SAME:   %[[ARG0:.*]]: tensor<5x6xf32>
 //       CHECK:   %[[PAD:.*]] = linalg.pad_tensor
 //       CHECK:   return %[[PAD]]
-func @pad_tensor_packing_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
+func @pad_tensor_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
     -> tensor<5x6xf32> {
   %cst = constant 0.000000e+00 : f32
-  %0 = linalg.pad_tensor %arg0 packing low[%a, 0] high[0, %a] {
+  %0 = linalg.pad_tensor %arg0 nofold low[%a, 0] high[0, %a] {
         ^bb0(%arg1: index, %arg2: index):
           linalg.yield %cst : f32
   } : tensor<5x6xf32> to tensor<5x6xf32>
@@ -937,13 +937,13 @@ func @pad_static_zero_cast(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<
 
 // -----
 
-// CHECK-LABEL: func @pad_packing_static_zero(
+// CHECK-LABEL: func @pad_nofold_static_zero(
 //  CHECK-SAME:                  %[[ARG0:.*]]: tensor<?x?x?xf32>
 //       CHECK:   %[[PAD:.*]] = linalg.pad_tensor
 //       CHECK:   return %[[PAD]]
-func @pad_packing_static_zero(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
+func @pad_nofold_static_zero(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
   %c0 = constant 0 : index
-  %0 = linalg.pad_tensor %arg0 packing low[0, %c0, 0] high[0, 0, %c0] {
+  %0 = linalg.pad_tensor %arg0 nofold low[0, %c0, 0] high[0, 0, %c0] {
     ^bb0(%arg1: index, %arg2: index, %arg3: index):
       linalg.yield %pad_value : f32
     } : tensor<?x?x?xf32> to tensor<2x3x4xf32>
diff  --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
index 91126ef956957..60d5457e86af4 100644
--- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
@@ -20,11 +20,11 @@ func @matmul_tensors(
 //  CHECK-NOT:       linalg.matmul {{.*}} tensor<?x?xi8>
 
 // Padding injects static information.
-//      CHECK:       %[[pA:.*]] = linalg.pad_tensor %[[sTA]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
+//      CHECK:       %[[pA:.*]] = linalg.pad_tensor %[[sTA]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
 //      CHECK:         : tensor<?x?xi8> to tensor<2x4xi8>
-//      CHECK:       %[[pB:.*]] = linalg.pad_tensor %[[sTB]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
+//      CHECK:       %[[pB:.*]] = linalg.pad_tensor %[[sTB]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
 //      CHECK:         : tensor<?x?xi8> to tensor<4x3xi8>
-//      CHECK:       %[[pC:.*]] = linalg.pad_tensor %[[sTC]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
+//      CHECK:       %[[pC:.*]] = linalg.pad_tensor %[[sTC]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
 //      CHECK:         : tensor<?x?xi32> to tensor<2x3xi32>
 //      CHECK:       %[[pD:.*]] = linalg.matmul ins(%[[pA]], %[[pB]] : tensor<2x4xi8>, tensor<4x3xi8>)
 // CHECK-SAME:                                           outs(%[[pC]] : tensor<2x3xi32>)  -> tensor<2x3xi32>
@@ -55,7 +55,7 @@ func @generic_scalar_and_tensor(
 //      CHECK:       %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
 
 // Padding injects static information.
-//      CHECK:       %[[pC:.*]] = linalg.pad_tensor %[[sTC]] packing low[%[[C0]], %[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}, %{{.*}}]
+//      CHECK:       %[[pC:.*]] = linalg.pad_tensor %[[sTC]] nofold low[%[[C0]], %[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}, %{{.*}}]
 //      CHECK:        : tensor<?x?x?xf32> to tensor<2x3x4xf32>
 //      CHECK:       %[[pD:.*]] = linalg.generic
 // CHECK-SAME:         ins(%[[VAL]] : f32) outs(%[[pC]] : tensor<2x3x4xf32>)
@@ -108,9 +108,9 @@ func @matmul_partially_padded_tensors(
 //      CHECK-1DIM-TILE:                %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<?x8xi8> to tensor<?x8xi8>
 //      CHECK-1DIM-TILE:                %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<8x?xi8> to tensor<8x?xi8>
 //      CHECK-1DIM-TILE:                %[[sTC:.*]] = tensor.extract_slice %[[TC1]][{{.*}}] : tensor<?x?xi32> to tensor<?x?xi32>
-//      CHECK-1DIM-TILE:                %[[pA:.*]] = linalg.pad_tensor %[[sTA]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
+//      CHECK-1DIM-TILE:                %[[pA:.*]] = linalg.pad_tensor %[[sTA]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
 //      CHECK-1DIM-TILE:                   : tensor<?x8xi8> to tensor<2x8xi8>
-//      CHECK-1DIM-TILE:                %[[pB:.*]] = linalg.pad_tensor %[[sTB]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
+//      CHECK-1DIM-TILE:                %[[pB:.*]] = linalg.pad_tensor %[[sTB]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
 //      CHECK-1DIM-TILE:                   : tensor<8x?xi8> to tensor<8x3xi8>
 //      CHECK-1DIM-TILE:                %[[pD:.*]] = linalg.matmul ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>)
 //      CHECK-1DIM-TILE:                                           outs(%[[sTC]] : tensor<?x?xi32>)  -> tensor<?x?xi32>
@@ -122,7 +122,7 @@ func @matmul_partially_padded_tensors(
 func @pad_to_same_static_size(%arg0: tensor<2x3x4xf32>, %arg1: f32) -> tensor<2x3x4xf32> {
   // CHECK: %[[c0:.*]] = constant 0 : index
   // CHECK-NOT: scf.for
-  // CHECK: linalg.pad_tensor %{{.*}} packing low[%[[c0]], %[[c0]], %[[c0]]] high[%[[c0]], %[[c0]], %[[c0]]]
+  // CHECK: linalg.pad_tensor %{{.*}} nofold low[%[[c0]], %[[c0]], %[[c0]]] high[%[[c0]], %[[c0]], %[[c0]]]
   // CHECK: tensor<2x3x4xf32> to tensor<2x3x4xf32>
   %0 = linalg.generic {
     indexing_maps =  [affine_map<(d0, d1, d2) -> ()>,
@@ -140,7 +140,7 @@ func @pad_to_same_static_size(%arg0: tensor<2x3x4xf32>, %arg1: f32) -> tensor<2x
 func @pad_static_divisible_size(%arg0: tensor<4x6x8xf32>, %arg1: f32) -> tensor<4x6x8xf32> {
   // CHECK: %[[c0:.*]] = constant 0 : index
   // CHECK-COUNT-3: scf.for
-  // CHECK: linalg.pad_tensor %{{.*}} packing low[%[[c0]], %[[c0]], %[[c0]]] high[%[[c0]], %[[c0]], %[[c0]]]
+  // CHECK: linalg.pad_tensor %{{.*}} nofold low[%[[c0]], %[[c0]], %[[c0]]] high[%[[c0]], %[[c0]], %[[c0]]]
   // CHECK: tensor<2x3x4xf32> to tensor<2x3x4xf32>
   %0 = linalg.generic {
     indexing_maps =  [affine_map<(d0, d1, d2) -> ()>,
        
    
    
More information about the Mlir-commits
mailing list