[Mlir-commits] [mlir] 5988a3b - [mlir] Linalg: ensure tile-and-pad always creates padding as requested

Alex Zinenko llvmlistbot at llvm.org
Fri Sep 24 09:40:20 PDT 2021


Author: Alex Zinenko
Date: 2021-09-24T18:40:13+02:00
New Revision: 5988a3b7a09126aff982944ecb36f533c450388e

URL: https://github.com/llvm/llvm-project/commit/5988a3b7a09126aff982944ecb36f533c450388e
DIFF: https://github.com/llvm/llvm-project/commit/5988a3b7a09126aff982944ecb36f533c450388e.diff

LOG: [mlir] Linalg: ensure tile-and-pad always creates padding as requested

Initially, the padding transformation and the related operation were only used
to guarantee static shapes of subtensors in tiled operations. The
transformation would not insert the padding operation if the shapes were
already static, and the overall code generation would actively remove such
"noop" pads. However, this transformation can be also used to pack data into
smaller tensors and marshall them into faster memory, regardless of the size
mismatches. In context of expert-driven transformation, we should assume that,
if padding is requested, a potentially padded tensor must be always created.
Update the transformation accordingly. To do this, introduce an optional
`packing` attribute to the `pad_tensor` op that serves as an indication that
the padding is an intentional choice (as opposed to side effect of type
normalization) and should be left alone by cleanups.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D110425

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir
    mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index dd568ba367067..6608c7ac7773d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -147,6 +147,8 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
            dimension, i.e `low`.
     * high: A list contains the padding along the end of each
            dimension, i.e. `high`.
+    * packing: whether the padding operation is guaranteed to create a new
+           tensor suitable for packing, i.e. a copy.
 
     The result tensor dimensions are `low` + `dim` + `high` along that
     dimension. The number of elements of `low` and `high` must match
@@ -159,6 +161,11 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     the rank of the `source` tensor. The value `yield`-ed by the
     region is used as the value of the view at the given position.
 
+    If `packing` is indicated, the padding is guaranteed to produce a new
+    tensor, e.g., to use for packing or promotion to faster memory. Such
+    operations are not optimized away even when the source type has the same
+    static shape.
+
     Example 1:
 
     ```mlir
@@ -188,6 +195,17 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
         linalg.yield %pad_value : f32
       } : tensor<2x3xf32> to tensor<?x?xf32>
     ```
+
+    Example 4:
+
+    ```mlir
+      // Force a padded value to be always exist with `packing`.
+      %pad_value = ... : f32
+      %0 = linalg.pad_tensor %arg0 packing low[0, 0] high[0, 0] {
+      ^bb0(%arg1: index, %arg2: index):
+        linalg.yield %pad_value : f32
+      } : tensor<2x3xf32> to tensor<2x3xf32>
+    ```
   }];
 
   let arguments = (ins
@@ -195,7 +213,8 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     Variadic<Index>:$low,
     Variadic<Index>:$high,
     I64ArrayAttr:$static_low,
-    I64ArrayAttr:$static_high);
+    I64ArrayAttr:$static_high,
+    UnitAttr:$packing);
 
   let regions = (region SizedRegion<1>:$region);
 
@@ -204,6 +223,7 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
   // TODO: Remove custom<InferType> when AllTypesMatch supports opt. operands.
   let assemblyFormat = [{
     $source
+    (`packing` $packing^)?
     `low` `` custom<OperandsOrIntegersSizesList>($low, $static_low)
     `high` `` custom<OperandsOrIntegersSizesList>($high, $static_high)
     $region attr-dict `:` type($source) `to` type($result)
@@ -240,14 +260,16 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     // "high" padding (i.e. it adds trailing padding values until the desired
     // size is met).
     static linalg::PadTensorOp createPadHighOp(
-        Type type, Value source, Value pad, Location loc, OpBuilder & builder);
+        Type type, Value source, Value pad, bool packing, Location loc,
+        OpBuilder & builder);
 
     // Return a PadTensorOp that pads `source to `type` size with `pad` value.
     // I.e., a block will be created and the `pad` value will be yielded
     // directly. If the type passed is nullptr, it is inferred.
     static linalg::PadTensorOp createPadScalarOp(
         Type type, Value source, Value pad, ArrayRef<OpFoldResult> low,
-        ArrayRef<OpFoldResult> high, Location loc, OpBuilder & builder);
+        ArrayRef<OpFoldResult> high, bool packing, Location loc,
+        OpBuilder & builder);
 
     // Return the pad value if it is a constant. Return null value otherwise.
     Value getConstantPaddingValue();
@@ -291,14 +313,17 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
     // Build a PadTensorOp with mixed static and dynamic entries.
     OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$staticLow,
       "ArrayRef<int64_t>":$staticHigh, "ValueRange":$low, "ValueRange":$high,
+      CArg<"bool", "false">:$packing,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
     // Build a PadTensorOp with all dynamic entries.
     OpBuilder<(ins "Value":$source, "ValueRange":$low, "ValueRange":$high,
+      CArg<"bool", "false">:$packing,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
     // Build a PadTensorOp with mixed static and dynamic entries and custom
     // result type. If the type passed is nullptr, it is inferred.
     OpBuilder<(ins "Type":$resultType, "Value":$source,
       "ArrayRef<OpFoldResult>":$low, "ArrayRef<OpFoldResult>":$high,
+      CArg<"bool", "false">:$packing,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
   ];
 

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 4184b9cadf4dc..ed1b03cdc64d3 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -87,7 +87,7 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
 
   return linalg::PadTensorOp::createPadScalarOp(
              RankedTensorType::get(paddedShape, inputETy), input, padValue,
-             lowIndices, highIndices, loc, rewriter)
+             lowIndices, highIndices, /*packing=*/false, loc, rewriter)
       .result();
 }
 
@@ -2350,7 +2350,8 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
     Value constant = rewriter.create<ConstantOp>(loc, constantAttr);
 
     auto newPadOp = linalg::PadTensorOp::createPadScalarOp(
-        padOp.getType(), input, constant, lowValues, highValues, loc, rewriter);
+        padOp.getType(), input, constant, lowValues, highValues,
+        /*packing=*/false, loc, rewriter);
 
     rewriter.replaceOp(padOp, newPadOp.getResult());
     return success();

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index dfa4df8f58d69..7dd6bd90c78b9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1085,26 +1085,28 @@ RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType,
 void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
                         ArrayRef<int64_t> staticLow,
                         ArrayRef<int64_t> staticHigh, ValueRange low,
-                        ValueRange high, ArrayRef<NamedAttribute> attrs) {
+                        ValueRange high, bool packing,
+                        ArrayRef<NamedAttribute> attrs) {
   auto sourceType = source.getType().cast<RankedTensorType>();
   auto resultType = inferResultType(sourceType, staticLow, staticHigh);
   build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
-        b.getI64ArrayAttr(staticHigh));
+        b.getI64ArrayAttr(staticHigh), packing ? b.getUnitAttr() : UnitAttr());
   result.addAttributes(attrs);
 }
 
 void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
-                        ValueRange low, ValueRange high,
+                        ValueRange low, ValueRange high, bool packing,
                         ArrayRef<NamedAttribute> attrs) {
   auto sourceType = source.getType().cast<RankedTensorType>();
   unsigned rank = sourceType.getRank();
   SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamicSize);
-  build(b, result, source, staticVector, staticVector, low, high, attrs);
+  build(b, result, source, staticVector, staticVector, low, high, packing,
+        attrs);
 }
 
 void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
                         Value source, ArrayRef<OpFoldResult> low,
-                        ArrayRef<OpFoldResult> high,
+                        ArrayRef<OpFoldResult> high, bool packing,
                         ArrayRef<NamedAttribute> attrs) {
   assert(resultType.isa<RankedTensorType>());
   auto sourceType = source.getType().cast<RankedTensorType>();
@@ -1126,15 +1128,18 @@ void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
         PadTensorOp::inferResultType(sourceType, staticLow, staticHigh);
   }
   build(b, result, resultType, source, dynamicLow, dynamicHigh,
-        b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh));
+        b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh),
+        packing ? b.getUnitAttr() : UnitAttr());
+  result.addAttributes(attrs);
 }
 
 PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad,
                                            ArrayRef<OpFoldResult> low,
                                            ArrayRef<OpFoldResult> high,
-                                           Location loc, OpBuilder &builder) {
-  auto padTensorOp =
-      builder.create<linalg::PadTensorOp>(loc, type, source, low, high);
+                                           bool packing, Location loc,
+                                           OpBuilder &builder) {
+  auto padTensorOp = builder.create<linalg::PadTensorOp>(loc, type, source, low,
+                                                         high, packing);
   int rank = padTensorOp.getResultType().getRank();
   SmallVector<Type, 4> blockArgTypes;
   blockArgTypes.assign(rank, builder.getIndexType());
@@ -1148,7 +1153,8 @@ PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad,
 }
 
 PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad,
-                                         Location loc, OpBuilder &builder) {
+                                         bool packing, Location loc,
+                                         OpBuilder &builder) {
   SmallVector<OpFoldResult, 4> low, high;
   auto rankedTensorType = type.cast<RankedTensorType>();
   assert(rankedTensorType.hasStaticShape());
@@ -1161,8 +1167,8 @@ PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad,
     high.push_back(highValue);
     low.push_back(builder.createOrFold<ConstantIndexOp>(loc, 0));
   }
-  return PadTensorOp::createPadScalarOp(type, source, pad, low, high, loc,
-                                        builder);
+  return PadTensorOp::createPadScalarOp(type, source, pad, low, high, packing,
+                                        loc, builder);
 }
 
 LogicalResult PadTensorOp::reifyResultShapes(
@@ -1434,7 +1440,8 @@ Operation *PadTensorOp::getTiledImplementation(OpBuilder &b, ValueRange dest,
 }
 
 namespace {
-// Folds linalg.pad_tensor when padding is static zeros.
+// Folds linalg.pad_tensor when padding is static zeros and packing is not
+// requested.
 struct FoldStaticZeroPadding : public OpRewritePattern<PadTensorOp> {
   using OpRewritePattern<PadTensorOp>::OpRewritePattern;
 
@@ -1442,6 +1449,8 @@ struct FoldStaticZeroPadding : public OpRewritePattern<PadTensorOp> {
                                 PatternRewriter &rewriter) const override {
     if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
       return failure();
+    if (padTensorOp.packing())
+      return failure();
     rewriter.replaceOpWithNewOp<tensor::CastOp>(
         padTensorOp, padTensorOp.result().getType(), padTensorOp.source());
     return success();
@@ -1472,7 +1481,7 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
       auto newOp = rewriter.create<PadTensorOp>(
           padTensorOp->getLoc(), newResultType, padTensorOp.source(),
           padTensorOp.low(), padTensorOp.high(), padTensorOp.static_low(),
-          padTensorOp.static_high());
+          padTensorOp.static_high(), padTensorOp.packing());
       BlockAndValueMapping mapper;
       padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
 
@@ -1503,7 +1512,8 @@ struct FoldTargetTensorCast : public OpRewritePattern<PadTensorOp> {
     auto replacementOp = rewriter.create<PadTensorOp>(
         padTensorOp.getLoc(), tensorCastOp.dest().getType(),
         padTensorOp.source(), padTensorOp.low(), padTensorOp.high(),
-        padTensorOp.static_low(), padTensorOp.static_high());
+        padTensorOp.static_low(), padTensorOp.static_high(),
+        padTensorOp.packing());
     replacementOp.region().takeBody(padTensorOp.region());
 
     rewriter.replaceOp(padTensorOp, replacementOp.result());
@@ -1544,7 +1554,8 @@ Value PadTensorOp::getConstantPaddingValue() {
 }
 
 OpFoldResult PadTensorOp::fold(ArrayRef<Attribute>) {
-  if (getResultType().hasStaticShape() && getResultType() == getSourceType())
+  if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
+      !packing())
     return source();
   return {};
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 533f295a35f56..cef9e5a030aed 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -145,17 +145,15 @@ LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
   return *this;
 }
 
-/// Try to compute a static bounding box for `operand`
-/// Return success if either:
-///   1. The operand is already statically shaped, `result` is left unchanged.
-///   2. The operand is (partially) dynamic, `result` is the result of a freshly
-///      created PadTensorOp.
-/// Return failure if the operand cannot be padded to a static shape.
+/// Try to compute a static bounding box for `operand`. The padding happens
+/// even if the operand already has static shape. `result` is the result of a
+/// freshly created PadTensorOp. Return failure if the operand cannot be padded
+/// to a static shape.
 static LogicalResult padOperandToSmallestStaticBoundingBox(
     PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
     const PaddingValueComputationFunction &paddingFunc, Value &result) {
-  // Already static shape, no need to pad.
-  if (llvm::none_of(opToPad.getShape(opOperand), ShapedType::isDynamic))
+  // Can't pad scalars.
+  if (opToPad.getShape(opOperand).empty())
     return success();
   auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
   // Not a slice op, cannot construct a static bounding box.
@@ -179,7 +177,8 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
   auto staticTensorType = RankedTensorType::get(
       staticSizes, getElementTypeOrSelf(opOperand->get()));
   result = linalg::PadTensorOp::createPadHighOp(
-      staticTensorType, opOperand->get(), pad, opToPad->getLoc(), rewriter);
+      staticTensorType, opOperand->get(), pad, /*packing=*/true,
+      opToPad->getLoc(), rewriter);
   return success();
 }
 
@@ -189,12 +188,9 @@ linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad,
                           LinalgOp &paddedOp) {
   Location loc = opToPad->getLoc();
 
-  // If the op is fully static, it does not need padding.
   // TODO: there are cases where we may still want to pad to larger sizes.
   assert(opToPad.hasTensorSemantics() &&
          "expected operation to have tensor semantics");
-  if (!opToPad.hasDynamicShape())
-    return success();
 
   OpBuilder::InsertionGuard g(rewriter);
   // Set IP after op because we also take the dims of the original output.

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 42d640a60246c..ec5fccf86afbd 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -630,6 +630,22 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
 
 // -----
 
+// CHECK-LABEL: func @pad_tensor_packing_same_static_shape(
+//  CHECK-SAME:   %[[ARG0:.*]]: tensor<5x6xf32>
+//       CHECK:   %[[PAD:.*]] = linalg.pad_tensor
+//       CHECK:   return %[[PAD]]
+func @pad_tensor_packing_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
+    -> tensor<5x6xf32> {
+  %cst = constant 0.000000e+00 : f32
+  %0 = linalg.pad_tensor %arg0 packing low[%a, 0] high[0, %a] {
+        ^bb0(%arg1: index, %arg2: index):
+          linalg.yield %cst : f32
+  } : tensor<5x6xf32> to tensor<5x6xf32>
+  return %0 : tensor<5x6xf32>
+}
+
+// -----
+
 // CHECK-LABEL:   func @pad_tensor_after_cast_
diff erent_shape(
 // CHECK-SAME:      %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
 // CHECK:           %[[CST:.*]] = constant 0.000000e+00 : f32
@@ -921,6 +937,22 @@ func @pad_static_zero_cast(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<
 
 // -----
 
+// CHECK-LABEL: func @pad_packing_static_zero(
+//  CHECK-SAME:                  %[[ARG0:.*]]: tensor<?x?x?xf32>
+//       CHECK:   %[[PAD:.*]] = linalg.pad_tensor
+//       CHECK:   return %[[PAD]]
+func @pad_packing_static_zero(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
+  %c0 = constant 0 : index
+  %0 = linalg.pad_tensor %arg0 packing low[0, %c0, 0] high[0, 0, %c0] {
+    ^bb0(%arg1: index, %arg2: index, %arg3: index):
+      linalg.yield %pad_value : f32
+    } : tensor<?x?x?xf32> to tensor<2x3x4xf32>
+
+  return %0 : tensor<2x3x4xf32>
+}
+
+// -----
+
 func private @some_use(%i : index, %j : index)
 
 // CHECK-LABEL: func @init_canonicalize

diff  --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
index 4e5ae78c466ec..fb283814a4662 100644
--- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
@@ -20,11 +20,11 @@ func @matmul_tensors(
 //  CHECK-NOT:       linalg.matmul {{.*}} tensor<?x?xi8>
 
 // Padding injects static information.
-//      CHECK:       %[[pA:.*]] = linalg.pad_tensor %[[sTA]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
+//      CHECK:       %[[pA:.*]] = linalg.pad_tensor %[[sTA]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
 //      CHECK:         : tensor<?x?xi8> to tensor<2x4xi8>
-//      CHECK:       %[[pB:.*]] = linalg.pad_tensor %[[sTB]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
+//      CHECK:       %[[pB:.*]] = linalg.pad_tensor %[[sTB]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
 //      CHECK:         : tensor<?x?xi8> to tensor<4x3xi8>
-//      CHECK:       %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
+//      CHECK:       %[[pC:.*]] = linalg.pad_tensor %[[sTC]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
 //      CHECK:         : tensor<?x?xi32> to tensor<2x3xi32>
 //      CHECK:       %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x4xi8>, tensor<4x3xi8>)
 // CHECK-SAME:                                           outs(%[[pC]] : tensor<2x3xi32>)  -> tensor<2x3xi32>
@@ -55,7 +55,7 @@ func @generic_scalar_and_tensor(
 //      CHECK:       %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
 
 // Padding injects static information.
-//      CHECK:       %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}, %{{.*}}]
+//      CHECK:       %[[pC:.*]] = linalg.pad_tensor %[[sTC]] packing low[%[[C0]], %[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}, %{{.*}}]
 //      CHECK:        : tensor<?x?x?xf32> to tensor<2x3x4xf32>
 //      CHECK:       %[[pD:.*]] = linalg.generic
 // CHECK-SAME:         ins(%[[VAL]] : f32) outs(%[[pC]] : tensor<2x3x4xf32>)
@@ -107,11 +107,50 @@ func @matmul_partially_padded_tensors(
 //      CHECK-1DIM-TILE:                %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<?x8xi8> to tensor<?x8xi8>
 //      CHECK-1DIM-TILE:                %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<8x?xi8> to tensor<8x?xi8>
 //      CHECK-1DIM-TILE:                %[[sTC:.*]] = tensor.extract_slice %[[TC1]][{{.*}}] : tensor<?x?xi32> to tensor<?x?xi32>
-//      CHECK-1DIM-TILE:                %[[pA:.*]] = linalg.pad_tensor %[[sTA]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
+//      CHECK-1DIM-TILE:                %[[pA:.*]] = linalg.pad_tensor %[[sTA]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
 //      CHECK-1DIM-TILE:                   : tensor<?x8xi8> to tensor<2x8xi8>
-//      CHECK-1DIM-TILE:                %[[pB:.*]] = linalg.pad_tensor %[[sTB]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
+//      CHECK-1DIM-TILE:                %[[pB:.*]] = linalg.pad_tensor %[[sTB]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
 //      CHECK-1DIM-TILE:                   : tensor<8x?xi8> to tensor<8x3xi8>
-//      CHECK-1DIM-TILE:                %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
+//      CHECK-1DIM-TILE:                %[[pC:.*]] = linalg.pad_tensor %[[sTC]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}]
 //      CHECK-1DIM-TILE:                   : tensor<?x?xi32> to tensor<2x3xi32>
 //      CHECK-1DIM-TILE:               %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>)
 //      CHECK-1DIM-TILE:                                           outs(%[[pC]] : tensor<2x3xi32>)  -> tensor<2x3xi32>
+
+// Check that the tile-and-pad transformation actually introduces the padding
+// as requested, even if original operation already operates on static
+// shapes.
+// CHECK-LABEL: @pad_to_same_static_size
+func @pad_to_same_static_size(%arg0: tensor<2x3x4xf32>, %arg1: f32) -> tensor<2x3x4xf32> {
+  // CHECK: %[[c0:.*]] = constant 0 : index
+  // CHECK-NOT: scf.for
+  // CHECK: linalg.pad_tensor %{{.*}} packing low[%[[c0]], %[[c0]], %[[c0]]] high[%[[c0]], %[[c0]], %[[c0]]]
+  // CHECK: tensor<2x3x4xf32> to tensor<2x3x4xf32>
+  %0 = linalg.generic {
+    indexing_maps =  [affine_map<(d0, d1, d2) -> ()>,
+                      affine_map<(d0, d1, d2) -> (d0, d1, d2)> ],
+    iterator_types = ["parallel", "parallel", "parallel"]}
+  {__internal_linalg_transform__ = "tile"}
+  ins(%arg1 : f32) outs(%arg0 : tensor<2x3x4xf32>) {
+  ^bb0(%arg2: f32, %arg3: f32):  // no predecessors
+    linalg.yield %arg2 : f32
+  } -> tensor<2x3x4xf32>
+  return %0 : tensor<2x3x4xf32>
+}
+
+// CHECK-LABEL: @pad_static_divisible_size
+func @pad_static_divisible_size(%arg0: tensor<4x6x8xf32>, %arg1: f32) -> tensor<4x6x8xf32> {
+  // CHECK: %[[c0:.*]] = constant 0 : index
+  // CHECK-COUNT-3: scf.for
+  // CHECK: linalg.pad_tensor %{{.*}} packing low[%[[c0]], %[[c0]], %[[c0]]] high[%[[c0]], %[[c0]], %[[c0]]]
+  // CHECK: tensor<2x3x4xf32> to tensor<2x3x4xf32>
+  %0 = linalg.generic {
+    indexing_maps =  [affine_map<(d0, d1, d2) -> ()>,
+                      affine_map<(d0, d1, d2) -> (d0, d1, d2)> ],
+    iterator_types = ["parallel", "parallel", "parallel"]}
+  {__internal_linalg_transform__ = "tile"}
+  ins(%arg1 : f32) outs(%arg0 : tensor<4x6x8xf32>) {
+  ^bb0(%arg2: f32, %arg3: f32):  // no predecessors
+    linalg.yield %arg2 : f32
+  } -> tensor<4x6x8xf32>
+  return %0 : tensor<4x6x8xf32>
+}


        


More information about the Mlir-commits mailing list