[Mlir-commits] [mlir] d0ec4a8 - [mlir][linalg] Add pad and hoist test pass.

Tobias Gysi llvmlistbot at llvm.org
Fri Oct 29 08:09:27 PDT 2021


Author: Tobias Gysi
Date: 2021-10-29T15:08:16Z
New Revision: d0ec4a8ed9a39b4e3a35e0826b63ac6c5bc21da1

URL: https://github.com/llvm/llvm-project/commit/d0ec4a8ed9a39b4e3a35e0826b63ac6c5bc21da1
DIFF: https://github.com/llvm/llvm-project/commit/d0ec4a8ed9a39b4e3a35e0826b63ac6c5bc21da1.diff

LOG: [mlir][linalg] Add pad and hoist test pass.

Adding a padding and hoisting pattern, a test pass, and tests. The patch prepares the split of tiling/fusion and padding.

Depends On D112255

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D112412

Added: 
    mlir/test/Dialect/Linalg/pad-and-hoist.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6f5ed28d9bf78..bbaabac10cdb0 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -460,6 +460,47 @@ using PaddingValueComputationFunction =
 /// OpOperand shall be marked as nofold to enable packing.
 using PaddingNoFoldComputationFunction = std::function<bool(OpOperand &)>;
 
+/// Callback returning the number of loops to hoist the pad tensor operation
+/// defining the given OpOperand.
+using PaddingHoistComputationFunction = std::function<int64_t(OpOperand &)>;
+
+struct LinalgPaddingOptions {
+  /// Callback returning the padding value to use for a given OpOperand or
+  /// failure for no padding. Padding operations are introduced if
+  /// `paddingValueComputationFunction` is set and does not return failure.
+  /// Padding all operands guarantees the operation is statically shaped and
+  /// thus can be vectorized.
+  PaddingValueComputationFunction paddingValueComputationFunction = nullptr;
+
+  LinalgPaddingOptions &
+  setPaddingValueComputationFunction(PaddingValueComputationFunction fun) {
+    paddingValueComputationFunction = std::move(fun);
+    return *this;
+  }
+
+  /// Callback returning true if the pad tensor operation defining the given
+  /// OpOperand shall be marked as nofold to enable packing. A padding operation
+  /// is only marked nofold if `paddingNoFoldComputationFunction` is set and
+  /// returns true. Otherwise, the nofold attribute is set to false.
+  PaddingNoFoldComputationFunction paddingNoFoldComputationFunction = nullptr;
+
+  LinalgPaddingOptions &
+  setPaddingNoFoldComputationFunction(PaddingNoFoldComputationFunction fun) {
+    paddingNoFoldComputationFunction = std::move(fun);
+    return *this;
+  }
+
+  /// Callback returning the number of loops to hoist the pad tensor operation
+  /// defining the given OpOperand.
+  PaddingHoistComputationFunction paddingHoistComputationFunction = nullptr;
+
+  LinalgPaddingOptions &
+  setPaddingHoistComputationFunction(PaddingHoistComputationFunction fun) {
+    paddingHoistComputationFunction = std::move(fun);
+    return *this;
+  }
+};
+
 struct LinalgTilingOptions {
   /// Computation function that returns the tile sizes for each operation.
   /// Delayed construction of constant tile sizes should occur to interoperate
@@ -650,6 +691,35 @@ struct LinalgGenericTilingPattern : public LinalgBaseTilingPattern {
   }
 };
 
+///
+/// Linalg padding pattern.
+///
+/// Apply the `padding` transformation as a pattern.
+/// `filter` controls LinalgTransformMarker matching and update when specified.
+/// See `padding` for more details.
+struct LinalgPaddingPattern : public RewritePattern {
+  // Entry point to match any LinalgOp OpInterface.
+  LinalgPaddingPattern(
+      MLIRContext *context,
+      LinalgPaddingOptions options = LinalgPaddingOptions(),
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1);
+  // Entry point to match a specific LinalgOp.
+  LinalgPaddingPattern(
+      StringRef opName, MLIRContext *context,
+      LinalgPaddingOptions options = LinalgPaddingOptions(),
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1);
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  /// LinalgTransformMarker handles special attribute manipulations.
+  LinalgTransformationFilter filter;
+  /// Options to control padding and hoisting.
+  LinalgPaddingOptions options;
+};
+
 struct LinalgFusionOptions {
   /// List of operands indices to use for fusion.
   llvm::SmallSet<unsigned, 1> indicesToFuse = {};

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 18c3620783f36..f81ce919a4faf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/SCF/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -470,6 +471,64 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
   return success();
 }
 
+/// Linalg padding pattern.
+mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
+    MLIRContext *context, LinalgPaddingOptions options,
+    LinalgTransformationFilter filter, PatternBenefit benefit)
+    : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
+      options(options) {}
+
+mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
+    StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
+    LinalgTransformationFilter filter, PatternBenefit benefit)
+    : RewritePattern(opName, benefit, context, {}), filter(filter),
+      options(options) {}
+
+LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
+    Operation *op, PatternRewriter &rewriter) const {
+  LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
+  if (!linalgOp)
+    return failure();
+  if (!linalgOp.hasTensorSemantics())
+    return failure();
+  if (failed(filter.checkAndNotify(rewriter, op)))
+    return failure();
+
+  // Pad the operation.
+  LinalgOp paddedOp;
+  FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp(
+      rewriter, linalgOp, options.paddingValueComputationFunction,
+      options.paddingNoFoldComputationFunction, paddedOp);
+  if (failed(newResults))
+    return failure();
+
+  // Compute the desired hoisting depths.
+  SmallVector<int64_t> depths;
+  if (options.paddingHoistComputationFunction) {
+    for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands())
+      depths.push_back(options.paddingHoistComputationFunction(*opOperand));
+  }
+
+  // Hoist the padding.
+  for (auto en : enumerate(depths)) {
+    OpOperand &opOperand = paddedOp->getOpOperand(en.index());
+    auto padTensorOp = opOperand.get().getDefiningOp<PadTensorOp>();
+    if (!padTensorOp || en.value() == 0)
+      continue;
+    PadTensorOp hoistedOp;
+    FailureOr<Value> newResult =
+        hoistPaddingOnTensors(padTensorOp, en.value(), hoistedOp);
+    if (failed(newResult))
+      continue;
+    rewriter.replaceOp(padTensorOp, newResult.getValue());
+  }
+
+  // Replace the original operation to pad.
+  rewriter.replaceOp(op, newResults.getValue());
+  filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
+  return success();
+}
+
 /// Linalg generic interchange pattern.
 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
     MLIRContext *context, ArrayRef<unsigned> interchangeVector,

diff  --git a/mlir/test/Dialect/Linalg/pad-and-hoist.mlir b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir
new file mode 100644
index 0000000000000..93e2bf5f189d2
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir
@@ -0,0 +1,164 @@
+// RUN: mlir-opt %s -test-linalg-transform-patterns="test-pad-pattern pack-paddings=1,1,0 hoist-paddings=2,1,0" -cse -canonicalize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns="test-pad-pattern pack-paddings=1,1,0 hoist-paddings=4,3,0" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-DOUBLE
+
+// CHECK-DAG: #[[MAP0:[0-9a-z]+]] = affine_map<(d0) -> (5, -d0 + 24)>
+// CHECK-DAG: #[[MAP1:[0-9a-z]+]] = affine_map<(d0) -> (8, -d0 + 12)>
+// CHECK-DAG: #[[DIV6:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 6)>
+#map0 = affine_map<(d0) -> (5, -d0 + 24)>
+#map1 = affine_map<(d0) -> (8, -d0 + 12)>
+#map2 = affine_map<(d0) -> (7, -d0 + 25)>
+
+//      CHECK:  single_tiling
+//      CHECK-DOUBLE:  single_tiling
+
+// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
+// CHECK-SAME:    %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32>
+// CHECK-SAME:    %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>
+func @single_tiling(%arg0: tensor<24x12xf32>,
+                    %arg1: tensor<12x25xf32>,
+                    %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+  //  CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  //  CHECK-DAG: %[[C5:.*]] = arith.constant 5
+  //  CHECK-DAG: %[[C8:.*]] = arith.constant 8
+  %c0 = arith.constant 0 : index
+  %c12 = arith.constant 12 : index
+  %c25 = arith.constant 25 : index
+  %c24 = arith.constant 24 : index
+  %c6 = arith.constant 6 : index
+  %c7 = arith.constant 7 : index
+  %c5 = arith.constant 5 : index
+
+  //      CHECK:  scf.for %[[IV0:[0-9a-zA-Z]*]] =
+  %0 = scf.for %arg3 = %c0 to %c24 step %c5 iter_args(%arg4 = %arg2) -> (tensor<24x25xf32>) {
+
+    // Packing the first input operand for all values of IV2 (IV2x5x6).
+    //      CHECK:  = linalg.init_tensor [2, 5, 6]
+    //      CHECK:  %[[PT0:.*]] = scf.for %[[P0IV2:[0-9a-z]+]] =
+    //        CHECK:   %[[PIDX0:.*]] = affine.apply #[[DIV6]](%[[P0IV2]])
+    //        CHECK:   %[[TS0:.*]] = affine.min #[[MAP0]](%[[IV0]])
+    //        CHECK:   %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
+    //   CHECK-SAME:                                     %[[IV0]], %[[P0IV2]]
+    //   CHECK-SAME:                                     %[[TS0]], 6
+    //        CHECK:   %[[V0:.*]] = arith.subi %[[C5]], %[[TS0]]
+    //        CHECK:   %[[T1:.*]] = linalg.pad_tensor %[[T0]] nofold {{.*}} high[%[[V0]]
+    //        CHECK:   %[[T2:.*]] = tensor.insert_slice %[[T1:.*]] into %{{.*}}[%[[PIDX0]], 0, 0]
+    //        CHECK:   scf.yield %[[T2:.*]]
+
+    //      CHECK:  scf.for %[[IV1:[0-9a-zA-Z]*]] =
+    %1 = scf.for %arg5 = %c0 to %c25 step %c7 iter_args(%arg6 = %arg4) -> (tensor<24x25xf32>) {
+
+      // Packing the second input operand for all values of IV2 (IV2x6x8).
+      //      CHECK:  = linalg.init_tensor [2, 6, 8]
+      //      CHECK:  %[[PT1:.*]] = scf.for %[[P1IV2:[0-9a-z]+]] =
+      //        CHECK:   %[[PIDX1:.*]] = affine.apply #[[DIV6]](%[[P1IV2]])
+      //        CHECK:   %[[TS1:.*]] = affine.min #[[MAP1]](%[[IV1]])
+      //        CHECK:   %[[T3:.*]] = tensor.extract_slice %[[ARG1]]
+      //   CHECK-SAME:                                     %[[P1IV2]], %[[IV1]]
+      //   CHECK-SAME:                                     6, %[[TS1]]
+      //        CHECK:   %[[V1:.*]] = arith.subi %[[C8]], %[[TS1]]
+      //        CHECK:   %[[T4:.*]] = linalg.pad_tensor %[[T3]] nofold {{.*}} high[%[[C0]], %[[V1]]
+      //        CHECK:   %[[T5:.*]] = tensor.insert_slice %[[T4:.*]] into %{{.*}}[%[[PIDX1]], 0, 0]
+      //        CHECK:   scf.yield %[[T5:.*]]
+
+      //      CHECK:  scf.for %[[IV2:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG4:.*]] =
+      %2 = scf.for %arg7 = %c0 to %c12 step %c6 iter_args(%arg8 = %arg6) -> (tensor<24x25xf32>) {
+        %3 = affine.min #map0(%arg3)
+        // Index the packed operands.
+        //    CHECK-DAG:   %[[IDX:.*]] = affine.apply #[[DIV6]](%[[IV2]])
+        //    CHECK-DAG:   %[[T6:.*]] = tensor.extract_slice %[[PT0]][%[[IDX]]
+        //    CHECK-DAG:   %[[T7:.*]] = tensor.extract_slice %[[PT1]][%[[IDX]]
+        %4 = tensor.extract_slice %arg0[%arg3, %arg7] [%3, 6] [1, 1] : tensor<24x12xf32> to tensor<?x6xf32>
+        %5 = affine.min #map1(%arg5)
+        %6 = tensor.extract_slice %arg1[%arg7, %arg5] [6, %5] [1, 1] : tensor<12x25xf32> to tensor<6x?xf32>
+
+        // Pad the output operand without setting the nofold attribute.
+        //    CHECK-DAG:   %[[T8:.*]] = tensor.extract_slice %[[ARG4]][%[[IV0]], %[[IV1]]
+        //        CHECK:   %[[T9:.*]] = linalg.pad_tensor %[[T8]] low
+        %7 = tensor.extract_slice %arg8[%arg3, %arg5] [%3, %5] [1, 1] : tensor<24x25xf32> to tensor<?x?xf32>
+
+        // Check matmul uses the packed input operands and the padded output operand.
+        //        CHECK:   = linalg.matmul ins(%[[T6]], %[[T7]]{{.*}} outs(%[[T9]]
+        %8 = linalg.matmul {__internal_linalg_transform__ = "pad"} ins(%4, %6 : tensor<?x6xf32>, tensor<6x?xf32>) outs(%7 : tensor<?x?xf32>) -> tensor<?x?xf32>
+        %9 = tensor.insert_slice %8 into %arg8[%arg3, %arg5] [%3, %5] [1, 1] : tensor<?x?xf32> into tensor<24x25xf32>
+        scf.yield %9 : tensor<24x25xf32>
+      }
+      scf.yield %2 : tensor<24x25xf32>
+    }
+    scf.yield %1 : tensor<24x25xf32>
+  }
+  return %0 : tensor<24x25xf32>
+}
+
+// -----
+
+#map0 = affine_map<(d0) -> (15, -d0 + 24)>
+#map1 = affine_map<(d0) -> (16, -d0 + 25)>
+#map2 = affine_map<(d0, d1) -> (5, -d0 + d1)>
+#map3 = affine_map<(d0, d1) -> (d0 + d1)>
+#map4 = affine_map<(d0, d1) -> (6, -d0 + d1)>
+
+//      CHECK:  double_tiling
+//      CHECK-DOUBLE:  double_tiling
+
+// CHECK-DOUBLE-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
+// CHECK-DOUBLE-SAME:    %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32>
+// CHECK-DOUBLE-SAME:    %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>
+func @double_tiling(%arg0: tensor<24x12xf32>,
+                    %arg1: tensor<12x25xf32>,
+                    %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+  %c15 = arith.constant 15 : index
+  %c16 = arith.constant 16 : index
+  %c24 = arith.constant 24 : index
+  %c25 = arith.constant 25 : index
+  %c0 = arith.constant 0 : index
+  %c5 = arith.constant 5 : index
+  %c6 = arith.constant 6 : index
+
+  // Packing the first input operand.
+  //    CHECK-DOUBLE:  = linalg.init_tensor
+  //    CHECK-DOUBLE:  = linalg.pad_tensor {{.*}} nofold
+
+  //    CHECK-DOUBLE:  scf.for %[[IV0:[0-9a-zA-Z]*]] =
+  %0 = scf.for %arg3 = %c0 to %c24 step %c15 iter_args(%arg4 = %arg2) -> (tensor<24x25xf32>) {
+
+    // Packing the second input operand.
+    //    CHECK-DOUBLE:  = linalg.init_tensor
+    //    CHECK-DOUBLE:  = linalg.pad_tensor {{.*}} nofold
+
+    //    CHECK-DOUBLE:  scf.for %[[IV1:[0-9a-zA-Z]*]] =
+    %1 = scf.for %arg5 = %c0 to %c25 step %c16 iter_args(%arg6 = %arg4) -> (tensor<24x25xf32>) {
+      %2 = affine.min #map0(%arg3)
+      %3 = affine.min #map1(%arg5)
+      %4 = tensor.extract_slice %arg6[%arg3, %arg5] [%2, %3] [1, 1] : tensor<24x25xf32> to tensor<?x?xf32>
+
+      //    CHECK-DOUBLE:  scf.for %[[IV2:[0-9a-zA-Z]*]] =
+      %5 = scf.for %arg7 = %c0 to %2 step %c5 iter_args(%arg8 = %4) -> (tensor<?x?xf32>) {
+
+        //    CHECK-DOUBLE:  scf.for %[[IV3:[0-9a-zA-Z]*]] =
+        %7 = scf.for %arg9 = %c0 to %3 step %c6 iter_args(%arg10 = %arg8) -> (tensor<?x?xf32>) {
+          %8 = affine.min #map2(%arg7, %2)
+          %9 = affine.apply #map3(%arg7, %arg3)
+          %10 = tensor.extract_slice %arg0[%9, 0] [%8, 12] [1, 1] : tensor<24x12xf32> to tensor<?x12xf32>
+          %11 = affine.min #map4(%arg9, %3)
+          %12 = affine.apply #map3(%arg9, %arg5)
+          %13 = tensor.extract_slice %arg1[0, %12] [12, %11] [1, 1] : tensor<12x25xf32> to tensor<12x?xf32>
+          %14 = affine.min #map2(%arg7, %2)
+          %15 = affine.min #map4(%arg9, %3)
+          %16 = tensor.extract_slice %arg10[%arg7, %arg9] [%14, %15] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+
+          // Pad the output operand and perform the multiplication.
+          //        CHECK-DOUBLE:   = linalg.pad_tensor
+          //        CHECK-DOUBLE:   = linalg.matmul
+          %17 = linalg.matmul {__internal_linalg_transform__ = "pad"} ins(%10, %13 : tensor<?x12xf32>, tensor<12x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
+          %18 = tensor.insert_slice %17 into %arg10[%arg7, %arg9] [%14, %15] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+          scf.yield %18 : tensor<?x?xf32>
+        }
+        scf.yield %7 : tensor<?x?xf32>
+      }
+      %6 = tensor.insert_slice %5 into %arg6[%arg3, %arg5] [%2, %3] [1, 1] : tensor<?x?xf32> into tensor<24x25xf32>
+      scf.yield %6 : tensor<24x25xf32>
+    }
+    scf.yield %1 : tensor<24x25xf32>
+  }
+  return %0 : tensor<24x25xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 78e76ca9ea311..25525fb851d24 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -96,6 +96,9 @@ struct TestLinalgTransforms
   Option<int> testHoistPadding{*this, "test-hoist-padding",
                                llvm::cl::desc("Test hoist padding"),
                                llvm::cl::init(0)};
+  Option<bool> testPadPattern{*this, "test-pad-pattern",
+                              llvm::cl::desc("Test pad pattern"),
+                              llvm::cl::init(false)};
   Option<bool> testTransformPadTensor{
       *this, "test-transform-pad-tensor",
       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
@@ -117,6 +120,14 @@ struct TestLinalgTransforms
       *this, "nofold-operands",
       llvm::cl::desc("Operands to set nofold when test-tile-pattern"),
       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+  ListOption<int64_t> packPaddings{
+      *this, "pack-paddings",
+      llvm::cl::desc("Operand packing flags when test-pad-pattern"),
+      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+  ListOption<int64_t> hoistPaddings{
+      *this, "hoist-paddings",
+      llvm::cl::desc("Operand hoisting depths when test-pad-pattern"),
+      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
   ListOption<int64_t> peeledLoops{
       *this, "peeled-loops",
       llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
@@ -637,6 +648,30 @@ static void applyTilePattern(FuncOp funcOp, std::string loopType,
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
 }
 
+static void applyPadPattern(FuncOp funcOp, ArrayRef<int64_t> packPaddings,
+                            ArrayRef<int64_t> hoistPaddings) {
+  MLIRContext *context = funcOp.getContext();
+  RewritePatternSet padPattern(context);
+  auto linalgPaddingOptions = linalg::LinalgPaddingOptions();
+  auto packFunc = [&](OpOperand &opOperand) {
+    return opOperand.getOperandNumber() < packPaddings.size()
+               ? packPaddings[opOperand.getOperandNumber()]
+               : false;
+  };
+  auto hoistingFunc = [&](OpOperand &opOperand) {
+    return opOperand.getOperandNumber() < hoistPaddings.size()
+               ? hoistPaddings[opOperand.getOperandNumber()]
+               : 0;
+  };
+  linalgPaddingOptions.setPaddingValueComputationFunction(getNeutralOfLinalgOp);
+  linalgPaddingOptions.setPaddingNoFoldComputationFunction(packFunc);
+  linalgPaddingOptions.setPaddingHoistComputationFunction(hoistingFunc);
+  padPattern.add<LinalgPaddingPattern>(
+      context, linalgPaddingOptions,
+      LinalgTransformationFilter(Identifier::get("pad", context)));
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(padPattern));
+}
+
 static void applyInterchangePattern(FuncOp funcOp,
                                     ArrayRef<unsigned> interchangeVector) {
   MLIRContext *context = funcOp.getContext();
@@ -780,6 +815,8 @@ void TestLinalgTransforms::runOnFunction() {
       }
     });
   }
+  if (testPadPattern)
+    return applyPadPattern(getFunction(), packPaddings, hoistPaddings);
   if (testInterchangePattern.hasValue())
     return applyInterchangePattern(getFunction(), testInterchangePattern);
 }


        


More information about the Mlir-commits mailing list