[llvm-branch-commits] [mlir] 3747eb9 - [mlir][Linalg] Add a padding option to Linalg tiling
Nicolas Vasilache via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Jan 25 01:26:04 PST 2021
Author: Nicolas Vasilache
Date: 2021-01-25T09:17:30Z
New Revision: 3747eb9c85b3393aa00ad12e9e7ef31ffec8bd4c
URL: https://github.com/llvm/llvm-project/commit/3747eb9c85b3393aa00ad12e9e7ef31ffec8bd4c
DIFF: https://github.com/llvm/llvm-project/commit/3747eb9c85b3393aa00ad12e9e7ef31ffec8bd4c.diff
LOG: [mlir][Linalg] Add a padding option to Linalg tiling
This revision allows the base Linalg tiling pattern to optionally require padding to
a constant bounding shape.
When requested, a simple analysis is performed, similar to buffer promotion.
A temporary `linalg.simple_pad` op is added to model padding for the purpose of
connecting the dots. This will be replaced by a more fleshed out `linalg.pad_tensor`
op when it is available.
In the meantime, this temporary op serves the purpose of exhibiting the necessary
properties required from a more fleshed out pad op, to compose with transformations
properly.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D95149
Added:
mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Interfaces/ViewLikeInterface.h
mlir/include/mlir/Interfaces/ViewLikeInterface.td
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index ae9f81d043f5..9ea1bc5a3587 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -475,6 +475,38 @@ def Linalg_SliceOp : Linalg_Op<"slice", [
let hasFolder = 1;
}
+def Linalg_SimplePadOp : Linalg_Op<"simple_pad", [NoSideEffect]> {
+ let summary = "TODO: replace with pad_tensors when ready.";
+
+ let description = [{
+ `linalg.simple_pad` is a tmp placeholder for padding and packing on tensors.
+ Its semantics are to pad a partially dynamic tensor to a fully static tensor
+ where the static sizes are assumed to be greater than the dynamic sizes. The
+ op perforrms "high" padding (i.e. it adds trailing padding values until the
+ desired size is met).
+ }];
+
+ let arguments = (ins AnyRankedTensor:$tensor, AnyType:$padding);
+ let results = (outs AnyRankedTensor:$result);
+
+ // TODO: verify all static result, some dynamic input, static shapes match,
+ // element types match, ranks match etc. Use pad_tensors when ready but for
+ // now just let it ne fully specified by traits.
+ let verifier = ?;
+
+ let extraClassDeclaration = [{
+ RankedTensorType getSourceType() {
+ return tensor().getType().cast<RankedTensorType>(); }
+ RankedTensorType getResultType() {
+ return getResult().getType().cast<RankedTensorType>(); }
+ }];
+
+ let assemblyFormat = [{
+ $tensor `pad` $padding attr-dict `:`
+ type($tensor) `to` type($result) `pad` type($padding)
+ }];
+}
+
def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>,
Arguments<(ins Variadic<AnyType>:$values)> {
let summary = "Linalg yield operation";
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 611ab6867372..f359992e5ff1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -345,6 +345,9 @@ enum class LinalgTilingLoopType {
using TileSizeComputationFunction =
std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
+using PaddingValueComputationFunction =
+ std::function<Value(OpBuilder &, Operation *)>;
+
struct LinalgTilingOptions {
/// Computation function that returns the tile sizes for each operation.
/// Delayed construction of constant tile sizes should occur to interoperate
@@ -393,6 +396,18 @@ struct LinalgTilingOptions {
distribution = std::move(distributionOptions);
return *this;
}
+
+ /// Computation function that returns a padding value to use when padding to
+ /// force static sizes. When `paddingValueComputationFunction` is set, padding
+ /// operations are introduced, that guarantee the underlying op is statically
+ /// shaped and can thus be vectorized.
+ PaddingValueComputationFunction paddingValueComputationFunction = nullptr;
+
+ LinalgTilingOptions &
+ setPaddingValueComputationFunction(PaddingValueComputationFunction fun) {
+ paddingValueComputationFunction = std::move(fun);
+ return *this;
+ }
};
/// Canonicalization patterns relevant to apply after tiling patterns. These are
@@ -403,6 +418,11 @@ getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
void populateLinalgTilingCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx);
+/// Base pattern that applied the tiling transformation specified by `options`.
+/// Abort and return failure in 2 cases:
+/// 1. if the tiling specification is invalid and tiling fails to occur.
+/// 2. if tiling occurs but `options.paddingValueComputationFunction` is set
+/// and some operand shape cannot be bounded statically.
struct LinalgBaseTilingPattern : public RewritePattern {
// Entry point to match any LinalgOp OpInterface.
LinalgBaseTilingPattern(LinalgTilingOptions options,
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index d5f44c3e63da..2b3a054338ab 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -14,6 +14,7 @@
#define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpImplementation.h"
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
index 62c371b2f97d..6c72b47f2ac3 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td
@@ -108,6 +108,28 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
return $_op.sizes();
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return a vector of all the static or dynamic sizes of the op.
+ }],
+ /*retTy=*/"SmallVector<OpFoldResult, 4>",
+ /*methodName=*/"getMixedSizes",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ SmallVector<OpFoldResult, 4> res;
+ std::array<unsigned, 3> ranks = $_op.getArrayAttrRanks();
+ unsigned numDynamic = 0;
+ unsigned count = ranks[getOffsetOperandGroupPosition()];
+ for (unsigned idx = 0; idx < count; ++idx) {
+ if (isDynamicSize(idx))
+ res.push_back($_op.sizes()[numDynamic++]);
+ else
+ res.push_back($_op.static_sizes()[idx]);
+ }
+ return res;
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Return the dynamic stride operands.
@@ -359,6 +381,9 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
];
let extraClassDeclaration = [{
+ static unsigned getOffsetOperandGroupPosition() { return 0; }
+ static unsigned getSizeOperandGroupPosition() { return 1; }
+ static unsigned getStrideOperandGroupPosition() { return 2; }
static StringRef getStaticOffsetsAttrName() {
return "static_offsets";
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 283ff20f611b..a76d70c8cd5f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -25,6 +25,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
@@ -105,6 +106,118 @@ mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
return *this;
}
+/// Try to compute a static bounding box for `operand`
+/// Return success if either:
+/// 1. The operand is already statically shaped, `result` is left unchanged.
+/// 2. The operand is (partially) dynamic, `result` is the result of a freshly
+/// created SimplePadOp.
+/// Return failure if the operand cannot be padded to a static shape.
+static LogicalResult padOperandToSmallestStaticBoundingBox(
+ PatternRewriter &rewriter, linalg::LinalgOp opToPad, Value operand,
+ const LinalgTilingOptions &options, Value &result) {
+ auto tensorType = operand.getType().cast<RankedTensorType>();
+ // Already static shape, no need to pad.
+ if (tensorType.hasStaticShape())
+ return success();
+ auto subtensor = operand.getDefiningOp<SubTensorOp>();
+ // Not a subtensor, cannot construct a static bounding box.
+ if (!subtensor)
+ return failure();
+ SmallVector<int64_t> staticSizes;
+ staticSizes.reserve(tensorType.getRank());
+ auto shapedOp =
+ cast<OffsetSizeAndStrideOpInterface>(subtensor.getOperation());
+ for (auto size : shapedOp.getMixedSizes()) {
+ auto indexAttr = size.is<Attribute>()
+ ? size.get<Attribute>().dyn_cast<IntegerAttr>()
+ : linalg::getSmallestBoundingIndex(size.get<Value>());
+ // SmallestBoundingIndex must exist for all sizes.
+ // For now return an error if we can't find it.
+ if (!indexAttr)
+ return rewriter.notifyMatchFailure(
+ opToPad, "No constant bounding box can be found for padding");
+ staticSizes.push_back(indexAttr.getInt());
+ }
+ Value pad = options.paddingValueComputationFunction(rewriter, opToPad);
+ auto staticTensorType =
+ RankedTensorType::get(staticSizes, tensorType.getElementType());
+ result = rewriter.create<linalg::SimplePadOp>(opToPad->getLoc(),
+ staticTensorType, operand, pad);
+ return success();
+}
+
+// Try to create a static bounding box around each operand of `res.op`.
+// If successful, `res.op` is rewritten in static form with padded operands.
+// `res.op` is updated to the cloned static form of the op on success.
+static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
+ TiledLinalgOp &res,
+ const LinalgTilingOptions &options) {
+ LinalgOp opToPad = res.op;
+ Location loc = opToPad->getLoc();
+
+ // If the op is fully static, it does not need padding.
+ // TODO: there are cases where we may still want to pad to larger sizes.
+ if (llvm::all_of(opToPad.getShapedOperands(), [](Value v) {
+ return v.getType().cast<RankedTensorType>().hasStaticShape();
+ }))
+ return success();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ // Set IP after op because we also take the dims of the original output.
+ rewriter.setInsertionPointAfter(opToPad);
+ // Make a copy of the shaped operands and update it.
+ SmallVector<Value> operands = opToPad.getShapedOperands();
+ for (Value &v : operands) {
+ Value paddedOperand;
+ // If padding was requested but the shape cannot be bounded statically then
+ // the pattern fails to apply.
+ if (failed(padOperandToSmallestStaticBoundingBox(rewriter, opToPad, v,
+ options, paddedOperand))) {
+ return failure();
+ }
+ // Update v if we indeed got a padded operand.
+ v = paddedOperand ? paddedOperand : v;
+ }
+
+ // Clone `opToPad` to operate on the statically padded shapes.
+ auto resultTensorTypes =
+ ValueRange(operands).take_back(opToPad.getNumOutputs()).getTypes();
+ ValueRange otherOperands = opToPad.getAssumedNonShapedOperands();
+ operands.append(otherOperands.begin(), otherOperands.end());
+ linalg::LinalgOp paddedOp =
+ opToPad.clone(rewriter, loc, resultTensorTypes, operands);
+
+ // Recover the subtensor out of the new static results. This keeps the
+ // original linalg op around because it uses the dims of the original results.
+ // This later folds away.
+ SmallVector<Value> paddedSubviewResults;
+ paddedSubviewResults.reserve(opToPad->getNumResults());
+ Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+ llvm::SetVector<Operation *> newUsersOfOpToPad;
+ for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) {
+ auto rank = std::get<0>(it).getType().cast<RankedTensorType>().getRank();
+ SmallVector<Value> offsets(rank, zero);
+ auto sizes = llvm::to_vector<4>(
+ llvm::map_range(llvm::seq<unsigned>(0, rank), [&](unsigned d) -> Value {
+ auto dimOp = rewriter.create<DimOp>(loc, std::get<0>(it), d);
+ newUsersOfOpToPad.insert(dimOp);
+ return dimOp;
+ }));
+ SmallVector<Value> strides(rank, one);
+ paddedSubviewResults.push_back(rewriter.create<SubTensorOp>(
+ loc, std::get<1>(it), offsets, sizes, strides));
+ }
+ // Replace the transient `opToPad` locally, except for uses that we just
+ // created for the purpose of extracting the dims.
+ rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) {
+ return !newUsersOfOpToPad.contains(opOp.getOwner());
+ });
+
+ res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults};
+ return success();
+}
+
/// Linalg base tiling pattern.
mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
StringRef opName, MLIRContext *context, LinalgTilingOptions options,
@@ -130,11 +243,34 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
if (!res)
return failure();
- // Return relevant information to derived pattern.
- result = *res;
+ // Setup RAII guard to return properly.
+ bool succeeded = true;
+ LinalgOp tiledOp = res->op;
+ auto guard = llvm::make_scope_exit([&]() {
+ if (!succeeded)
+ return;
+ // Return relevant information to derived pattern.
+ result = *res;
+ // Replace marker on both tiledOp and tiledAndPaddedOp, if necessary.
+ marker.replaceLinalgMarker(rewriter, tiledOp);
+ if (tiledOp != res->op)
+ marker.replaceLinalgMarker(rewriter, res->op);
+ });
+
+ // Consider padding on the fly only if the op has tensor semantics.
+ if (!options.paddingValueComputationFunction ||
+ !linalgOp.hasTensorSemantics())
+ return success();
+
+ // Try to pad on the fly by rewriting res->op as a padded op.
+ if (failed(rewriteAsPaddedOp(rewriter, *res, options))) {
+ // Set so RAII guard does not propagate TiledLinalgOp to `result`.
+ succeeded = false;
+ return failure();
+ }
- // New marker if specified.
- marker.replaceLinalgMarker(rewriter, res->op.getOperation());
+ // Do not perform replacement of `linalgOp`, let the derived patterns
+ // do this as they see fit, from the resulting TiledLinalgOp.
return success();
}
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 45dd0fd0086a..b8671cfe48fe 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1411,13 +1411,20 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
return Value{*dynExtents};
}
+ // The size at the given index is now known to be a dynamic size.
+ unsigned unsignedIndex = index.getValue().getZExtValue();
+
+ if (auto subtensor = dyn_cast_or_null<SubTensorOp>(definingOp)) {
+ assert(subtensor.isDynamicSize(unsignedIndex) &&
+ "Expected dynamic subtensor size");
+ return subtensor.getDynamicSize(unsignedIndex);
+ }
+
// Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
auto memrefType = argTy.dyn_cast<MemRefType>();
if (!memrefType)
return {};
- // The size at the given index is now known to be a dynamic size of a memref.
- unsigned unsignedIndex = index.getValue().getZExtValue();
if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
return *(alloc.getDynamicSizes().begin() +
memrefType.getDynamicDimIndex(unsignedIndex));
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 44743eaedc8c..6dc0768bc2e3 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -814,3 +814,13 @@ func @fill_tensor(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK: %{{.+}} = linalg.fill(%{{.+}}, %{{.+}}) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
+
+// -----
+
+// TODO: this op should disappear once pad_tensors is available and connected.
+// CHECK-LABEL: func @simple_pad
+func @simple_pad(%0: tensor<?x4x?xf32>, %pad: f32) {
+// CHECK: linalg.simple_pad %{{.+}} pad %{{.+}}: tensor<?x4x?xf32> to tensor<8x4x8xf32>
+ %1 = linalg.simple_pad %0 pad %pad: tensor<?x4x?xf32> to tensor<8x4x8xf32> pad f32
+ return
+}
diff --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
new file mode 100644
index 000000000000..e4121083e240
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-tile-and-pad-pattern -canonicalize | FileCheck %s
+
+// CHECK-LABEL: func @matmul_tensors(
+// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func @matmul_tensors(
+ %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
+ -> tensor<?x?xf32> {
+// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[sTA:.*]] = subtensor %[[TA]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[sTB:.*]] = subtensor %[[TB]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+// CHECK: %[[sTC:.*]] = subtensor %[[TC2]][{{.*}}] : tensor<?x?xf32> to tensor<?x?xf32>
+
+// Dynamic op has been canonicalized away.
+// CHECK-NOT: linalg.matmul {{.*}} tensor<?x?xf32>
+
+// Padding injects static information.
+// CHECK: %[[pA:.*]] = linalg.simple_pad %[[sTA]] pad %{{.*}} : tensor<?x?xf32> to tensor<2x4xf32> pad f32
+// CHECK: %[[pB:.*]] = linalg.simple_pad %[[sTB]] pad %{{.*}} : tensor<?x?xf32> to tensor<4x3xf32> pad f32
+// CHECK: %[[pC:.*]] = linalg.simple_pad %[[sTC]] pad %{{.*}} : tensor<?x?xf32> to tensor<2x3xf32> pad f32
+// CHECK: %[[pD:.*]] = linalg.matmul ins(%[[pA]], %[[pB]] : tensor<2x4xf32>, tensor<4x3xf32>)
+// CHECK-SAME: outs(%[[pC]] : tensor<2x3xf32>) -> tensor<2x3xf32>
+// CHECK: %[[sTD:.*]] = subtensor %[[pD]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<2x3xf32> to tensor<?x?xf32>
+// CHECK: %[[TD:.*]] = subtensor_insert %[[sTD]] into %[[TC2]][{{.*}}] : tensor<?x?xf32> into tensor<?x?xf32>
+// CHECK: scf.yield %[[TD]] : tensor<?x?xf32>
+// CHECK: scf.yield %[[TD2]] : tensor<?x?xf32>
+// CHECK: scf.yield %[[TD1]] : tensor<?x?xf32>
+ %0 = linalg.matmul {__internal_linalg_transform__ = "tile-and-pad"}
+ ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2: tensor<?x?xf32>)
+ -> tensor<?x?xf32>
+
+// CHECK: return %[[TD0]] : tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index c2b4c7b9c821..87f81dbbf1fd 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -79,6 +79,9 @@ struct TestLinalgTransforms
*this, "test-affine-min-scf-canonicalization-patterns",
llvm::cl::desc("Test affine-min + scf canonicalization patterns."),
llvm::cl::init(false)};
+ Option<bool> testTileAndPadPattern{
+ *this, "test-tile-and-pad-pattern",
+ llvm::cl::desc("Test tile and pad pattern"), llvm::cl::init(false)};
};
} // end anonymous namespace
@@ -487,6 +490,27 @@ static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
applyOpPatternsAndFold(minOp, frozenPatterns);
});
}
+
+// For now, just assume it is the zero of type.
+// In the future, it should be the zero of type + op.
+static Value getNeutralOfLinalgOp(OpBuilder &b, Operation *op) {
+ auto t = op->getResult(0).getType().cast<ShapedType>().getElementType();
+ return b.create<ConstantOp>(op->getLoc(), t, b.getZeroAttr(t));
+}
+
+static void applyTileAndPadPattern(FuncOp funcOp) {
+ MLIRContext *context = funcOp.getContext();
+ OwningRewritePatternList tilingPattern;
+ auto linalgTilingOptions =
+ linalg::LinalgTilingOptions()
+ .setTileSizes({2, 3, 4})
+ .setPaddingValueComputationFunction(getNeutralOfLinalgOp);
+ tilingPattern.insert<linalg::LinalgTilingPattern<linalg::MatmulOp>>(
+ context, linalgTilingOptions,
+ linalg::LinalgMarker(Identifier::get("tile-and-pad", context)));
+ applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
+}
+
/// Apply transformations specified as patterns.
void TestLinalgTransforms::runOnFunction() {
auto lambda = [&](void *) {
@@ -520,6 +544,8 @@ void TestLinalgTransforms::runOnFunction() {
return applyLinalgToVectorPatterns(getFunction());
if (testAffineMinSCFCanonicalizationPatterns)
return applyAffineMinSCFCanonicalizationPatterns(getFunction());
+ if (testTileAndPadPattern)
+ return applyTileAndPadPattern(getFunction());
}
namespace mlir {
More information about the llvm-branch-commits
mailing list