[Mlir-commits] [mlir] 97c1a24 - [mlir][linalg] Add option to pad dynamic dims to `linalg::rewriteAsPaddedOp` (#144354)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 19 02:47:48 PDT 2025
Author: Fabian Mora
Date: 2025-06-19T11:47:44+02:00
New Revision: 97c1a2444574b32dd7a283c53be248c5dbbf62e9
URL: https://github.com/llvm/llvm-project/commit/97c1a2444574b32dd7a283c53be248c5dbbf62e9
DIFF: https://github.com/llvm/llvm-project/commit/97c1a2444574b32dd7a283c53be248c5dbbf62e9.diff
LOG: [mlir][linalg] Add option to pad dynamic dims to `linalg::rewriteAsPaddedOp` (#144354)
This patch makes the following changes:
- Add a `ValueRange typeDynDims` argument to
`linalg::makeComposedPadHighOp`, allowing to pad a tensor with dynamic
dimensions using `tensor::createPadHighOp`.
- Add a `DenseMap<std::pair<unsigned, unsigned>, OpFoldResult>
sizeToPadTo;` option to `LinalgPaddingOptions`. This option allows
setting the size to use when padding a dimension of an operand, allowing
to pad operands even in the case they don't have a constant upper
bounding box. If the value is not provided, then the constant upper
bound is used by default.
- Add a `use_prescribed_tensor_shapes` option to
`transform.structured.pad`. If set to true then `tensor.dim` will be
used as dimensions to compute the size of the padded dim instead of
computing the constant upper bound.
- This patch also changes the behavior for computing the padded shape
`linalg::rewriteAsPaddedOp`, by using the newly added options in
`LinalgPaddingOptions`.
- Finally it adds tests verifying the behavior.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/test/Dialect/Linalg/transform-op-pad.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 15ea5e7bf7159..6f6df350f1ba6 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1134,7 +1134,8 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
DefaultValuedAttr<
TypedArrayAttrBase<I64ArrayAttr, "array of arrays of i64">,
"{}">:$transpose_paddings,
- DefaultValuedAttr<StrAttr, "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copy_back_op);
+ DefaultValuedAttr<StrAttr, "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copy_back_op,
+ DefaultValuedAttr<UnitAttr, "false">:$use_prescribed_tensor_shapes);
let results = (outs TransformHandleTypeInterface:$padded,
TransformHandleTypeInterface:$pad,
TransformHandleTypeInterface:$copy);
@@ -1142,6 +1143,7 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
let assemblyFormat = [{
$target
(`pad_to_multiple_of` custom<DynamicIndexList>($pad_to_multiple_of, $static_pad_to_multiple_of)^)?
+ (`use_prescribed_tensor_shapes` $use_prescribed_tensor_shapes^)?
attr-dict
`:` functional-type(operands, results)
}];
@@ -1159,13 +1161,15 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
CArg<"ArrayRef<int64_t>", "{}">:$staticPadToMultipleOf,
CArg<"ArrayRef<int64_t>", "{}">:$nofoldFlags,
CArg<"ArrayRef<Attribute>", "{}">:$transposePaddings,
- CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)>,
+ CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp,
+ CArg<"bool", "false">:$usePrescribedTensorShapes)>,
OpBuilder<(ins "Value":$target,
"ArrayRef<int64_t>":$paddingDimensions,
"ArrayRef<OpFoldResult>":$mixedPadToMultipleOf,
CArg<"ArrayRef<int64_t>", "{}">:$nofoldFlags,
CArg<"ArrayRef<Attribute>", "{}">:$transposePaddings,
- CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)>
+ CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp,
+ CArg<"bool", "false">:$usePrescribedTensorShapes)>
];
let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..147a2907f52e4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -295,6 +295,23 @@ struct LinalgPaddingOptions {
padToMultipleOf.emplace(m.begin(), m.end());
return *this;
}
+ /// A mapping between an operand and shape dim, and a size for a padding
+ /// dimension. Each size is expected to be greater or equal than the
+ /// corresponding shape dim. If no value is provided then the constant upper
+ /// bound will be used.
+ DenseMap<std::pair<unsigned, unsigned>, OpFoldResult> sizeToPadTo;
+ LinalgPaddingOptions &setSizeToPadTo(unsigned operandIndex, unsigned dimIndex,
+ OpFoldResult size) {
+ assert(size && "expected non-null size");
+ sizeToPadTo[{operandIndex, dimIndex}] = size;
+ return *this;
+ }
+ /// Given the operand index and shape dim it returns the size to pad to.
+ OpFoldResult getSizeToPadTo(unsigned operandIndex, unsigned dimIndex) const {
+ return sizeToPadTo.lookup_or(
+ std::pair<unsigned, unsigned>(operandIndex, dimIndex), nullptr);
+ }
+
/// A flag for every operand to mark the PadOp as nofold which enables
/// packing for statically shaped operands.
SmallVector<bool> nofoldFlags;
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 80aa034d2199d..fc151d02ceef6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -71,12 +71,14 @@ bool isParallelIterator(utils::IteratorType iteratorType);
/// Check if iterator type has "reduction" semantics.
bool isReductionIterator(utils::IteratorType iteratorType);
-/// Create a tensor::PadOp that pads `source` to the size of the statically
-/// sized `type` whose static sizes are assumed to be greater than the dynamic
-/// `source` size. The padding introduces trailing `pad` values until the
-/// target size is met. If `source` is defined by one or more LinalgOps that
-/// have been padded with the same value and sizes, return their padded result
-/// instead of creating a tensor::PadOp.
+/// Create a tensor::PadOp that pads `source` to the shape of `type` whose sizes
+/// are assumed to be greater than the dynamic `source` size. If `typeDynDims`
+/// is specified, then it must contain the sizes of all the dynamic dimensions
+/// in order of appearance in `type`, otherwise the function will pad those
+/// values to `0`. The padding introduces trailing `pad` values until the target
+/// size is met. If `source` is defined by one or more LinalgOps that have been
+/// padded with the same value and sizes, return their padded result instead of
+/// creating a tensor::PadOp.
///
/// Example:
/// ```
@@ -91,7 +93,8 @@ bool isReductionIterator(utils::IteratorType iteratorType);
/// %4 = tensor.pad %3 low[0, 0] high[...] { tensor.yield %other_cst }
/// ```
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
- Value source, Value pad, bool nofold);
+ Value source, Value padding, bool nofold,
+ ValueRange typeDynDims = std::nullopt);
/// Returns GenericOp that copies an n-D memref. Unlike the current
/// implementation of memref::CopyOp, this op can further tile, lower to loops
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b2c28f5eed33c..d78c8847f8843 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1907,7 +1907,8 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
ArrayRef<int64_t> padToMultipleOf,
ArrayRef<int64_t> nofoldFlags,
ArrayRef<Attribute> transposePaddings,
- StringRef copyBackOp) {
+ StringRef copyBackOp,
+ bool usePrescribedTensorShapes) {
auto resultType = transform::AnyOpType::get(b.getContext());
return build(/*builder=*/b,
/*result=*/result,
@@ -1922,7 +1923,9 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
: b.getDenseI64ArrayAttr(padToMultipleOf)),
/*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
/*transposePaddings=*/b.getArrayAttr(transposePaddings),
- /*copyBackOp=*/b.getStringAttr(copyBackOp));
+ /*copyBackOp=*/b.getStringAttr(copyBackOp),
+ /*usePrescribedTensorShapes=*/
+ usePrescribedTensorShapes ? b.getUnitAttr() : nullptr);
}
void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
@@ -1930,7 +1933,8 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
ArrayRef<OpFoldResult> mixedPadToMultipleOf,
ArrayRef<int64_t> nofoldFlags,
ArrayRef<Attribute> transposePaddings,
- StringRef copyBackOp) {
+ StringRef copyBackOp,
+ bool usePrescribedTensorShapes) {
auto resultType = transform::AnyOpType::get(b.getContext());
SmallVector<int64_t> staticPadToMultipleOf;
SmallVector<Value> dynamicPadToMultipleOf;
@@ -1946,7 +1950,8 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
/*padToMultipleOf=*/staticPadToMultipleOf,
/*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
/*transposePaddings=*/b.getArrayAttr(transposePaddings),
- /*copyBackOp=*/b.getStringAttr(copyBackOp));
+ /*copyBackOp=*/copyBackOp,
+ /*usePrescribedTensorShapes=*/usePrescribedTensorShapes);
}
void PadOp::getEffects(
@@ -2051,11 +2056,34 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
} else {
llvm_unreachable("unsupported copy_back op");
}
+ // Populate `sizeToPadTo` with the dynamic tensor sizes for each operand.
+ bool irChanged = false;
+ if (getUsePrescribedTensorShapes() &&
+ linalgTarget.hasPureTensorSemantics()) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(linalgTarget);
+ for (OpOperand &operand : linalgTarget->getOpOperands()) {
+ for (auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
+ if (!ShapedType::isDynamic(dim))
+ continue;
+ options.setSizeToPadTo(operand.getOperandNumber(), i,
+ tensor::getMixedSize(rewriter,
+ operand.get().getLoc(),
+ operand.get(), i));
+ irChanged = true;
+ }
+ }
+ }
SmallVector<Value> replacements;
SmallVector<tensor::PadOp> newPadOps;
if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
replacements, newPadOps))) {
+ if (irChanged) {
+ auto diag = emitDefiniteFailure() << "failed to pad op";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
auto diag = emitSilenceableError() << "failed to pad op";
diag.attachNote(target->getLoc()) << "target op";
return diag;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index 9a685f6dc96ac..dc9e11eccac4d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -22,53 +23,93 @@ using namespace mlir::linalg;
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
#define DBGSNL() (llvm::dbgs() << "\n")
-/// Compute the padded shape of the given operand. The operand is padded to a
-/// static bounding box according to the specified padding options.
-static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
+namespace {
+/// Helper class for storing padding information.
+struct PaddingInfo {
+ PaddingInfo(int64_t padToMultipleOf = 1, OpFoldResult size = {})
+ : padToMultipleOf(padToMultipleOf), size(size) {}
+ /// Pad the tensor to a multiple of.
+ int64_t padToMultipleOf = 1;
+ /// The size used for padding.
+ OpFoldResult size = {};
+};
+
+/// Helper class for storing and computing the padded shape.
+struct PaddedShape {
+ /// Initializes the shape information and on success it returns whether the
+ /// shape of the operand will change. Returns failure if the operand cannot be
+ /// padded.
+ FailureOr<bool> initialize(linalg::LinalgOp opToPad, OpOperand *opOperand,
+ const LinalgPaddingOptions &options);
+
+ /// Computs the padded shape.
+ void computePadding(OpBuilder &builder, Value operand);
+
+ /// Returns the new tensor type.
+ RankedTensorType getType(Type elemTy) {
+ return RankedTensorType::get(shape, elemTy);
+ }
+
+ SmallVector<Value> dynDims;
+
+private:
+ SmallVector<int64_t> shape;
+ DenseMap<int64_t, PaddingInfo> dimToInfo;
+};
+} // namespace
+
+FailureOr<bool> PaddedShape::initialize(linalg::LinalgOp opToPad,
OpOperand *opOperand,
- const LinalgPaddingOptions &options,
- SmallVector<int64_t> &paddedShape,
- bool &alreadyHasRequestedShape) {
+ const LinalgPaddingOptions &options) {
AffineMap indexingMap = opToPad.getMatchingIndexingMap(opOperand);
- ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
+
+ // Initialize the padded shape.
+ llvm::append_range(shape, opToPad.getShape(opOperand));
// Collect the shape dimensions that are a function of "paddingDimensions",
// along with the multiple that they should be padded to ("1" if none).
- alreadyHasRequestedShape = true;
- DenseMap<int64_t, int64_t> shapeDimToMultiple;
+ bool alreadyHasRequestedShape = true;
for (const auto &dimEn : enumerate(options.paddingDimensions)) {
for (const auto &en : enumerate(indexingMap.getResults())) {
if (en.value().isFunctionOfDim(dimEn.value())) {
+ PaddingInfo paddingInfo;
int64_t dimSize = shape[en.index()];
if (options.padToMultipleOf.has_value()) {
- shapeDimToMultiple[en.index()] =
+ paddingInfo.padToMultipleOf =
(*options.padToMultipleOf)[dimEn.index()];
} else {
- shapeDimToMultiple[en.index()] = 1;
+ paddingInfo.padToMultipleOf = 1;
}
- if (ShapedType::isDynamic(dimSize)) {
- alreadyHasRequestedShape = false;
- } else if (dimSize % shapeDimToMultiple[en.index()] != 0) {
+
+ // Check if the user provided a size in the options.
+ paddingInfo.size =
+ options.getSizeToPadTo(opOperand->getOperandNumber(), en.index());
+
+ // Set the padding info.
+ dimToInfo[en.index()] = paddingInfo;
+ if (ShapedType::isDynamic(dimSize) ||
+ dimSize % paddingInfo.padToMultipleOf != 0 ||
+ !paddingInfo.size.isNull()) {
alreadyHasRequestedShape = false;
}
}
}
}
- // Helper function to round a number up to a given multiple.
- auto ceil = [](int64_t val, int64_t multiple) {
- return ((val + multiple - 1) / multiple) * multiple;
- };
-
// Upper bound the sizes to obtain a static bounding box.
- paddedShape.assign(shape.begin(), shape.end());
for (int64_t i = 0, e = shape.size(); i < e; ++i) {
- LLVM_DEBUG(DBGS() << "--compute padded size for dim " << i << "\n");
+ LLVM_DEBUG(DBGS() << "--computing un-padded size for dim " << i << "\n");
// Skip dimensions that do not require padding.
- if (!shapeDimToMultiple.contains(i)) {
+ if (!dimToInfo.contains(i)) {
LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n");
continue;
}
+ PaddingInfo &info = dimToInfo[i];
+ if (info.size) {
+ LLVM_DEBUG(DBGS() << "----the user provided the size: " << info.size
+ << "\n");
+ continue;
+ }
// Otherwise, try to compute a constant upper bound for the size value.
FailureOr<int64_t> upperBound =
ValueBoundsConstraintSet::computeConstantBound(
@@ -77,14 +118,58 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
/*dim=*/i},
/*stopCondition=*/nullptr, /*closedUB=*/true);
if (failed(upperBound)) {
- LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
+ LLVM_DEBUG(
+ DBGS() << "----could not compute a bounding box for padding\n");
return failure();
}
- paddedShape[i] = ceil(*upperBound, shapeDimToMultiple[i]);
- LLVM_DEBUG(DBGS() << "----new dim size: " << paddedShape[i] << "\n");
+ info.size =
+ IntegerAttr::get(IndexType::get(opToPad.getContext()), *upperBound);
+ LLVM_DEBUG(DBGS() << "----new un-padded size: " << info.size << "\n");
}
+ return alreadyHasRequestedShape;
+}
- return success();
+void PaddedShape::computePadding(OpBuilder &builder, Value operand) {
+ Location loc = operand.getLoc();
+ AffineExpr sizeSym = builder.getAffineSymbolExpr(0);
+
+ // Compute the padding for each dimension.
+ for (auto &&[i, dim] : llvm::enumerate(shape)) {
+ LLVM_DEBUG(DBGS() << "--computing padded size for dim " << i << "\n");
+
+ // Get the padding info or default info for the shape dimension.
+ PaddingInfo paddingInfo = dimToInfo.lookup(i);
+
+ // Skip dimensions that do not require padding.
+ if (paddingInfo.size.isNull()) {
+ LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n");
+
+ // We still need to push the size as `makeComposedPadHighOp` expects a
+ // range with all the dynamic sizes, whether they're being padded or not.
+ if (ShapedType::isDynamic(dim)) {
+ dynDims.push_back(
+ cast<Value>(tensor::getMixedSize(builder, loc, operand, i)));
+ }
+ continue;
+ }
+
+ // Compute the padded size to be a multiple of `padToMultipleOf`.
+ AffineExpr szExpr = (sizeSym).ceilDiv(paddingInfo.padToMultipleOf) *
+ paddingInfo.padToMultipleOf;
+ OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
+ builder, loc, szExpr, paddingInfo.size);
+ assert(paddedSize && "invalid arguments to affine apply");
+
+ if (auto cstSzAttr = dyn_cast<Attribute>(paddedSize)) {
+ // Update the shape as the size is static.
+ dim = cast<IntegerAttr>(cstSzAttr).getValue().getZExtValue();
+ } else {
+ // Add a dynamic dimension.
+ dim = ShapedType::kDynamic;
+ dynDims.push_back(cast<Value>(paddedSize));
+ }
+ LLVM_DEBUG(DBGS() << "----new dim size: " << paddedSize << "\n");
+ }
}
/// Pad the `opOperand` in the "paddingDimensions" using the padding value and
@@ -107,20 +192,21 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
options.padToMultipleOf->size() == options.paddingDimensions.size()) &&
"invalid number of elements in padToMultipleOf");
- // Compute padded shape.
- SmallVector<int64_t> paddedShape;
- bool alreadyHasRequestedShape = false;
- if (failed(computePaddedShape(opToPad, opOperand, options, paddedShape,
- alreadyHasRequestedShape)))
+ // Initialize the padded shape and get whether it requires padding.
+ PaddedShape shape;
+ FailureOr<bool> alreadyHasRequestedShape =
+ shape.initialize(opToPad, opOperand, options);
+ if (failed(alreadyHasRequestedShape)) {
return rewriter.notifyMatchFailure(opToPad,
"--failed to compute padded shape");
+ }
- // Return the unpadded operand if padding to a static shape is not needed and
+ // Return the un-padded operand if padding to a static shape is not needed and
// if the nofold flag is not set.
bool nofold = opOperand->getOperandNumber() < options.nofoldFlags.size()
? bool(options.nofoldFlags[opOperand->getOperandNumber()])
: false;
- if (!nofold && alreadyHasRequestedShape)
+ if (!nofold && *alreadyHasRequestedShape)
return opOperand->get();
// Fail if `paddingValues` specifies no padding value.
@@ -140,13 +226,18 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
}
+ // Computes the padded shape.
+ if (!*alreadyHasRequestedShape)
+ shape.computePadding(rewriter, opOperand->get());
+
// Pad the operand to the bounding box defined by `paddedShape`.
- auto paddedTensorType = RankedTensorType::get(
- paddedShape, getElementTypeOrSelf(opOperand->get()));
+ RankedTensorType paddedTensorType =
+ shape.getType(getElementTypeOrSelf(opOperand->get()));
LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: "
<< paddedTensorType);
return makeComposedPadHighOp(rewriter, opToPad->getLoc(), paddedTensorType,
- opOperand->get(), paddingValue, nofold);
+ opOperand->get(), paddingValue, nofold,
+ shape.dynDims);
}
LogicalResult
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 2527d90cfa2e6..209309ddb413a 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -244,11 +244,13 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
}
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
- Value source, Value pad, bool nofold) {
+ Value source, Value pad, bool nofold,
+ ValueRange typeDynDims) {
// Exit if `source` is not defined by an ExtractSliceOp.
auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>();
if (!sliceOp)
- return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
+ return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
+ typeDynDims);
// Search the `source` use-def chain for padded LinalgOps.
Value current = sliceOp.getSource();
@@ -264,24 +266,28 @@ Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
// Exit if the search fails to match a tensor::PadOp at the end of the matched
// LinalgOp sequence.
if (!padOp)
- return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
+ return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
+ typeDynDims);
// Exit if the padded result type does not match.
if (sliceOp.getSource().getType() != type)
- return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
+ return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
+ typeDynDims);
// Exit if the LinalgOps are not high padded.
if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) {
return getConstantIntValue(ofr) != static_cast<int64_t>(0);
}))
- return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
+ return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
+ typeDynDims);
// Exit if `padOpSliceOp`, which defines the slice used by
// `padOp`, is rank-reducing.
auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
if (!padOpSliceOp ||
sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size())
- return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
+ return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
+ typeDynDims);
// Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size
// of the slice padded by `padOp`.
@@ -290,14 +296,16 @@ Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
[](std::tuple<OpFoldResult, OpFoldResult> it) {
return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it));
}))
- return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
+ return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
+ typeDynDims);
// Exit if the padding values do not match.
Attribute padOpPadAttr, padAttr;
Value padOpPad = padOp.getConstantPaddingValue();
if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) ||
!matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr)
- return tensor::createPadHighOp(type, source, pad, nofold, loc, b);
+ return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
+ typeDynDims);
// Return the padded result if the padding values and sizes match.
return sliceOp.getSource();
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index ab2711545405e..bc684b53c9b61 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -300,7 +300,7 @@ func.func @negative_no_ub_estimate(%arg0: tensor<?x12xf32>,
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-error @below {{ailed to pad op}}
+ // expected-error @below {{failed to pad op}}
%padded, %pad, %copy_back = transform.structured.pad %0 {
padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32],
// Note - attempting to pad non-static dim
@@ -313,6 +313,41 @@ module attributes {transform.with_named_sequence} {
// -----
+// Test dynamic padding using `use_prescribed_tensor_shapes`
+
+// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 7) * 7)>
+// CHECK: @use_prescribed_tensor_shapes
+// CHECK: (%[[ARG0:.*]]: tensor<?x12xf32>, %[[ARG1:.*]]: tensor<12x?xf32>
+func.func @use_prescribed_tensor_shapes(%arg0: tensor<?x12xf32>,
+ %arg1: tensor<12x?xf32>,
+ %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ // CHECK: %[[C1_0:.*]] = arith.constant 1 : index
+ // CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG1]], %[[C1_0]] : tensor<12x?xf32>
+ // CHECK: %[[PADDING:.*]] = affine.apply #[[MAP]]()[%[[DIM_0]]]
+ // CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG1]] low[0, 0] high[0, %[[PADDING]]] {
+ // CHECK: linalg.matmul ins(%[[ARG0]], %[[PADDED]] : tensor<?x12xf32>, tensor<12x?xf32>)
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x12xf32>, tensor<12x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ func.return %0 : tensor<?x?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %padded, %pad, %copy_back = transform.structured.pad %0
+ pad_to_multiple_of [7] use_prescribed_tensor_shapes {
+ padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32],
+ padding_dimensions=[1]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %func = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.canonicalization
+ } {apply_cse} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// Check that the padding can be applied even when the output argument of the
// linalg op is not produced by an empty op or an extract_slice op.
@@ -416,6 +451,6 @@ module attributes {transform.with_named_sequence} {
padding_dimensions=[0, 1, 2],
nofold_flags=[1, 1, 1]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
- transform.yield
+ transform.yield
}
}
More information about the Mlir-commits
mailing list