[llvm-branch-commits] [mlir] 3747eb9 - [mlir][Linalg] Add a padding option to Linalg tiling

Nicolas Vasilache via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Jan 25 01:26:04 PST 2021


Author: Nicolas Vasilache
Date: 2021-01-25T09:17:30Z
New Revision: 3747eb9c85b3393aa00ad12e9e7ef31ffec8bd4c

URL: https://github.com/llvm/llvm-project/commit/3747eb9c85b3393aa00ad12e9e7ef31ffec8bd4c
DIFF: https://github.com/llvm/llvm-project/commit/3747eb9c85b3393aa00ad12e9e7ef31ffec8bd4c.diff

LOG: [mlir][Linalg] Add a padding option to Linalg tiling

This revision allows the base Linalg tiling pattern to optionally require padding to
a constant bounding shape.
When requested, a simple analysis is performed, similar to buffer promotion.
A temporary `linalg.simple_pad` op is added to model padding for the purpose of
connecting the dots. This will be replaced by a more fleshed out `linalg.pad_tensor`
op when it is available.
In the meantime, this temporary op serves the purpose of exhibiting the necessary
properties required from a more fleshed out pad op, to compose with transformations
properly.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D95149

Added: 
    mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Interfaces/ViewLikeInterface.h
    mlir/include/mlir/Interfaces/ViewLikeInterface.td
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index ae9f81d043f5..9ea1bc5a3587 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -475,6 +475,38 @@ def Linalg_SliceOp : Linalg_Op<"slice", [
   let hasFolder = 1;
 }
 
+def Linalg_SimplePadOp : Linalg_Op<"simple_pad", [NoSideEffect]> {
+  let summary = "TODO: replace with pad_tensors when ready.";
+
+  let description = [{
+    `linalg.simple_pad` is a tmp placeholder for padding and packing on tensors.
+    Its semantics are to pad a partially dynamic tensor to a fully static tensor
+    where the static sizes are assumed to be greater than the dynamic sizes. The
+    op perforrms "high" padding (i.e. it adds trailing padding values until the 
+    desired size is met).
+  }];
+
+  let arguments = (ins AnyRankedTensor:$tensor, AnyType:$padding);
+  let results = (outs AnyRankedTensor:$result);
+
+  // TODO: verify all static result, some dynamic input, static shapes match,
+  // element types match, ranks match etc. Use pad_tensors when ready but for
+  // now just let it ne fully specified by traits.
+  let verifier = ?;
+
+  let extraClassDeclaration = [{
+    RankedTensorType getSourceType() {
+      return tensor().getType().cast<RankedTensorType>(); }
+    RankedTensorType getResultType() {
+      return getResult().getType().cast<RankedTensorType>(); }
+   }];
+
+  let assemblyFormat = [{
+    $tensor `pad` $padding attr-dict `:`
+      type($tensor) `to` type($result) `pad` type($padding)
+  }];
+}
+
 def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>,
     Arguments<(ins Variadic<AnyType>:$values)> {
   let summary = "Linalg yield operation";

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 611ab6867372..f359992e5ff1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -345,6 +345,9 @@ enum class LinalgTilingLoopType {
 using TileSizeComputationFunction =
     std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
 
+using PaddingValueComputationFunction =
+    std::function<Value(OpBuilder &, Operation *)>;
+
 struct LinalgTilingOptions {
   /// Computation function that returns the tile sizes for each operation.
   /// Delayed construction of constant tile sizes should occur to interoperate
@@ -393,6 +396,18 @@ struct LinalgTilingOptions {
     distribution = std::move(distributionOptions);
     return *this;
   }
+
+  /// Computation function that returns a padding value to use when padding to
+  /// force static sizes. When `paddingValueComputationFunction` is set, padding
+  /// operations are introduced, that guarantee the underlying op is statically
+  /// shaped and can thus be vectorized.
+  PaddingValueComputationFunction paddingValueComputationFunction = nullptr;
+
+  LinalgTilingOptions &
+  setPaddingValueComputationFunction(PaddingValueComputationFunction fun) {
+    paddingValueComputationFunction = std::move(fun);
+    return *this;
+  }
 };
 
 /// Canonicalization patterns relevant to apply after tiling patterns. These are
@@ -403,6 +418,11 @@ getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
 void populateLinalgTilingCanonicalizationPatterns(
     OwningRewritePatternList &patterns, MLIRContext *ctx);
 
+/// Base pattern that applied the tiling transformation specified by `options`.
+/// Abort and return failure in 2 cases:
+///   1. if the tiling specification is invalid and tiling fails to occur.
+///   2. if tiling occurs but `options.paddingValueComputationFunction` is set
+///      and some operand shape cannot be bounded statically.
 struct LinalgBaseTilingPattern : public RewritePattern {
   // Entry point to match any LinalgOp OpInterface.
   LinalgBaseTilingPattern(LinalgTilingOptions options,

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index d5f44c3e63da..2b3a054338ab 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -14,6 +14,7 @@
 #define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
 
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpImplementation.h"
 

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index 62c371b2f97d..6c72b47f2ac3 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -108,6 +108,28 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
         return $_op.sizes();
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return a vector of all the static or dynamic sizes of the op.
+      }],
+      /*retTy=*/"SmallVector<OpFoldResult, 4>",
+      /*methodName=*/"getMixedSizes",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        SmallVector<OpFoldResult, 4> res;
+        std::array<unsigned, 3> ranks = $_op.getArrayAttrRanks();
+        unsigned numDynamic = 0;
+        unsigned count = ranks[getOffsetOperandGroupPosition()];
+        for (unsigned idx = 0; idx < count; ++idx) {
+          if (isDynamicSize(idx))
+            res.push_back($_op.sizes()[numDynamic++]);
+          else
+            res.push_back($_op.static_sizes()[idx]);
+        }
+        return res;
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the dynamic stride operands.
@@ -359,6 +381,9 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
   ];
 
   let extraClassDeclaration = [{
+    static unsigned getOffsetOperandGroupPosition() { return 0; }
+    static unsigned getSizeOperandGroupPosition() { return 1; }
+    static unsigned getStrideOperandGroupPosition() { return 2; }
     static StringRef getStaticOffsetsAttrName() {
       return "static_offsets";
     }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 283ff20f611b..a76d70c8cd5f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -25,6 +25,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include <type_traits>
@@ -105,6 +106,118 @@ mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
   return *this;
 }
 
+/// Try to compute a static bounding box for `operand`
+/// Return success if either:
+///   1. The operand is already statically shaped, `result` is left unchanged.
+///   2. The operand is (partially) dynamic, `result` is the result of a freshly
+///      created SimplePadOp.
+/// Return failure if the operand cannot be padded to a static shape.
+static LogicalResult padOperandToSmallestStaticBoundingBox(
+    PatternRewriter &rewriter, linalg::LinalgOp opToPad, Value operand,
+    const LinalgTilingOptions &options, Value &result) {
+  auto tensorType = operand.getType().cast<RankedTensorType>();
+  // Already static shape, no need to pad.
+  if (tensorType.hasStaticShape())
+    return success();
+  auto subtensor = operand.getDefiningOp<SubTensorOp>();
+  // Not a subtensor, cannot construct a static bounding box.
+  if (!subtensor)
+    return failure();
+  SmallVector<int64_t> staticSizes;
+  staticSizes.reserve(tensorType.getRank());
+  auto shapedOp =
+      cast<OffsetSizeAndStrideOpInterface>(subtensor.getOperation());
+  for (auto size : shapedOp.getMixedSizes()) {
+    auto indexAttr = size.is<Attribute>()
+                         ? size.get<Attribute>().dyn_cast<IntegerAttr>()
+                         : linalg::getSmallestBoundingIndex(size.get<Value>());
+    // SmallestBoundingIndex must exist for all sizes.
+    // For now return an error if we can't find it.
+    if (!indexAttr)
+      return rewriter.notifyMatchFailure(
+          opToPad, "No constant bounding box can be found for padding");
+    staticSizes.push_back(indexAttr.getInt());
+  }
+  Value pad = options.paddingValueComputationFunction(rewriter, opToPad);
+  auto staticTensorType =
+      RankedTensorType::get(staticSizes, tensorType.getElementType());
+  result = rewriter.create<linalg::SimplePadOp>(opToPad->getLoc(),
+                                                staticTensorType, operand, pad);
+  return success();
+}
+
+// Try to create a static bounding box around each operand of `res.op`.
+// If successful, `res.op` is rewritten in static form with padded operands.
+// `res.op` is updated to the cloned static form of the op on success.
+static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
+                                       TiledLinalgOp &res,
+                                       const LinalgTilingOptions &options) {
+  LinalgOp opToPad = res.op;
+  Location loc = opToPad->getLoc();
+
+  // If the op is fully static, it does not need padding.
+  // TODO: there are cases where we may still want to pad to larger sizes.
+  if (llvm::all_of(opToPad.getShapedOperands(), [](Value v) {
+        return v.getType().cast<RankedTensorType>().hasStaticShape();
+      }))
+    return success();
+
+  OpBuilder::InsertionGuard g(rewriter);
+  // Set IP after op because we also take the dims of the original output.
+  rewriter.setInsertionPointAfter(opToPad);
+  // Make a copy of the shaped operands and update it.
+  SmallVector<Value> operands = opToPad.getShapedOperands();
+  for (Value &v : operands) {
+    Value paddedOperand;
+    // If padding was requested but the shape cannot be bounded statically then
+    // the pattern fails to apply.
+    if (failed(padOperandToSmallestStaticBoundingBox(rewriter, opToPad, v,
+                                                     options, paddedOperand))) {
+      return failure();
+    }
+    // Update v if we indeed got a padded operand.
+    v = paddedOperand ? paddedOperand : v;
+  }
+
+  // Clone `opToPad` to operate on the statically padded shapes.
+  auto resultTensorTypes =
+      ValueRange(operands).take_back(opToPad.getNumOutputs()).getTypes();
+  ValueRange otherOperands = opToPad.getAssumedNonShapedOperands();
+  operands.append(otherOperands.begin(), otherOperands.end());
+  linalg::LinalgOp paddedOp =
+      opToPad.clone(rewriter, loc, resultTensorTypes, operands);
+
+  // Recover the subtensor out of the new static results. This keeps the
+  // original linalg op around because it uses the dims of the original results.
+  // This later folds away.
+  SmallVector<Value> paddedSubviewResults;
+  paddedSubviewResults.reserve(opToPad->getNumResults());
+  Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+  Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+  llvm::SetVector<Operation *> newUsersOfOpToPad;
+  for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) {
+    auto rank = std::get<0>(it).getType().cast<RankedTensorType>().getRank();
+    SmallVector<Value> offsets(rank, zero);
+    auto sizes = llvm::to_vector<4>(
+        llvm::map_range(llvm::seq<unsigned>(0, rank), [&](unsigned d) -> Value {
+          auto dimOp = rewriter.create<DimOp>(loc, std::get<0>(it), d);
+          newUsersOfOpToPad.insert(dimOp);
+          return dimOp;
+        }));
+    SmallVector<Value> strides(rank, one);
+    paddedSubviewResults.push_back(rewriter.create<SubTensorOp>(
+        loc, std::get<1>(it), offsets, sizes, strides));
+  }
+  // Replace the transient `opToPad` locally, except for uses that we just
+  // created for the purpose of extracting the dims.
+  rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) {
+    return !newUsersOfOpToPad.contains(opOp.getOwner());
+  });
+
+  res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults};
+  return success();
+}
+
 /// Linalg base tiling pattern.
 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
@@ -130,11 +243,34 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
   if (!res)
     return failure();
 
-  // Return relevant information to derived pattern.
-  result = *res;
+  // Setup RAII guard to return properly.
+  bool succeeded = true;
+  LinalgOp tiledOp = res->op;
+  auto guard = llvm::make_scope_exit([&]() {
+    if (!succeeded)
+      return;
+    // Return relevant information to derived pattern.
+    result = *res;
+    // Replace marker on both tiledOp and tiledAndPaddedOp, if necessary.
+    marker.replaceLinalgMarker(rewriter, tiledOp);
+    if (tiledOp != res->op)
+      marker.replaceLinalgMarker(rewriter, res->op);
+  });
+
+  // Consider padding on the fly only if the op has tensor semantics.
+  if (!options.paddingValueComputationFunction ||
+      !linalgOp.hasTensorSemantics())
+    return success();
+
+  // Try to pad on the fly by rewriting res->op as a padded op.
+  if (failed(rewriteAsPaddedOp(rewriter, *res, options))) {
+    // Set so RAII guard does not propagate TiledLinalgOp to `result`.
+    succeeded = false;
+    return failure();
+  }
 
-  // New marker if specified.
-  marker.replaceLinalgMarker(rewriter, res->op.getOperation());
+  // Do not perform replacement of `linalgOp`, let the derived patterns
+  // do this as they see fit, from the resulting TiledLinalgOp.
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 45dd0fd0086a..b8671cfe48fe 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1411,13 +1411,20 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
     return Value{*dynExtents};
   }
 
+  // The size at the given index is now known to be a dynamic size.
+  unsigned unsignedIndex = index.getValue().getZExtValue();
+
+  if (auto subtensor = dyn_cast_or_null<SubTensorOp>(definingOp)) {
+    assert(subtensor.isDynamicSize(unsignedIndex) &&
+           "Expected dynamic subtensor size");
+    return subtensor.getDynamicSize(unsignedIndex);
+  }
+
   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
   auto memrefType = argTy.dyn_cast<MemRefType>();
   if (!memrefType)
     return {};
 
-  // The size at the given index is now known to be a dynamic size of a memref.
-  unsigned unsignedIndex = index.getValue().getZExtValue();
   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
     return *(alloc.getDynamicSizes().begin() +
              memrefType.getDynamicDimIndex(unsignedIndex));

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 44743eaedc8c..6dc0768bc2e3 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -814,3 +814,13 @@ func @fill_tensor(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
 // CHECK: %{{.+}} = linalg.fill(%{{.+}}, %{{.+}}) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
+
+// -----
+
+// TODO: this op should disappear once pad_tensors is available and connected.
+// CHECK-LABEL: func @simple_pad
+func @simple_pad(%0: tensor<?x4x?xf32>, %pad: f32) {
+//     CHECK:   linalg.simple_pad %{{.+}} pad %{{.+}}: tensor<?x4x?xf32> to tensor<8x4x8xf32>
+  %1 = linalg.simple_pad %0 pad %pad: tensor<?x4x?xf32> to tensor<8x4x8xf32> pad f32
+  return
+}

diff  --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
new file mode 100644
index 000000000000..e4121083e240
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-tile-and-pad-pattern -canonicalize | FileCheck %s
+
+// CHECK-LABEL: func @matmul_tensors(
+// CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME:    %[[TB:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME:    %[[TC:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func @matmul_tensors(
+  %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+//      CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xf32>) {
+//      CHECK:   %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
+//      CHECK:     %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<?x?xf32>) {
+//      CHECK:       %[[sTA:.*]] = subtensor %[[TA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK:       %[[sTB:.*]] = subtensor %[[TB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+//      CHECK:       %[[sTC:.*]] = subtensor %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+
+// Dynamic op has been canonicalized away.
+//  CHECK-NOT:       linalg.matmul {{.*}} tensor<?x?xf32>
+
+// Padding injects static information.
+//      CHECK:       %[[pA:.*]] = linalg.simple_pad %[[sTA]] pad %{{.*}} : tensor<?x?xf32> to tensor<2x4xf32> pad f32
+//      CHECK:       %[[pB:.*]] = linalg.simple_pad %[[sTB]] pad %{{.*}} : tensor<?x?xf32> to tensor<4x3xf32> pad f32
+//      CHECK:       %[[pC:.*]] = linalg.simple_pad %[[sTC]] pad %{{.*}} : tensor<?x?xf32> to tensor<2x3xf32> pad f32
+//      CHECK:       %[[pD:.*]] = linalg.matmul ins(%[[pA]], %[[pB]] : tensor<2x4xf32>, tensor<4x3xf32>)
+// CHECK-SAME:                                  outs(%[[pC]] : tensor<2x3xf32>)  -> tensor<2x3xf32>
+//      CHECK:       %[[sTD:.*]] = subtensor %[[pD]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<2x3xf32> to tensor<?x?xf32>
+//      CHECK:       %[[TD:.*]] = subtensor_insert %[[sTD]] into %[[TC2]][{{.*}}]  : tensor<?x?xf32> into tensor<?x?xf32>
+//      CHECK:       scf.yield %[[TD]] : tensor<?x?xf32>
+//      CHECK:     scf.yield %[[TD2]] : tensor<?x?xf32>
+//      CHECK:   scf.yield %[[TD1]] : tensor<?x?xf32>
+  %0 = linalg.matmul {__internal_linalg_transform__ = "tile-and-pad"}
+      ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
+     outs(%arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32>
+
+//      CHECK: return %[[TD0]] : tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index c2b4c7b9c821..87f81dbbf1fd 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -79,6 +79,9 @@ struct TestLinalgTransforms
       *this, "test-affine-min-scf-canonicalization-patterns",
       llvm::cl::desc("Test affine-min + scf canonicalization patterns."),
       llvm::cl::init(false)};
+  Option<bool> testTileAndPadPattern{
+      *this, "test-tile-and-pad-pattern",
+      llvm::cl::desc("Test tile and pad pattern"), llvm::cl::init(false)};
 };
 } // end anonymous namespace
 
@@ -487,6 +490,27 @@ static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
     applyOpPatternsAndFold(minOp, frozenPatterns);
   });
 }
+
+// For now, just assume it is the zero of type.
+// In the future, it should be the zero of type + op.
+static Value getNeutralOfLinalgOp(OpBuilder &b, Operation *op) {
+  auto t = op->getResult(0).getType().cast<ShapedType>().getElementType();
+  return b.create<ConstantOp>(op->getLoc(), t, b.getZeroAttr(t));
+}
+
+static void applyTileAndPadPattern(FuncOp funcOp) {
+  MLIRContext *context = funcOp.getContext();
+  OwningRewritePatternList tilingPattern;
+  auto linalgTilingOptions =
+      linalg::LinalgTilingOptions()
+          .setTileSizes({2, 3, 4})
+          .setPaddingValueComputationFunction(getNeutralOfLinalgOp);
+  tilingPattern.insert<linalg::LinalgTilingPattern<linalg::MatmulOp>>(
+      context, linalgTilingOptions,
+      linalg::LinalgMarker(Identifier::get("tile-and-pad", context)));
+  applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnFunction() {
   auto lambda = [&](void *) {
@@ -520,6 +544,8 @@ void TestLinalgTransforms::runOnFunction() {
     return applyLinalgToVectorPatterns(getFunction());
   if (testAffineMinSCFCanonicalizationPatterns)
     return applyAffineMinSCFCanonicalizationPatterns(getFunction());
+  if (testTileAndPadPattern)
+    return applyTileAndPadPattern(getFunction());
 }
 
 namespace mlir {


        


More information about the llvm-branch-commits mailing list