[Mlir-commits] [mlir] [mlir][Transforms] Add a PadTilingInterface transformation and hook i… (PR #144991)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 20 00:32:29 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Nicolas Vasilache (nicolasvasilache)
<details>
<summary>Changes</summary>
…t up to the transform dialect
This revision revisits the padding transformation from first principles and prepares it to work more generally with TilingInterface.
Compared to structured.transform.pad it has the following differences:
- no support for nofold, copy-back, transpose and hoisting: these have been carried by the padding op in the very early days of StructuredOps and have since then been separated out as independent transformations that compose.
- no conflated static bounding box derivation attempts: pad_tiling_interface composes more naturally with or without tiling.
- properly derives padding size on outputs where multiple dimensions contribute: this is not supported in structured.transform.pad
- geared towards supporting TilingInterface once the proper control mechanisms are supported through a PadSizeComputationFunction (supports LinalgOp by default)
This will gradually replace structured.transform.pad as it is fleshed out and tested more comprehensively.
---
Patch is 35.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144991.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+76)
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+76-3)
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+161)
- (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp (+322)
- (added) mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir (+73)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 6f6df350f1ba6..c00151dea54a6 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1186,6 +1186,82 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
}];
}
+//===----------------------------------------------------------------------===//
+// PadTilingInterfaceOp
+//===----------------------------------------------------------------------===//
+
+def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interface",
+ [FunctionalStyleTransformOpTrait, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Pads the operations pointed to by the target handle using the options
+ provided as operation attributes. The operation returns a handle to the
+ padded operation and to the padding operation ("tensor.pad").
+
+ #### Return modes
+
+ This operation ignores non-Linalg ops and drops them in the return.
+ In the future, this operation will support all TilingInterfaceOps.
+
+ This operation may produce a definite failure if the padding fails for any
+ reason.
+
+ If all the operations referred to by the `target` handle pad properly, the
+ transform succeeds. Otherwise the transform produces a silenceable failure.
+ The return handle points to only the subset of successfully produced
+ padded operations, which can be empty.
+ }];
+
+ let arguments =
+ (ins TransformHandleTypeInterface:$target,
+ DefaultValuedAttr<ArrayAttr, "{}">:$padding_values,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$padding_dimensions,
+ Variadic<TransformAnyParamTypeOrAnyHandle>:$padding_sizes,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
+ $static_padding_sizes,
+ DefaultValuedAttr<UnitAttr, "false">:$pad_to_multiple_of);
+ let results = (outs TransformHandleTypeInterface:$padded,
+ TransformHandleTypeInterface:$pad);
+
+ let assemblyFormat = [{
+ $target
+ `to`
+ (`padding_sizes` custom<DynamicIndexList>($padding_sizes, $static_padding_sizes)^)?
+ (`pad_to_multiple_of` $pad_to_multiple_of^)?
+ attr-dict
+ `:` functional-type(operands, results)
+ }];
+
+ let hasVerifier = 1;
+
+ let builders = [
+ // Builder for a transform::PadOp with automatic inference of padding
+ // value. Warning: this will set the value 0 for the inferred elemental
+ // type without taking the op into account and thus only work for the
+ // add/mul ring at the moment.
+ // TODO: support other operations (e.g. min, max etc).
+ OpBuilder<(ins "Value":$target,
+ "ArrayRef<int64_t>":$paddingDimensions,
+ CArg<"ArrayRef<int64_t>", "{}">:$staticPaddingSizes,
+ CArg<"bool", "false">:$padToMultipleOf)>,
+ OpBuilder<(ins "Value":$target,
+ "ArrayRef<int64_t>":$paddingDimensions,
+ "ArrayRef<OpFoldResult>":$mixedPadPaddingSizes,
+ CArg<"bool", "false">:$usePrescribedTensorShapes)>
+ ];
+
+ let extraClassDeclaration = [{
+ /// Returns a mix of dynamic `padding_sizes` and static `static_padding_sizes`.
+ SmallVector<OpFoldResult> getMixedPaddingSizes();
+
+ ::mlir::DiagnosedSilenceableFailure apply(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::transform::TransformResults &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// HoistPadOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 147a2907f52e4..59b7fdeef10b3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -20,6 +20,7 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -347,6 +348,34 @@ struct LinalgPaddingOptions {
}
};
+struct PadTilingInterfaceOptions {
+ /// A padding value for every operand.
+ SmallVector<Attribute> paddingValues;
+ PadTilingInterfaceOptions &setPaddingValues(ArrayRef<Attribute> pv) {
+ paddingValues.assign(pv.begin(), pv.end());
+ return *this;
+ }
+ /// A list of iterator dimensions to pad.
+ SmallVector<int64_t> paddingDimensions;
+ PadTilingInterfaceOptions &setPaddingDimensions(ArrayRef<int64_t> pd) {
+ paddingDimensions.assign(pd.begin(), pd.end());
+ return *this;
+ }
+ /// A list of iterator dimensions sizes to pad to.
+ SmallVector<OpFoldResult> paddingSizes;
+ PadTilingInterfaceOptions &setPaddingSizes(ArrayRef<OpFoldResult> m) {
+ paddingSizes.assign(m.begin(), m.end());
+ return *this;
+ }
+ /// Pad iterator `paddingDimension[i]` to next multiple of `paddingSizes[i]`
+ /// if true. Otherwise pad to `paddingSizes[i]`.
+ bool padToMultipleOf;
+ PadTilingInterfaceOptions &setPadToMultipleOf(bool b) {
+ padToMultipleOf = b;
+ return *this;
+ }
+};
+
/// Callback function type used to perform the allocation for the promoted
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
/// smallest constant value for the size of the buffer needed for each
@@ -542,9 +571,9 @@ SmallVector<Value> peelLoop(RewriterBase &rewriter, Operation *op);
/// where relevant.
void peelLoops(RewriterBase &rewriter, ArrayRef<scf::ForOp> loops);
-/// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands
-/// to a static bounding box. The original `opToPad` is cloned and operates on
-/// the padded tensors.
+/// Pad the iterator dimensions `options.paddingDimensions` of all `opToPad`
+/// operands to a static bounding box. The original `opToPad` is cloned and
+/// operates on the padded tensors.
///
/// * "options.padToMultipleOf" indicates that each padding dimension should be
/// padded to the specified multiple.
@@ -561,6 +590,50 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
SmallVector<Value> &replacements,
SmallVector<tensor::PadOp> &padOps);
+/// Helper function to compute the padded shape of the given value `v` of
+/// `RankedTensorType` given:
+/// - the `indexingSizes` as a list of OpFoldResult.
+/// - an `indexingMap` that encodes how the padded shape varies with
+/// increases in `indexingSizes`.
+/// The implementation iteratively combines increases from contributing using
+/// affine.apply operations.
+/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps and
+/// provides a gentle portability path for Linalg-like ops with affine maps.
+/// In the future, more general interfaces can be devised to encode similar
+/// shape evolutions and map between an op and its operands.
+SmallVector<OpFoldResult>
+computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v,
+ AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
+ const PadTilingInterfaceOptions &options);
+
+using PadSizeComputationFunction =
+ std::function<FailureOr<SmallVector<OpFoldResult>>(
+ RewriterBase &, OpOperand &, ArrayRef<Range>,
+ const PadTilingInterfaceOptions &)>;
+
+/// Specific helper for Linalg ops.
+FailureOr<SmallVector<OpFoldResult>>
+computeLinalgPaddedShape(RewriterBase &rewriter, OpOperand &operandToPad,
+ ArrayRef<Range> iterationDomain,
+ const PadTilingInterfaceOptions &options);
+
+/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
+///
+/// * "options.paddingSizes" indicates that each padding dimension should be
+/// padded to the specified padding size.
+/// * "options.padToMultipleOf" indicates that the paddingSizes should be
+// interpreted as the bounding box (dynamic) value to pad to.
+/// * Use "options.paddingValues" to set the padding value of the created
+// tensor::PadOp.
+/// * The tensor::PadOp is returned on success.
+
+FailureOr<TilingInterface>
+rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
+ const PadTilingInterfaceOptions &constOptions,
+ SmallVector<tensor::PadOp> &padOps,
+ PadSizeComputationFunction computePaddingSizeFun =
+ &computeLinalgPaddedShape);
+
namespace detail {
/// Helper struct to hold the results of building a packing loop nest.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d78c8847f8843..bc8ff8e55bf0f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -45,6 +45,7 @@
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
#include <type_traits>
using namespace mlir;
@@ -2155,6 +2156,166 @@ LogicalResult transform::PadOp::verify() {
return success();
}
+//===---------------------------------------------------------------------===//
+// PadTilingInterfaceOp
+//===---------------------------------------------------------------------===//
+
+void transform::PadTilingInterfaceOp::build(OpBuilder &b,
+ OperationState &result,
+ Value target,
+ ArrayRef<int64_t> paddingDimensions,
+ ArrayRef<int64_t> paddingSizes,
+ bool padToMultipleOf) {
+ auto resultType = transform::AnyOpType::get(b.getContext());
+ return build(/*builder=*/b,
+ /*result=*/result,
+ /*types=*/TypeRange{resultType, resultType},
+ /*target=*/target,
+ /*paddingValues=*/ArrayAttr(), // let inference handle this
+ /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
+ /*paddingSizes=*/ValueRange{},
+ /*paddingSizes=*/
+ (paddingSizes.empty() ? DenseI64ArrayAttr()
+ : b.getDenseI64ArrayAttr(paddingSizes)),
+ /*padToMultipleOf=*/
+ padToMultipleOf ? b.getUnitAttr() : nullptr);
+}
+
+void transform::PadTilingInterfaceOp::build(
+ OpBuilder &b, OperationState &result, Value target,
+ ArrayRef<int64_t> paddingDimensions,
+ ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
+ auto resultType = transform::AnyOpType::get(b.getContext());
+ SmallVector<int64_t> staticPaddingSizes;
+ SmallVector<Value> dynamicPaddingSizes;
+ dispatchIndexOpFoldResults(mixedPaddingSizes, dynamicPaddingSizes,
+ staticPaddingSizes);
+ return build(/*builder=*/b,
+ /*result=*/result,
+ /*types=*/TypeRange{resultType, resultType},
+ /*target=*/target,
+ /*paddingValues=*/ArrayAttr(), // let inference handle this
+ /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
+ /*paddingSizes=*/dynamicPaddingSizes,
+ /*paddingSizes=*/staticPaddingSizes,
+ /*usePrescribedTensorShapes=*/padToMultipleOf);
+}
+
+void transform::PadTilingInterfaceOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getPaddingSizesMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
+SmallVector<OpFoldResult>
+transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
+ Builder b(getContext());
+ return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
+}
+
+DiagnosedSilenceableFailure
+transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ SmallVector<Operation *> paddedOps, padOps;
+
+ for (Operation *target : state.getPayloadOps(getTarget())) {
+ auto targetOp = dyn_cast<TilingInterface>(target);
+ if (!targetOp) {
+ auto diag = emitSilenceableError() << "expected TilingInterface target";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ // Only Linalg ops for now, until TilingInterface exposes a loopsToOperand
+ // map / C++ APIs to compute the effect of padding on operands.
+ if (!isa<LinalgOp>(targetOp.getOperation())) {
+ auto diag = emitSilenceableError() << "only LinalgOp supported atm";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ // Convert the padding values to attributes.
+ SmallVector<Attribute> paddingValues;
+ for (auto const &it :
+ llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
+ auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
+ if (!attr) {
+ emitOpError("expects padding values to be typed attributes");
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ Type elementType = getElementTypeOrSelf(std::get<1>(it));
+ // Try to parse string attributes to obtain an attribute of element type.
+ if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
+ auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
+ stringAttr, getContext(), elementType,
+ /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
+ if (!parsedAttr || parsedAttr.getType() != elementType) {
+ auto diag = this->emitOpError("expects a padding that parses to ")
+ << elementType << ", got " << std::get<0>(it);
+ diag.attachNote(targetOp.getLoc()) << "when applied to this op";
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ paddingValues.push_back(parsedAttr);
+ continue;
+ }
+ // Otherwise, add the attribute directly.
+ if (attr.getType() != elementType) {
+ auto diag = this->emitOpError("expects a padding value of type ")
+ << elementType << ", got " << attr;
+ diag.attachNote(targetOp.getLoc()) << "when applied to this op";
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+ paddingValues.push_back(attr);
+ }
+
+ // Set options.
+ TilingInterface paddedOp;
+ PadTilingInterfaceOptions options;
+ options.setPaddingValues(paddingValues)
+ .setPaddingDimensions(
+ extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions()))
+ .setPaddingSizes(getMixedPaddingSizes())
+ .setPadToMultipleOf(getPadToMultipleOf());
+
+ // Apply padding.
+ SmallVector<tensor::PadOp> newPadOps;
+ FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
+ rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
+ newPadOps);
+ if (failed(maybePaddedOp)) {
+ auto diag = emitSilenceableError() << "failed to pad op";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ // Set transform results.
+ paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
+ padOps.append(newPadOps.begin(), newPadOps.end());
+ }
+
+ results.set(cast<OpResult>(getPadded()), paddedOps);
+ results.set(cast<OpResult>(getPad()), padOps);
+ return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::PadTilingInterfaceOp::verify() {
+ SmallVector<int64_t> paddingDimensions =
+ extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
+ if (any_of(paddingDimensions,
+ [](int64_t paddingDimension) { return paddingDimension < 0; })) {
+ return emitOpError() << "expects padding_dimensions to contain positive "
+ "integers, found "
+ << getPaddingDimensions();
+ }
+ if (getMixedPaddingSizes().size() != paddingDimensions.size()) {
+ return emitOpError() << "expects as many multiples as padding_dimensions";
+ }
+ return success();
+}
+
//===---------------------------------------------------------------------===//
// HoistPadOp
//===---------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 881d9fcb4f52e..69e6fdabf9a58 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -29,6 +29,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
BlockPackMatmul.cpp
PackAndUnpackPatterns.cpp
Padding.cpp
+ PadTilingInterface.cpp
Promotion.cpp
RuntimeOpVerification.cpp
Specialize.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
new file mode 100644
index 0000000000000..a9d7bc64f2a6b
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -0,0 +1,322 @@
+//===- PaddingTilingInterface.cpp - Padding of TilingInterface ops --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Casting.h"
+
+#define DEBUG_TYPE "pad-tiling-interface"
+
+using namespace mlir;
+using namespace mlir::linalg;
+using namespace mlir::tensor;
+
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+
+/// Form a "full-rank" padding specification so that the application is easy.
+static llvm::SmallDenseMap<int64_t, OpFoldResult>
+getDimsToSize(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
+ const PadTilingInterfaceOptions &options) {
+ llvm::SmallDenseMap<int64_t, OpFoldResult> dimsToSize;
+ for (const auto &[paddingDim, paddingSize] :
+ llvm::zip_equal(options.paddingDimensions, options.paddingSizes)) {
+ dimsToSize[paddingDim] = paddingSize;
+ }
+ // Complete the padding specification to specify all dimensions.
+ for (int64_t idx = 0, e = indexingSizes.size(); idx != e; ++idx) {
+ if (dimsToSize.find(idx) != dimsToSize.end())
+ continue;
+ // If a dimension is not specified, either complete with:
+ // - 1 if we are padding to the next multiple of.
+ // - indexingSizes[idx] otherwise
+ dimsToSize[idx] =
+ options.padToMultipleOf ? b.getIndexAttr(1) : indexingSizes[idx];
+ }
+ for (int64_t idx = 0, e = indexingSizes.size(); idx != e; ++idx) {
+ LLVM_DEBUG(DBGS() << "----idx: " << idx << " : " << dimsToSize[idx]
+ << "\n");
+ }
+ return dimsToSize;
+}
+
+/// Compute the padded shape of the given value `v` of `RankedTensorType` given
+/// - `indexingSizes` a list of OpFoldResult.
+/// - an `indexingMap` that encodes how the shape of varies with increases
+/// in `indexingSizes`.
+/// The `indexingMap` encodes how the shape of varies w...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/144991
More information about the Mlir-commits
mailing list