[Mlir-commits] [mlir] d0ec4a8 - [mlir][linalg] Add pad and hoist test pass.
Tobias Gysi
llvmlistbot at llvm.org
Fri Oct 29 08:09:27 PDT 2021
Author: Tobias Gysi
Date: 2021-10-29T15:08:16Z
New Revision: d0ec4a8ed9a39b4e3a35e0826b63ac6c5bc21da1
URL: https://github.com/llvm/llvm-project/commit/d0ec4a8ed9a39b4e3a35e0826b63ac6c5bc21da1
DIFF: https://github.com/llvm/llvm-project/commit/d0ec4a8ed9a39b4e3a35e0826b63ac6c5bc21da1.diff
LOG: [mlir][linalg] Add pad and hoist test pass.
Adding a padding and hoisting pattern, a test pass, and tests. The patch prepares the split of tiling/fusion and padding.
Depends On D112255
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D112412
Added:
mlir/test/Dialect/Linalg/pad-and-hoist.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6f5ed28d9bf78..bbaabac10cdb0 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -460,6 +460,47 @@ using PaddingValueComputationFunction =
/// OpOperand shall be marked as nofold to enable packing.
using PaddingNoFoldComputationFunction = std::function<bool(OpOperand &)>;
+/// Callback returning the number of loops to hoist the pad tensor operation
+/// defining the given OpOperand.
+using PaddingHoistComputationFunction = std::function<int64_t(OpOperand &)>;
+
+struct LinalgPaddingOptions {
+ /// Callback returning the padding value to use for a given OpOperand or
+ /// failure for no padding. Padding operations are introduced if
+ /// `paddingValueComputationFunction` is set and does not return failure.
+ /// Padding all operands guarantees the operation is statically shaped and
+ /// thus can be vectorized.
+ PaddingValueComputationFunction paddingValueComputationFunction = nullptr;
+
+ LinalgPaddingOptions &
+ setPaddingValueComputationFunction(PaddingValueComputationFunction fun) {
+ paddingValueComputationFunction = std::move(fun);
+ return *this;
+ }
+
+ /// Callback returning true if the pad tensor operation defining the given
+ /// OpOperand shall be marked as nofold to enable packing. A padding operation
+ /// is only marked nofold if `paddingNoFoldComputationFunction` is set and
+ /// returns true. Otherwise, the nofold attribute is set to false.
+ PaddingNoFoldComputationFunction paddingNoFoldComputationFunction = nullptr;
+
+ LinalgPaddingOptions &
+ setPaddingNoFoldComputationFunction(PaddingNoFoldComputationFunction fun) {
+ paddingNoFoldComputationFunction = std::move(fun);
+ return *this;
+ }
+
+ /// Callback returning the number of loops to hoist the pad tensor operation
+ /// defining the given OpOperand.
+ PaddingHoistComputationFunction paddingHoistComputationFunction = nullptr;
+
+ LinalgPaddingOptions &
+ setPaddingHoistComputationFunction(PaddingHoistComputationFunction fun) {
+ paddingHoistComputationFunction = std::move(fun);
+ return *this;
+ }
+};
+
struct LinalgTilingOptions {
/// Computation function that returns the tile sizes for each operation.
/// Delayed construction of constant tile sizes should occur to interoperate
@@ -650,6 +691,35 @@ struct LinalgGenericTilingPattern : public LinalgBaseTilingPattern {
}
};
+///
+/// Linalg padding pattern.
+///
+/// Apply the `padding` transformation as a pattern.
+/// `filter` controls LinalgTransformMarker matching and update when specified.
+/// See `padding` for more details.
+struct LinalgPaddingPattern : public RewritePattern {
+ // Entry point to match any LinalgOp OpInterface.
+ LinalgPaddingPattern(
+ MLIRContext *context,
+ LinalgPaddingOptions options = LinalgPaddingOptions(),
+ LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
+ // Entry point to match a specific LinalgOp.
+ LinalgPaddingPattern(
+ StringRef opName, MLIRContext *context,
+ LinalgPaddingOptions options = LinalgPaddingOptions(),
+ LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1);
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override;
+
+private:
+ /// LinalgTransformMarker handles special attribute manipulations.
+ LinalgTransformationFilter filter;
+ /// Options to control padding and hoisting.
+ LinalgPaddingOptions options;
+};
+
struct LinalgFusionOptions {
/// List of operands indices to use for fusion.
llvm::SmallSet<unsigned, 1> indicesToFuse = {};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 18c3620783f36..f81ce919a4faf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Transforms/HoistPadding.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -470,6 +471,64 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
return success();
}
+/// Linalg padding pattern.
+mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
+ MLIRContext *context, LinalgPaddingOptions options,
+ LinalgTransformationFilter filter, PatternBenefit benefit)
+ : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
+ options(options) {}
+
+mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
+ StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
+ LinalgTransformationFilter filter, PatternBenefit benefit)
+ : RewritePattern(opName, benefit, context, {}), filter(filter),
+ options(options) {}
+
+LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
+ Operation *op, PatternRewriter &rewriter) const {
+ LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
+ if (!linalgOp)
+ return failure();
+ if (!linalgOp.hasTensorSemantics())
+ return failure();
+ if (failed(filter.checkAndNotify(rewriter, op)))
+ return failure();
+
+ // Pad the operation.
+ LinalgOp paddedOp;
+ FailureOr<SmallVector<Value>> newResults = rewriteAsPaddedOp(
+ rewriter, linalgOp, options.paddingValueComputationFunction,
+ options.paddingNoFoldComputationFunction, paddedOp);
+ if (failed(newResults))
+ return failure();
+
+ // Compute the desired hoisting depths.
+ SmallVector<int64_t> depths;
+ if (options.paddingHoistComputationFunction) {
+ for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands())
+ depths.push_back(options.paddingHoistComputationFunction(*opOperand));
+ }
+
+ // Hoist the padding.
+ for (auto en : enumerate(depths)) {
+ OpOperand &opOperand = paddedOp->getOpOperand(en.index());
+ auto padTensorOp = opOperand.get().getDefiningOp<PadTensorOp>();
+ if (!padTensorOp || en.value() == 0)
+ continue;
+ PadTensorOp hoistedOp;
+ FailureOr<Value> newResult =
+ hoistPaddingOnTensors(padTensorOp, en.value(), hoistedOp);
+ if (failed(newResult))
+ continue;
+ rewriter.replaceOp(padTensorOp, newResult.getValue());
+ }
+
+ // Replace the original operation to pad.
+ rewriter.replaceOp(op, newResults.getValue());
+ filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
+ return success();
+}
+
/// Linalg generic interchange pattern.
mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
MLIRContext *context, ArrayRef<unsigned> interchangeVector,
diff --git a/mlir/test/Dialect/Linalg/pad-and-hoist.mlir b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir
new file mode 100644
index 0000000000000..93e2bf5f189d2
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir
@@ -0,0 +1,164 @@
+// RUN: mlir-opt %s -test-linalg-transform-patterns="test-pad-pattern pack-paddings=1,1,0 hoist-paddings=2,1,0" -cse -canonicalize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns="test-pad-pattern pack-paddings=1,1,0 hoist-paddings=4,3,0" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-DOUBLE
+
+// CHECK-DAG: #[[MAP0:[0-9a-z]+]] = affine_map<(d0) -> (5, -d0 + 24)>
+// CHECK-DAG: #[[MAP1:[0-9a-z]+]] = affine_map<(d0) -> (8, -d0 + 12)>
+// CHECK-DAG: #[[DIV6:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 6)>
+#map0 = affine_map<(d0) -> (5, -d0 + 24)>
+#map1 = affine_map<(d0) -> (8, -d0 + 12)>
+#map2 = affine_map<(d0) -> (7, -d0 + 25)>
+
+// CHECK: single_tiling
+// CHECK-DOUBLE: single_tiling
+
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32>
+// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>
+func @single_tiling(%arg0: tensor<24x12xf32>,
+ %arg1: tensor<12x25xf32>,
+ %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C5:.*]] = arith.constant 5
+ // CHECK-DAG: %[[C8:.*]] = arith.constant 8
+ %c0 = arith.constant 0 : index
+ %c12 = arith.constant 12 : index
+ %c25 = arith.constant 25 : index
+ %c24 = arith.constant 24 : index
+ %c6 = arith.constant 6 : index
+ %c7 = arith.constant 7 : index
+ %c5 = arith.constant 5 : index
+
+ // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] =
+ %0 = scf.for %arg3 = %c0 to %c24 step %c5 iter_args(%arg4 = %arg2) -> (tensor<24x25xf32>) {
+
+ // Packing the first input operand for all values of IV2 (IV2x5x6).
+ // CHECK: = linalg.init_tensor [2, 5, 6]
+ // CHECK: %[[PT0:.*]] = scf.for %[[P0IV2:[0-9a-z]+]] =
+ // CHECK: %[[PIDX0:.*]] = affine.apply #[[DIV6]](%[[P0IV2]])
+ // CHECK: %[[TS0:.*]] = affine.min #[[MAP0]](%[[IV0]])
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
+ // CHECK-SAME: %[[IV0]], %[[P0IV2]]
+ // CHECK-SAME: %[[TS0]], 6
+ // CHECK: %[[V0:.*]] = arith.subi %[[C5]], %[[TS0]]
+ // CHECK: %[[T1:.*]] = linalg.pad_tensor %[[T0]] nofold {{.*}} high[%[[V0]]
+ // CHECK: %[[T2:.*]] = tensor.insert_slice %[[T1:.*]] into %{{.*}}[%[[PIDX0]], 0, 0]
+ // CHECK: scf.yield %[[T2:.*]]
+
+ // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] =
+ %1 = scf.for %arg5 = %c0 to %c25 step %c7 iter_args(%arg6 = %arg4) -> (tensor<24x25xf32>) {
+
+ // Packing the second input operand for all values of IV2 (IV2x6x8).
+ // CHECK: = linalg.init_tensor [2, 6, 8]
+ // CHECK: %[[PT1:.*]] = scf.for %[[P1IV2:[0-9a-z]+]] =
+ // CHECK: %[[PIDX1:.*]] = affine.apply #[[DIV6]](%[[P1IV2]])
+ // CHECK: %[[TS1:.*]] = affine.min #[[MAP1]](%[[IV1]])
+ // CHECK: %[[T3:.*]] = tensor.extract_slice %[[ARG1]]
+ // CHECK-SAME: %[[P1IV2]], %[[IV1]]
+ // CHECK-SAME: 6, %[[TS1]]
+ // CHECK: %[[V1:.*]] = arith.subi %[[C8]], %[[TS1]]
+ // CHECK: %[[T4:.*]] = linalg.pad_tensor %[[T3]] nofold {{.*}} high[%[[C0]], %[[V1]]
+ // CHECK: %[[T5:.*]] = tensor.insert_slice %[[T4:.*]] into %{{.*}}[%[[PIDX1]], 0, 0]
+ // CHECK: scf.yield %[[T5:.*]]
+
+ // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG4:.*]] =
+ %2 = scf.for %arg7 = %c0 to %c12 step %c6 iter_args(%arg8 = %arg6) -> (tensor<24x25xf32>) {
+ %3 = affine.min #map0(%arg3)
+ // Index the packed operands.
+ // CHECK-DAG: %[[IDX:.*]] = affine.apply #[[DIV6]](%[[IV2]])
+ // CHECK-DAG: %[[T6:.*]] = tensor.extract_slice %[[PT0]][%[[IDX]]
+ // CHECK-DAG: %[[T7:.*]] = tensor.extract_slice %[[PT1]][%[[IDX]]
+ %4 = tensor.extract_slice %arg0[%arg3, %arg7] [%3, 6] [1, 1] : tensor<24x12xf32> to tensor<?x6xf32>
+ %5 = affine.min #map1(%arg5)
+ %6 = tensor.extract_slice %arg1[%arg7, %arg5] [6, %5] [1, 1] : tensor<12x25xf32> to tensor<6x?xf32>
+
+ // Pad the output operand without setting the nofold attribute.
+ // CHECK-DAG: %[[T8:.*]] = tensor.extract_slice %[[ARG4]][%[[IV0]], %[[IV1]]
+ // CHECK: %[[T9:.*]] = linalg.pad_tensor %[[T8]] low
+ %7 = tensor.extract_slice %arg8[%arg3, %arg5] [%3, %5] [1, 1] : tensor<24x25xf32> to tensor<?x?xf32>
+
+ // Check matmul uses the packed input operands and the padded output operand.
+ // CHECK: = linalg.matmul ins(%[[T6]], %[[T7]]{{.*}} outs(%[[T9]]
+ %8 = linalg.matmul {__internal_linalg_transform__ = "pad"} ins(%4, %6 : tensor<?x6xf32>, tensor<6x?xf32>) outs(%7 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %9 = tensor.insert_slice %8 into %arg8[%arg3, %arg5] [%3, %5] [1, 1] : tensor<?x?xf32> into tensor<24x25xf32>
+ scf.yield %9 : tensor<24x25xf32>
+ }
+ scf.yield %2 : tensor<24x25xf32>
+ }
+ scf.yield %1 : tensor<24x25xf32>
+ }
+ return %0 : tensor<24x25xf32>
+}
+
+// -----
+
+#map0 = affine_map<(d0) -> (15, -d0 + 24)>
+#map1 = affine_map<(d0) -> (16, -d0 + 25)>
+#map2 = affine_map<(d0, d1) -> (5, -d0 + d1)>
+#map3 = affine_map<(d0, d1) -> (d0 + d1)>
+#map4 = affine_map<(d0, d1) -> (6, -d0 + d1)>
+
+// CHECK: double_tiling
+// CHECK-DOUBLE: double_tiling
+
+// CHECK-DOUBLE-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>
+// CHECK-DOUBLE-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32>
+// CHECK-DOUBLE-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>
+func @double_tiling(%arg0: tensor<24x12xf32>,
+ %arg1: tensor<12x25xf32>,
+ %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+ %c15 = arith.constant 15 : index
+ %c16 = arith.constant 16 : index
+ %c24 = arith.constant 24 : index
+ %c25 = arith.constant 25 : index
+ %c0 = arith.constant 0 : index
+ %c5 = arith.constant 5 : index
+ %c6 = arith.constant 6 : index
+
+ // Packing the first input operand.
+ // CHECK-DOUBLE: = linalg.init_tensor
+ // CHECK-DOUBLE: = linalg.pad_tensor {{.*}} nofold
+
+ // CHECK-DOUBLE: scf.for %[[IV0:[0-9a-zA-Z]*]] =
+ %0 = scf.for %arg3 = %c0 to %c24 step %c15 iter_args(%arg4 = %arg2) -> (tensor<24x25xf32>) {
+
+ // Packing the second input operand.
+ // CHECK-DOUBLE: = linalg.init_tensor
+ // CHECK-DOUBLE: = linalg.pad_tensor {{.*}} nofold
+
+ // CHECK-DOUBLE: scf.for %[[IV1:[0-9a-zA-Z]*]] =
+ %1 = scf.for %arg5 = %c0 to %c25 step %c16 iter_args(%arg6 = %arg4) -> (tensor<24x25xf32>) {
+ %2 = affine.min #map0(%arg3)
+ %3 = affine.min #map1(%arg5)
+ %4 = tensor.extract_slice %arg6[%arg3, %arg5] [%2, %3] [1, 1] : tensor<24x25xf32> to tensor<?x?xf32>
+
+ // CHECK-DOUBLE: scf.for %[[IV2:[0-9a-zA-Z]*]] =
+ %5 = scf.for %arg7 = %c0 to %2 step %c5 iter_args(%arg8 = %4) -> (tensor<?x?xf32>) {
+
+ // CHECK-DOUBLE: scf.for %[[IV3:[0-9a-zA-Z]*]] =
+ %7 = scf.for %arg9 = %c0 to %3 step %c6 iter_args(%arg10 = %arg8) -> (tensor<?x?xf32>) {
+ %8 = affine.min #map2(%arg7, %2)
+ %9 = affine.apply #map3(%arg7, %arg3)
+ %10 = tensor.extract_slice %arg0[%9, 0] [%8, 12] [1, 1] : tensor<24x12xf32> to tensor<?x12xf32>
+ %11 = affine.min #map4(%arg9, %3)
+ %12 = affine.apply #map3(%arg9, %arg5)
+ %13 = tensor.extract_slice %arg1[0, %12] [12, %11] [1, 1] : tensor<12x25xf32> to tensor<12x?xf32>
+ %14 = affine.min #map2(%arg7, %2)
+ %15 = affine.min #map4(%arg9, %3)
+ %16 = tensor.extract_slice %arg10[%arg7, %arg9] [%14, %15] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+
+ // Pad the output operand and perform the multiplication.
+ // CHECK-DOUBLE: = linalg.pad_tensor
+ // CHECK-DOUBLE: = linalg.matmul
+ %17 = linalg.matmul {__internal_linalg_transform__ = "pad"} ins(%10, %13 : tensor<?x12xf32>, tensor<12x?xf32>) outs(%16 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %18 = tensor.insert_slice %17 into %arg10[%arg7, %arg9] [%14, %15] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+ scf.yield %18 : tensor<?x?xf32>
+ }
+ scf.yield %7 : tensor<?x?xf32>
+ }
+ %6 = tensor.insert_slice %5 into %arg6[%arg3, %arg5] [%2, %3] [1, 1] : tensor<?x?xf32> into tensor<24x25xf32>
+ scf.yield %6 : tensor<24x25xf32>
+ }
+ scf.yield %1 : tensor<24x25xf32>
+ }
+ return %0 : tensor<24x25xf32>
+}
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 78e76ca9ea311..25525fb851d24 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -96,6 +96,9 @@ struct TestLinalgTransforms
Option<int> testHoistPadding{*this, "test-hoist-padding",
llvm::cl::desc("Test hoist padding"),
llvm::cl::init(0)};
+ Option<bool> testPadPattern{*this, "test-pad-pattern",
+ llvm::cl::desc("Test pad pattern"),
+ llvm::cl::init(false)};
Option<bool> testTransformPadTensor{
*this, "test-transform-pad-tensor",
llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
@@ -117,6 +120,14 @@ struct TestLinalgTransforms
*this, "nofold-operands",
llvm::cl::desc("Operands to set nofold when test-tile-pattern"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+ ListOption<int64_t> packPaddings{
+ *this, "pack-paddings",
+ llvm::cl::desc("Operand packing flags when test-pad-pattern"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
+ ListOption<int64_t> hoistPaddings{
+ *this, "hoist-paddings",
+ llvm::cl::desc("Operand hoisting depths when test-pad-pattern"),
+ llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
ListOption<int64_t> peeledLoops{
*this, "peeled-loops",
llvm::cl::desc("Loops to be peeled when test-tile-pattern"),
@@ -637,6 +648,30 @@ static void applyTilePattern(FuncOp funcOp, std::string loopType,
(void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
}
+static void applyPadPattern(FuncOp funcOp, ArrayRef<int64_t> packPaddings,
+ ArrayRef<int64_t> hoistPaddings) {
+ MLIRContext *context = funcOp.getContext();
+ RewritePatternSet padPattern(context);
+ auto linalgPaddingOptions = linalg::LinalgPaddingOptions();
+ auto packFunc = [&](OpOperand &opOperand) {
+ return opOperand.getOperandNumber() < packPaddings.size()
+ ? packPaddings[opOperand.getOperandNumber()]
+ : false;
+ };
+ auto hoistingFunc = [&](OpOperand &opOperand) {
+ return opOperand.getOperandNumber() < hoistPaddings.size()
+ ? hoistPaddings[opOperand.getOperandNumber()]
+ : 0;
+ };
+ linalgPaddingOptions.setPaddingValueComputationFunction(getNeutralOfLinalgOp);
+ linalgPaddingOptions.setPaddingNoFoldComputationFunction(packFunc);
+ linalgPaddingOptions.setPaddingHoistComputationFunction(hoistingFunc);
+ padPattern.add<LinalgPaddingPattern>(
+ context, linalgPaddingOptions,
+ LinalgTransformationFilter(Identifier::get("pad", context)));
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(padPattern));
+}
+
static void applyInterchangePattern(FuncOp funcOp,
ArrayRef<unsigned> interchangeVector) {
MLIRContext *context = funcOp.getContext();
@@ -780,6 +815,8 @@ void TestLinalgTransforms::runOnFunction() {
}
});
}
+ if (testPadPattern)
+ return applyPadPattern(getFunction(), packPaddings, hoistPaddings);
if (testInterchangePattern.hasValue())
return applyInterchangePattern(getFunction(), testInterchangePattern);
}
More information about the Mlir-commits
mailing list