[Mlir-commits] [mlir] [mlir][Transforms] Add a PadTilingInterface transformation and hook i… (PR #144991)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 20 00:32:29 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Nicolas Vasilache (nicolasvasilache)

<details>
<summary>Changes</summary>

…t up to the transform dialect

This revision revisits the padding transformation from first principles and prepares it to work more generally with TilingInterface.

Compared to structured.transform.pad it has the following differences:
  - no support for nofold, copy-back, transpose and hoisting: these have been carried by the padding op in the very early days of StructuredOps and have since then been separated out as independent transformations that compose.
  - no conflated static bounding box derivation attempts: pad_tiling_interface composes more naturally with or without tiling.
  - properly derives padding size on outputs where multiple dimensions contribute: this is not supported in structured.transform.pad
  - geared towards supporting TilingInterface once the proper control mechanisms are supported through a PadSizeComputationFunction (supports LinalgOp by default)

This will gradually replace structured.transform.pad as it is fleshed out and tested more comprehensively.

---

Patch is 35.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144991.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+76) 
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+76-3) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+161) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp (+322) 
- (added) mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir (+73) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 6f6df350f1ba6..c00151dea54a6 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1186,6 +1186,82 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// PadTilingInterfaceOp
+//===----------------------------------------------------------------------===//
+
+def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interface",
+    [FunctionalStyleTransformOpTrait, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     TransformOpInterface,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Pads the operations pointed to by the target handle using the options
+    provided as operation attributes. The operation returns a handle to the
+    padded operation and to the padding operation ("tensor.pad").
+
+    #### Return modes
+
+    This operation ignores non-Linalg ops and drops them in the return.
+    In the future, this operation will support all TilingInterfaceOps.
+
+    This operation may produce a definite failure if the padding fails for any
+    reason.
+
+    If all the operations referred to by the `target` handle pad properly, the
+    transform succeeds. Otherwise the transform produces a silenceable failure.
+    The return handle points to only the subset of successfully produced
+    padded operations, which can be empty.
+  }];
+
+  let arguments =
+    (ins TransformHandleTypeInterface:$target,
+         DefaultValuedAttr<ArrayAttr, "{}">:$padding_values,
+         DefaultValuedAttr<I64ArrayAttr, "{}">:$padding_dimensions,
+         Variadic<TransformAnyParamTypeOrAnyHandle>:$padding_sizes,
+         DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
+            $static_padding_sizes,
+         DefaultValuedAttr<UnitAttr, "false">:$pad_to_multiple_of);
+  let results = (outs TransformHandleTypeInterface:$padded,
+                      TransformHandleTypeInterface:$pad);
+
+  let assemblyFormat = [{
+    $target
+    `to`
+    (`padding_sizes` custom<DynamicIndexList>($padding_sizes, $static_padding_sizes)^)?
+    (`pad_to_multiple_of` $pad_to_multiple_of^)?
+    attr-dict
+    `:` functional-type(operands, results)
+  }];
+
+  let hasVerifier = 1;
+
+  let builders = [
+    // Builder for a transform::PadOp with automatic inference of padding
+    // value. Warning: this will set the value 0 for the inferred elemental
+    // type without taking the op into account and thus only work for the
+    // add/mul ring at the moment.
+    // TODO: support other operations (e.g. min, max etc).
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<int64_t>":$paddingDimensions,
+                   CArg<"ArrayRef<int64_t>", "{}">:$staticPaddingSizes,
+                   CArg<"bool", "false">:$padToMultipleOf)>,
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<int64_t>":$paddingDimensions,
+                   "ArrayRef<OpFoldResult>":$mixedPadPaddingSizes,
+                   CArg<"bool", "false">:$usePrescribedTensorShapes)>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Returns a mix of dynamic `padding_sizes` and static `static_padding_sizes`.
+    SmallVector<OpFoldResult> getMixedPaddingSizes();
+
+    ::mlir::DiagnosedSilenceableFailure apply(
+      ::mlir::transform::TransformRewriter &rewriter,
+      ::mlir::transform::TransformResults &results,
+      ::mlir::transform::TransformState &state);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // HoistPadOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 147a2907f52e4..59b7fdeef10b3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -347,6 +348,34 @@ struct LinalgPaddingOptions {
   }
 };
 
+struct PadTilingInterfaceOptions {
+  /// A padding value for every operand.
+  SmallVector<Attribute> paddingValues;
+  PadTilingInterfaceOptions &setPaddingValues(ArrayRef<Attribute> pv) {
+    paddingValues.assign(pv.begin(), pv.end());
+    return *this;
+  }
+  /// A list of iterator dimensions to pad.
+  SmallVector<int64_t> paddingDimensions;
+  PadTilingInterfaceOptions &setPaddingDimensions(ArrayRef<int64_t> pd) {
+    paddingDimensions.assign(pd.begin(), pd.end());
+    return *this;
+  }
+  /// A list of iterator dimensions sizes to pad to.
+  SmallVector<OpFoldResult> paddingSizes;
+  PadTilingInterfaceOptions &setPaddingSizes(ArrayRef<OpFoldResult> m) {
+    paddingSizes.assign(m.begin(), m.end());
+    return *this;
+  }
+  /// Pad iterator `paddingDimension[i]` to next multiple of `paddingSizes[i]`
+  /// if true. Otherwise pad to `paddingSizes[i]`.
+  bool padToMultipleOf;
+  PadTilingInterfaceOptions &setPadToMultipleOf(bool b) {
+    padToMultipleOf = b;
+    return *this;
+  }
+};
+
 /// Callback function type used to perform the allocation for the promoted
 /// `subView`. In `boundingSubViewsize` a best attempt is made to find the
 /// smallest constant value for the size of the buffer needed for each
@@ -542,9 +571,9 @@ SmallVector<Value> peelLoop(RewriterBase &rewriter, Operation *op);
 /// where relevant.
 void peelLoops(RewriterBase &rewriter, ArrayRef<scf::ForOp> loops);
 
-/// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands
-/// to a static bounding box. The original `opToPad` is cloned and operates on
-/// the padded tensors.
+/// Pad the iterator dimensions `options.paddingDimensions` of all `opToPad`
+/// operands to a static bounding box. The original `opToPad` is cloned and
+/// operates on the padded tensors.
 ///
 /// * "options.padToMultipleOf" indicates that each padding dimension should be
 ///   padded to the specified multiple.
@@ -561,6 +590,50 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
                                 SmallVector<Value> &replacements,
                                 SmallVector<tensor::PadOp> &padOps);
 
+/// Helper function to compute the padded shape of the given value `v` of
+/// `RankedTensorType` given:
+///   - the `indexingSizes` as a list of OpFoldResult.
+///   - an `indexingMap` that encodes how the padded shape varies with
+///     increases in `indexingSizes`.
+/// The implementation iteratively combines increases from contributing using
+/// affine.apply operations.
+/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps and
+/// provides a gentle portability path for Linalg-like ops with affine maps.
+/// In the future, more general interfaces can be devised to encode similar
+/// shape evolutions and map between an op and its operands.
+SmallVector<OpFoldResult>
+computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v,
+                   AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
+                   const PadTilingInterfaceOptions &options);
+
+using PadSizeComputationFunction =
+    std::function<FailureOr<SmallVector<OpFoldResult>>(
+        RewriterBase &, OpOperand &, ArrayRef<Range>,
+        const PadTilingInterfaceOptions &)>;
+
+/// Specific helper for Linalg ops.
+FailureOr<SmallVector<OpFoldResult>>
+computeLinalgPaddedShape(RewriterBase &rewriter, OpOperand &operandToPad,
+                         ArrayRef<Range> iterationDomain,
+                         const PadTilingInterfaceOptions &options);
+
+/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
+///
+/// * "options.paddingSizes" indicates that each padding dimension should be
+///   padded to the specified padding size.
+/// * "options.padToMultipleOf" indicates that the paddingSizes should be
+//    interpreted as the bounding box (dynamic) value to pad to.
+/// * Use "options.paddingValues" to set the padding value of the created
+//    tensor::PadOp.
+/// * The tensor::PadOp is returned on success.
+
+FailureOr<TilingInterface>
+rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
+                  const PadTilingInterfaceOptions &constOptions,
+                  SmallVector<tensor::PadOp> &padOps,
+                  PadSizeComputationFunction computePaddingSizeFun =
+                      &computeLinalgPaddedShape);
+
 namespace detail {
 
 /// Helper struct to hold the results of building a packing loop nest.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d78c8847f8843..bc8ff8e55bf0f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -45,6 +45,7 @@
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
 #include <type_traits>
 
 using namespace mlir;
@@ -2155,6 +2156,166 @@ LogicalResult transform::PadOp::verify() {
   return success();
 }
 
+//===---------------------------------------------------------------------===//
+// PadTilingInterfaceOp
+//===---------------------------------------------------------------------===//
+
+void transform::PadTilingInterfaceOp::build(OpBuilder &b,
+                                            OperationState &result,
+                                            Value target,
+                                            ArrayRef<int64_t> paddingDimensions,
+                                            ArrayRef<int64_t> paddingSizes,
+                                            bool padToMultipleOf) {
+  auto resultType = transform::AnyOpType::get(b.getContext());
+  return build(/*builder=*/b,
+               /*result=*/result,
+               /*types=*/TypeRange{resultType, resultType},
+               /*target=*/target,
+               /*paddingValues=*/ArrayAttr(), // let inference handle this
+               /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
+               /*paddingSizes=*/ValueRange{},
+               /*paddingSizes=*/
+               (paddingSizes.empty() ? DenseI64ArrayAttr()
+                                     : b.getDenseI64ArrayAttr(paddingSizes)),
+               /*padToMultipleOf=*/
+               padToMultipleOf ? b.getUnitAttr() : nullptr);
+}
+
+void transform::PadTilingInterfaceOp::build(
+    OpBuilder &b, OperationState &result, Value target,
+    ArrayRef<int64_t> paddingDimensions,
+    ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
+  auto resultType = transform::AnyOpType::get(b.getContext());
+  SmallVector<int64_t> staticPaddingSizes;
+  SmallVector<Value> dynamicPaddingSizes;
+  dispatchIndexOpFoldResults(mixedPaddingSizes, dynamicPaddingSizes,
+                             staticPaddingSizes);
+  return build(/*builder=*/b,
+               /*result=*/result,
+               /*types=*/TypeRange{resultType, resultType},
+               /*target=*/target,
+               /*paddingValues=*/ArrayAttr(), // let inference handle this
+               /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
+               /*paddingSizes=*/dynamicPaddingSizes,
+               /*paddingSizes=*/staticPaddingSizes,
+               /*usePrescribedTensorShapes=*/padToMultipleOf);
+}
+
+void transform::PadTilingInterfaceOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getPaddingSizesMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  modifiesPayload(effects);
+}
+
+SmallVector<OpFoldResult>
+transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
+  Builder b(getContext());
+  return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
+}
+
+DiagnosedSilenceableFailure
+transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
+                                       transform::TransformResults &results,
+                                       transform::TransformState &state) {
+  SmallVector<Operation *> paddedOps, padOps;
+
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    auto targetOp = dyn_cast<TilingInterface>(target);
+    if (!targetOp) {
+      auto diag = emitSilenceableError() << "expected TilingInterface target";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+
+    // Only Linalg ops for now, until TilingInterface exposes a loopsToOperand
+    // map / C++ APIs to compute the effect of padding on operands.
+    if (!isa<LinalgOp>(targetOp.getOperation())) {
+      auto diag = emitSilenceableError() << "only LinalgOp supported atm";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+
+    // Convert the padding values to attributes.
+    SmallVector<Attribute> paddingValues;
+    for (auto const &it :
+         llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
+      auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
+      if (!attr) {
+        emitOpError("expects padding values to be typed attributes");
+        return DiagnosedSilenceableFailure::definiteFailure();
+      }
+      Type elementType = getElementTypeOrSelf(std::get<1>(it));
+      // Try to parse string attributes to obtain an attribute of element type.
+      if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
+        auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
+            stringAttr, getContext(), elementType,
+            /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
+        if (!parsedAttr || parsedAttr.getType() != elementType) {
+          auto diag = this->emitOpError("expects a padding that parses to ")
+                      << elementType << ", got " << std::get<0>(it);
+          diag.attachNote(targetOp.getLoc()) << "when applied to this op";
+          return DiagnosedSilenceableFailure::definiteFailure();
+        }
+        paddingValues.push_back(parsedAttr);
+        continue;
+      }
+      // Otherwise, add the attribute directly.
+      if (attr.getType() != elementType) {
+        auto diag = this->emitOpError("expects a padding value of type ")
+                    << elementType << ", got " << attr;
+        diag.attachNote(targetOp.getLoc()) << "when applied to this op";
+        return DiagnosedSilenceableFailure::definiteFailure();
+      }
+      paddingValues.push_back(attr);
+    }
+
+    // Set options.
+    TilingInterface paddedOp;
+    PadTilingInterfaceOptions options;
+    options.setPaddingValues(paddingValues)
+        .setPaddingDimensions(
+            extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions()))
+        .setPaddingSizes(getMixedPaddingSizes())
+        .setPadToMultipleOf(getPadToMultipleOf());
+
+    // Apply padding.
+    SmallVector<tensor::PadOp> newPadOps;
+    FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
+        rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
+        newPadOps);
+    if (failed(maybePaddedOp)) {
+      auto diag = emitSilenceableError() << "failed to pad op";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+
+    // Set transform results.
+    paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
+    padOps.append(newPadOps.begin(), newPadOps.end());
+  }
+
+  results.set(cast<OpResult>(getPadded()), paddedOps);
+  results.set(cast<OpResult>(getPad()), padOps);
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::PadTilingInterfaceOp::verify() {
+  SmallVector<int64_t> paddingDimensions =
+      extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
+  if (any_of(paddingDimensions,
+             [](int64_t paddingDimension) { return paddingDimension < 0; })) {
+    return emitOpError() << "expects padding_dimensions to contain positive "
+                            "integers, found "
+                         << getPaddingDimensions();
+  }
+  if (getMixedPaddingSizes().size() != paddingDimensions.size()) {
+    return emitOpError() << "expects as many multiples as padding_dimensions";
+  }
+  return success();
+}
+
 //===---------------------------------------------------------------------===//
 // HoistPadOp
 //===---------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 881d9fcb4f52e..69e6fdabf9a58 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -29,6 +29,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   BlockPackMatmul.cpp
   PackAndUnpackPatterns.cpp
   Padding.cpp
+  PadTilingInterface.cpp
   Promotion.cpp
   RuntimeOpVerification.cpp
   Specialize.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
new file mode 100644
index 0000000000000..a9d7bc64f2a6b
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -0,0 +1,322 @@
+//===- PaddingTilingInterface.cpp - Padding of TilingInterface ops --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Casting.h"
+
+#define DEBUG_TYPE "pad-tiling-interface"
+
+using namespace mlir;
+using namespace mlir::linalg;
+using namespace mlir::tensor;
+
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+
+/// Form a "full-rank" padding specification so that the application is easy.
+static llvm::SmallDenseMap<int64_t, OpFoldResult>
+getDimsToSize(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
+              const PadTilingInterfaceOptions &options) {
+  llvm::SmallDenseMap<int64_t, OpFoldResult> dimsToSize;
+  for (const auto &[paddingDim, paddingSize] :
+       llvm::zip_equal(options.paddingDimensions, options.paddingSizes)) {
+    dimsToSize[paddingDim] = paddingSize;
+  }
+  // Complete the padding specification to specify all dimensions.
+  for (int64_t idx = 0, e = indexingSizes.size(); idx != e; ++idx) {
+    if (dimsToSize.find(idx) != dimsToSize.end())
+      continue;
+    // If a dimension is not specified, either complete with:
+    //   - 1 if we are padding to the next multiple of.
+    //   - indexingSizes[idx] otherwise
+    dimsToSize[idx] =
+        options.padToMultipleOf ? b.getIndexAttr(1) : indexingSizes[idx];
+  }
+  for (int64_t idx = 0, e = indexingSizes.size(); idx != e; ++idx) {
+    LLVM_DEBUG(DBGS() << "----idx: " << idx << " : " << dimsToSize[idx]
+                      << "\n");
+  }
+  return dimsToSize;
+}
+
+/// Compute the padded shape of the given value `v` of `RankedTensorType` given
+///   - `indexingSizes` a list of OpFoldResult.
+///   - an `indexingMap` that encodes how the shape of varies with increases
+///     in `indexingSizes`.
+/// The `indexingMap` encodes how the shape of varies w...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/144991


More information about the Mlir-commits mailing list