[Mlir-commits] [mlir] [mlir][Transforms] Add a PadTilingInterface transformation and hook i… (PR #144991)

Nicolas Vasilache llvmlistbot at llvm.org
Fri Jun 20 00:32:01 PDT 2025


https://github.com/nicolasvasilache created https://github.com/llvm/llvm-project/pull/144991

…t up to the transform dialect

This revision revisits the padding transformation from first principles and prepares it to work more generally with TilingInterface.

Compared to structured.transform.pad it has the following differences:
  - no support for nofold, copy-back, transpose and hoisting: these have been carried by the padding op in the very early days of StructuredOps and have since then been separated out as independent transformations that compose.
  - no conflated static bounding box derivation attempts: pad_tiling_interface composes more naturally with or without tiling.
  - properly derives padding size on outputs where multiple dimensions contribute: this is not supported in structured.transform.pad
  - geared towards supporting TilingInterface once the proper control mechanisms are supported through a PadSizeComputationFunction (supports LinalgOp by default)

This will gradually replace structured.transform.pad as it is fleshed out and tested more comprehensively.

>From d7d67e9f9b5641f14894bd64540cef678771613d Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Wed, 18 Jun 2025 16:39:14 +0200
Subject: [PATCH] [mlir][Transforms] Add a PadTilingInterface transformation
 and hook it up to the transform dialect

This revision revisits the padding transformation from first principles and prepares it to work more generally with TilingInterface.

Compared to structured.transform.pad it has the following differences:
  - no support for nofold, copy-back, transpose and hoisting: these have been carried by the padding op in the very early days of StructuredOps
    and have since then been separated out as independent transformations that compose.
  - no conflated static bounding box derivation attempts: pad_tiling_interface composes more naturally with or without tiling.
  - properly derives padding size on outputs where multiple dimensions contribute: this is not supported in structured.transform.pad
  - geared towards supporting TilingInterface once the proper control mechanisms are supported through a PadSizeComputationFunction (supports LinalgOp by default)

This will gradually replace structured.transform.pad as it is fleshed out and tested more comprehensively.
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  76 +++++
 .../Dialect/Linalg/Transforms/Transforms.h    |  79 ++++-
 .../TransformOps/LinalgTransformOps.cpp       | 161 +++++++++
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 +
 .../Linalg/Transforms/PadTilingInterface.cpp  | 322 ++++++++++++++++++
 .../transform-op-pad-tiling-interface.mlir    |  73 ++++
 6 files changed, 709 insertions(+), 3 deletions(-)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
 create mode 100644 mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 6f6df350f1ba6..c00151dea54a6 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1186,6 +1186,82 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// PadTilingInterfaceOp
+//===----------------------------------------------------------------------===//
+
+def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interface",
+    [FunctionalStyleTransformOpTrait, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     TransformOpInterface,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Pads the operations pointed to by the target handle using the options
+    provided as operation attributes. The operation returns a handle to the
+    padded operation and to the padding operation ("tensor.pad").
+
+    #### Return modes
+
+    This operation ignores non-Linalg ops and drops them in the return.
+    In the future, this operation will support all TilingInterfaceOps.
+
+    This operation may produce a definite failure if the padding fails for any
+    reason.
+
+    If all the operations referred to by the `target` handle pad properly, the
+    transform succeeds. Otherwise the transform produces a silenceable failure.
+    The return handle points to only the subset of successfully produced
+    padded operations, which can be empty.
+  }];
+
+  let arguments =
+    (ins TransformHandleTypeInterface:$target,
+         DefaultValuedAttr<ArrayAttr, "{}">:$padding_values,
+         DefaultValuedAttr<I64ArrayAttr, "{}">:$padding_dimensions,
+         Variadic<TransformAnyParamTypeOrAnyHandle>:$padding_sizes,
+         DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
+            $static_padding_sizes,
+         DefaultValuedAttr<UnitAttr, "false">:$pad_to_multiple_of);
+  let results = (outs TransformHandleTypeInterface:$padded,
+                      TransformHandleTypeInterface:$pad);
+
+  let assemblyFormat = [{
+    $target
+    `to`
+    (`padding_sizes` custom<DynamicIndexList>($padding_sizes, $static_padding_sizes)^)?
+    (`pad_to_multiple_of` $pad_to_multiple_of^)?
+    attr-dict
+    `:` functional-type(operands, results)
+  }];
+
+  let hasVerifier = 1;
+
+  let builders = [
+    // Builder for a transform::PadOp with automatic inference of padding
+    // value. Warning: this will set the value 0 for the inferred elemental
+    // type without taking the op into account and thus only work for the
+    // add/mul ring at the moment.
+    // TODO: support other operations (e.g. min, max etc).
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<int64_t>":$paddingDimensions,
+                   CArg<"ArrayRef<int64_t>", "{}">:$staticPaddingSizes,
+                   CArg<"bool", "false">:$padToMultipleOf)>,
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<int64_t>":$paddingDimensions,
+                   "ArrayRef<OpFoldResult>":$mixedPadPaddingSizes,
+                   CArg<"bool", "false">:$usePrescribedTensorShapes)>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Returns a mix of dynamic `padding_sizes` and static `static_padding_sizes`.
+    SmallVector<OpFoldResult> getMixedPaddingSizes();
+
+    ::mlir::DiagnosedSilenceableFailure apply(
+      ::mlir::transform::TransformRewriter &rewriter,
+      ::mlir::transform::TransformResults &results,
+      ::mlir::transform::TransformState &state);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // HoistPadOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 147a2907f52e4..59b7fdeef10b3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/X86Vector/Transforms.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -347,6 +348,34 @@ struct LinalgPaddingOptions {
   }
 };
 
+struct PadTilingInterfaceOptions {
+  /// A padding value for every operand.
+  SmallVector<Attribute> paddingValues;
+  PadTilingInterfaceOptions &setPaddingValues(ArrayRef<Attribute> pv) {
+    paddingValues.assign(pv.begin(), pv.end());
+    return *this;
+  }
+  /// A list of iterator dimensions to pad.
+  SmallVector<int64_t> paddingDimensions;
+  PadTilingInterfaceOptions &setPaddingDimensions(ArrayRef<int64_t> pd) {
+    paddingDimensions.assign(pd.begin(), pd.end());
+    return *this;
+  }
+  /// A list of iterator dimensions sizes to pad to.
+  SmallVector<OpFoldResult> paddingSizes;
+  PadTilingInterfaceOptions &setPaddingSizes(ArrayRef<OpFoldResult> m) {
+    paddingSizes.assign(m.begin(), m.end());
+    return *this;
+  }
+  /// Pad iterator `paddingDimension[i]` to next multiple of `paddingSizes[i]`
+  /// if true. Otherwise pad to `paddingSizes[i]`.
+  bool padToMultipleOf;
+  PadTilingInterfaceOptions &setPadToMultipleOf(bool b) {
+    padToMultipleOf = b;
+    return *this;
+  }
+};
+
 /// Callback function type used to perform the allocation for the promoted
 /// `subView`. In `boundingSubViewsize` a best attempt is made to find the
 /// smallest constant value for the size of the buffer needed for each
@@ -542,9 +571,9 @@ SmallVector<Value> peelLoop(RewriterBase &rewriter, Operation *op);
 /// where relevant.
 void peelLoops(RewriterBase &rewriter, ArrayRef<scf::ForOp> loops);
 
-/// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands
-/// to a static bounding box. The original `opToPad` is cloned and operates on
-/// the padded tensors.
+/// Pad the iterator dimensions `options.paddingDimensions` of all `opToPad`
+/// operands to a static bounding box. The original `opToPad` is cloned and
+/// operates on the padded tensors.
 ///
 /// * "options.padToMultipleOf" indicates that each padding dimension should be
 ///   padded to the specified multiple.
@@ -561,6 +590,50 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
                                 SmallVector<Value> &replacements,
                                 SmallVector<tensor::PadOp> &padOps);
 
+/// Helper function to compute the padded shape of the given value `v` of
+/// `RankedTensorType` given:
+///   - the `indexingSizes` as a list of OpFoldResult.
+///   - an `indexingMap` that encodes how the padded shape varies with
+///     increases in `indexingSizes`.
+/// The implementation iteratively combines increases from contributing using
+/// affine.apply operations.
+/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps and
+/// provides a gentle portability path for Linalg-like ops with affine maps.
+/// In the future, more general interfaces can be devised to encode similar
+/// shape evolutions and map between an op and its operands.
+SmallVector<OpFoldResult>
+computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v,
+                   AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
+                   const PadTilingInterfaceOptions &options);
+
+using PadSizeComputationFunction =
+    std::function<FailureOr<SmallVector<OpFoldResult>>(
+        RewriterBase &, OpOperand &, ArrayRef<Range>,
+        const PadTilingInterfaceOptions &)>;
+
+/// Specific helper for Linalg ops.
+FailureOr<SmallVector<OpFoldResult>>
+computeLinalgPaddedShape(RewriterBase &rewriter, OpOperand &operandToPad,
+                         ArrayRef<Range> iterationDomain,
+                         const PadTilingInterfaceOptions &options);
+
+/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
+///
+/// * "options.paddingSizes" indicates that each padding dimension should be
+///   padded to the specified padding size.
+/// * "options.padToMultipleOf" indicates that the paddingSizes should be
+//    interpreted as the bounding box (dynamic) value to pad to.
+/// * Use "options.paddingValues" to set the padding value of the created
+//    tensor::PadOp.
+/// * The tensor::PadOp is returned on success.
+
+FailureOr<TilingInterface>
+rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
+                  const PadTilingInterfaceOptions &constOptions,
+                  SmallVector<tensor::PadOp> &padOps,
+                  PadSizeComputationFunction computePaddingSizeFun =
+                      &computeLinalgPaddedShape);
+
 namespace detail {
 
 /// Helper struct to hold the results of building a packing loop nest.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d78c8847f8843..bc8ff8e55bf0f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -45,6 +45,7 @@
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
 #include <type_traits>
 
 using namespace mlir;
@@ -2155,6 +2156,166 @@ LogicalResult transform::PadOp::verify() {
   return success();
 }
 
+//===---------------------------------------------------------------------===//
+// PadTilingInterfaceOp
+//===---------------------------------------------------------------------===//
+
+void transform::PadTilingInterfaceOp::build(OpBuilder &b,
+                                            OperationState &result,
+                                            Value target,
+                                            ArrayRef<int64_t> paddingDimensions,
+                                            ArrayRef<int64_t> paddingSizes,
+                                            bool padToMultipleOf) {
+  auto resultType = transform::AnyOpType::get(b.getContext());
+  return build(/*builder=*/b,
+               /*result=*/result,
+               /*types=*/TypeRange{resultType, resultType},
+               /*target=*/target,
+               /*paddingValues=*/ArrayAttr(), // let inference handle this
+               /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
+               /*paddingSizes=*/ValueRange{},
+               /*paddingSizes=*/
+               (paddingSizes.empty() ? DenseI64ArrayAttr()
+                                     : b.getDenseI64ArrayAttr(paddingSizes)),
+               /*padToMultipleOf=*/
+               padToMultipleOf ? b.getUnitAttr() : nullptr);
+}
+
+void transform::PadTilingInterfaceOp::build(
+    OpBuilder &b, OperationState &result, Value target,
+    ArrayRef<int64_t> paddingDimensions,
+    ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
+  auto resultType = transform::AnyOpType::get(b.getContext());
+  SmallVector<int64_t> staticPaddingSizes;
+  SmallVector<Value> dynamicPaddingSizes;
+  dispatchIndexOpFoldResults(mixedPaddingSizes, dynamicPaddingSizes,
+                             staticPaddingSizes);
+  return build(/*builder=*/b,
+               /*result=*/result,
+               /*types=*/TypeRange{resultType, resultType},
+               /*target=*/target,
+               /*paddingValues=*/ArrayAttr(), // let inference handle this
+               /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
+               /*paddingSizes=*/dynamicPaddingSizes,
+               /*paddingSizes=*/staticPaddingSizes,
+               /*usePrescribedTensorShapes=*/padToMultipleOf);
+}
+
+void transform::PadTilingInterfaceOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getPaddingSizesMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  modifiesPayload(effects);
+}
+
+SmallVector<OpFoldResult>
+transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
+  Builder b(getContext());
+  return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
+}
+
+DiagnosedSilenceableFailure
+transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
+                                       transform::TransformResults &results,
+                                       transform::TransformState &state) {
+  SmallVector<Operation *> paddedOps, padOps;
+
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    auto targetOp = dyn_cast<TilingInterface>(target);
+    if (!targetOp) {
+      auto diag = emitSilenceableError() << "expected TilingInterface target";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+
+    // Only Linalg ops for now, until TilingInterface exposes a loopsToOperand
+    // map / C++ APIs to compute the effect of padding on operands.
+    if (!isa<LinalgOp>(targetOp.getOperation())) {
+      auto diag = emitSilenceableError() << "only LinalgOp supported atm";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+
+    // Convert the padding values to attributes.
+    SmallVector<Attribute> paddingValues;
+    for (auto const &it :
+         llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
+      auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
+      if (!attr) {
+        emitOpError("expects padding values to be typed attributes");
+        return DiagnosedSilenceableFailure::definiteFailure();
+      }
+      Type elementType = getElementTypeOrSelf(std::get<1>(it));
+      // Try to parse string attributes to obtain an attribute of element type.
+      if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
+        auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
+            stringAttr, getContext(), elementType,
+            /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
+        if (!parsedAttr || parsedAttr.getType() != elementType) {
+          auto diag = this->emitOpError("expects a padding that parses to ")
+                      << elementType << ", got " << std::get<0>(it);
+          diag.attachNote(targetOp.getLoc()) << "when applied to this op";
+          return DiagnosedSilenceableFailure::definiteFailure();
+        }
+        paddingValues.push_back(parsedAttr);
+        continue;
+      }
+      // Otherwise, add the attribute directly.
+      if (attr.getType() != elementType) {
+        auto diag = this->emitOpError("expects a padding value of type ")
+                    << elementType << ", got " << attr;
+        diag.attachNote(targetOp.getLoc()) << "when applied to this op";
+        return DiagnosedSilenceableFailure::definiteFailure();
+      }
+      paddingValues.push_back(attr);
+    }
+
+    // Set options.
+    TilingInterface paddedOp;
+    PadTilingInterfaceOptions options;
+    options.setPaddingValues(paddingValues)
+        .setPaddingDimensions(
+            extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions()))
+        .setPaddingSizes(getMixedPaddingSizes())
+        .setPadToMultipleOf(getPadToMultipleOf());
+
+    // Apply padding.
+    SmallVector<tensor::PadOp> newPadOps;
+    FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
+        rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
+        newPadOps);
+    if (failed(maybePaddedOp)) {
+      auto diag = emitSilenceableError() << "failed to pad op";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+
+    // Set transform results.
+    paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
+    padOps.append(newPadOps.begin(), newPadOps.end());
+  }
+
+  results.set(cast<OpResult>(getPadded()), paddedOps);
+  results.set(cast<OpResult>(getPad()), padOps);
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::PadTilingInterfaceOp::verify() {
+  SmallVector<int64_t> paddingDimensions =
+      extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
+  if (any_of(paddingDimensions,
+             [](int64_t paddingDimension) { return paddingDimension < 0; })) {
+    return emitOpError() << "expects padding_dimensions to contain positive "
+                            "integers, found "
+                         << getPaddingDimensions();
+  }
+  if (getMixedPaddingSizes().size() != paddingDimensions.size()) {
+    return emitOpError() << "expects as many multiples as padding_dimensions";
+  }
+  return success();
+}
+
 //===---------------------------------------------------------------------===//
 // HoistPadOp
 //===---------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 881d9fcb4f52e..69e6fdabf9a58 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -29,6 +29,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   BlockPackMatmul.cpp
   PackAndUnpackPatterns.cpp
   Padding.cpp
+  PadTilingInterface.cpp
   Promotion.cpp
   RuntimeOpVerification.cpp
   Specialize.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
new file mode 100644
index 0000000000000..a9d7bc64f2a6b
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -0,0 +1,322 @@
+//===- PaddingTilingInterface.cpp - Padding of TilingInterface ops --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Casting.h"
+
+#define DEBUG_TYPE "pad-tiling-interface"
+
+using namespace mlir;
+using namespace mlir::linalg;
+using namespace mlir::tensor;
+
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+
+/// Form a "full-rank" padding specification so that the application is easy.
+static llvm::SmallDenseMap<int64_t, OpFoldResult>
+getDimsToSize(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
+              const PadTilingInterfaceOptions &options) {
+  llvm::SmallDenseMap<int64_t, OpFoldResult> dimsToSize;
+  for (const auto &[paddingDim, paddingSize] :
+       llvm::zip_equal(options.paddingDimensions, options.paddingSizes)) {
+    dimsToSize[paddingDim] = paddingSize;
+  }
+  // Complete the padding specification to specify all dimensions.
+  for (int64_t idx = 0, e = indexingSizes.size(); idx != e; ++idx) {
+    if (dimsToSize.find(idx) != dimsToSize.end())
+      continue;
+    // If a dimension is not specified, either complete with:
+    //   - 1 if we are padding to the next multiple of.
+    //   - indexingSizes[idx] otherwise
+    dimsToSize[idx] =
+        options.padToMultipleOf ? b.getIndexAttr(1) : indexingSizes[idx];
+  }
+  for (int64_t idx = 0, e = indexingSizes.size(); idx != e; ++idx) {
+    LLVM_DEBUG(DBGS() << "----idx: " << idx << " : " << dimsToSize[idx]
+                      << "\n");
+  }
+  return dimsToSize;
+}
+
+/// Compute the padded shape of the given value `v` of `RankedTensorType` given
+///   - `indexingSizes` a list of OpFoldResult.
+///   - an `indexingMap` that encodes how the shape of varies with increases
+///     in `indexingSizes`.
+/// The `indexingMap` encodes how the shape of varies with `indexingSizes`.
+/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps.
+/// The implementaiton below iteratively combines increases from contributing
+/// dimensions using affine.apply operations.
+/// In the future, more general interfaces can be devised to encode similar
+/// shape evolutions and map between an op and its operands.
+SmallVector<OpFoldResult> linalg::computePaddedShape(
+    RewriterBase &rewriter, TypedValue<RankedTensorType> v,
+    AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
+    const PadTilingInterfaceOptions &options) {
+  Location loc = v.getLoc();
+  SmallVector<OpFoldResult> paddedShape;
+  auto tensorType = cast<RankedTensorType>(v.getType());
+  paddedShape.resize_for_overwrite(tensorType.getRank());
+  assert(tensorType.getRank() == indexingMap.getNumResults() &&
+         "expect the number of results of the affine map to match the tensor "
+         "rank");
+
+  // "Full-rank" padding specification.
+  llvm::SmallDenseMap<int64_t, OpFoldResult> dimsToSize =
+      getDimsToSize(rewriter, indexingSizes, options);
+
+  // For each dimension in the operand's shape, iterate over indexingSizes and
+  // add
+  for (const auto &enResults : enumerate(indexingMap.getResults())) {
+    int64_t resultIndex = enResults.index();
+    AffineMap partialIndexingMap = indexingMap.getSubMap(
+        ArrayRef<unsigned>{static_cast<unsigned>(resultIndex)});
+
+    LLVM_DEBUG(DBGS() << "----resultIndex: " << resultIndex
+                      << " with partialIndexingMap: " << partialIndexingMap
+                      << "\n");
+
+    // Find all padding dimensions that contribute to this operand dimension
+    // and compute the padded term contribution to the final padded shape.
+    SmallVector<OpFoldResult> terms;
+    for (const auto &[paddingDim, paddingSize] : dimsToSize) {
+      LLVM_DEBUG(DBGS() << "------try apply padding of dim: " << paddingDim
+                        << " to: " << paddingSize << "\n");
+      if (!enResults.value().isFunctionOfDim(paddingDim))
+        continue;
+
+      LLVM_DEBUG(DBGS() << "------apply padding of dim: " << paddingDim
+                        << " to: " << paddingSize << "\n");
+
+      // Project non-'paddingDim' dimensions and compress the result.
+      llvm::SmallBitVector projectedDims(partialIndexingMap.getNumDims(), true);
+      projectedDims.flip(paddingDim);
+      AffineMap projectedMap =
+          mlir::projectDims(partialIndexingMap, projectedDims,
+                            /*compressDims=*/true);
+
+      // If we are padding to the next multiple of, compose with ceil(sz) * sz.
+      if (options.padToMultipleOf) {
+        AffineExpr d0, s0;
+        bindDims(rewriter.getContext(), d0);
+        bindSymbols(rewriter.getContext(), s0);
+        AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0);
+        AffineMap composedMap = projectedMap.compose(ceilMap);
+        OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+            rewriter, loc, composedMap,
+            {indexingSizes[paddingDim], paddingSize});
+        terms.push_back(paddingDimOfr);
+      } else {
+        // Otherwise just set to paddingSize.
+        OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+            rewriter, loc, projectedMap, paddingSize);
+        terms.push_back(paddingDimOfr);
+      }
+
+      LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
+    }
+
+    // If there are no terms, just return the dim.
+    if (terms.empty()) {
+      paddedShape[resultIndex] =
+          createFoldedDimOp(rewriter, loc, v, resultIndex);
+      continue;
+    }
+
+    // Sum individual terms' contributions.
+    SmallVector<AffineExpr> dims(terms.size());
+    bindDimsList(rewriter.getContext(), MutableArrayRef{dims});
+    AffineExpr sumExpr = dims.front();
+    for (unsigned i = 1; i < dims.size(); ++i)
+      sumExpr = sumExpr + dims[i];
+    OpFoldResult paddedDimOfr =
+        affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, terms);
+    paddedShape[resultIndex] = paddedDimOfr;
+  }
+
+  return paddedShape;
+}
+
+FailureOr<SmallVector<OpFoldResult>> linalg::computeLinalgPaddedShape(
+    RewriterBase &rewriter, OpOperand &operandToPad,
+    ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
+  auto linalgOp = llvm::dyn_cast<LinalgOp>(operandToPad.getOwner());
+  if (!linalgOp)
+    return failure();
+
+  // clang-format off
+  assert(llvm::all_of(iterationDomain, [&rewriter](Range r) {
+    return r.offset == OpFoldResult(rewriter.getIndexAttr(0)) &&
+    r.stride == OpFoldResult(rewriter.getIndexAttr(1));
+  }) && "expected 0-offset 1-stride loop ranges");
+  // clang-format on
+  SmallVector<OpFoldResult> loopUpperBounds;
+  loopUpperBounds.reserve(iterationDomain.size());
+  for (const Range &range : iterationDomain)
+    loopUpperBounds.push_back(range.size);
+
+  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&operandToPad);
+  return computePaddedShape(
+      rewriter, cast<TypedValue<RankedTensorType>>(operandToPad.get()),
+      indexingMap, loopUpperBounds, options);
+}
+
+/// Pad a single operand to `paddedShape` using `paddingValueAttr` as padding
+/// Value.
+static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
+                        TypedValue<RankedTensorType> v,
+                        ArrayRef<OpFoldResult> paddedShape,
+                        Attribute paddingValueAttr) {
+  Value paddingValue;
+  if (auto complexTy =
+          dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
+    auto complexAttr = cast<ArrayAttr>(paddingValueAttr);
+    paddingValue = rewriter.create<complex::ConstantOp>(opToPad.getLoc(),
+                                                        complexTy, complexAttr);
+  } else {
+    paddingValue = rewriter.create<arith::ConstantOp>(
+        opToPad.getLoc(), cast<TypedAttr>(paddingValueAttr));
+  }
+
+  // Pad the operand to the bounding box defined by `paddedShape`.
+  SmallVector<int64_t> tensorShape;
+  SmallVector<Value> dynDims;
+  for (OpFoldResult ofr : paddedShape) {
+    std::optional<int64_t> cst = getConstantIntValue(ofr);
+    tensorShape.push_back(cst.has_value() ? *cst : ShapedType::kDynamic);
+    if (!cst.has_value())
+      dynDims.push_back(ofr.dyn_cast<Value>());
+  }
+  // TODO: use dispatchIndexOpFoldResults(paddedShape, dynDims, paddedShape);
+
+  auto paddedTensorType =
+      RankedTensorType::get(tensorShape, getElementTypeOrSelf(v));
+  LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: "
+                    << paddedTensorType);
+  return makeComposedPadHighOp(rewriter, opToPad.getLoc(), paddedTensorType, v,
+                               paddingValue, /*nofold=*/false, dynDims);
+}
+
+FailureOr<TilingInterface>
+linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
+                          const PadTilingInterfaceOptions &constOptions,
+                          SmallVector<tensor::PadOp> &padOps,
+                          PadSizeComputationFunction computePaddingSizeFun) {
+  LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
+  assert(constOptions.paddingSizes.size() ==
+             constOptions.paddingDimensions.size() &&
+         "invalid number of elements in padToMultipleOf");
+
+  Location loc = opToPad.getLoc();
+  PadTilingInterfaceOptions options(constOptions);
+  // Allow inference of pad values if they are not explicitly specified.
+  // TODO: be mindful about the value depending on the actual operation.
+  if (options.paddingValues.empty()) {
+    SmallVector<Type> types(opToPad->getOperandTypes());
+    llvm::append_range(types, opToPad->getResultTypes());
+    for (Type t : types) {
+      options.paddingValues.push_back(
+          rewriter.getZeroAttr(getElementTypeOrSelf(t)));
+    }
+  }
+
+  if (llvm::any_of(opToPad->getOperands(),
+                   [](Value v) { return isa<MemRefType>(v.getType()); })) {
+    return rewriter.notifyMatchFailure(opToPad,
+                                       "expected operation on tensors");
+  }
+
+  OpBuilder::InsertionGuard g(rewriter);
+  // Set IP after opToPad because we also take the dims of opToPad's output.
+  rewriter.setInsertionPointAfter(opToPad);
+
+  // 1. Get the loopUpperBounds from the TilingInterface.
+  SmallVector<Range> iterationDomain = opToPad.getIterationDomain(rewriter);
+
+  // 2. For each operand.
+  SmallVector<Value> newOperands;
+  newOperands.reserve(opToPad->getNumOperands());
+  for (OpOperand &opOperand : opToPad->getOpOperands()) {
+    LLVM_DEBUG(DBGS() << "--start padding oprd: " << opOperand.get() << "\n");
+    // 2.a. Compute padded shape.
+    FailureOr<SmallVector<OpFoldResult>> maybePaddedShape =
+        computePaddingSizeFun(rewriter, opOperand, iterationDomain, options);
+    if (failed(maybePaddedShape)) {
+      return rewriter.notifyMatchFailure(opToPad, "could not pad op");
+    }
+
+    // 2.b. Expect proper `paddingValues`.
+    // TODO: we may want to allow garbage padding in the future, in which case
+    // we would just not assert.
+    assert(opOperand.getOperandNumber() < options.paddingValues.size() &&
+           "--no padding value specified");
+    Attribute paddingValueAttr =
+        options.paddingValues[opOperand.getOperandNumber()];
+
+    // 2.c. Perform actual padding.
+    Value paddedOperand = padOperand(
+        rewriter, opToPad, cast<TypedValue<RankedTensorType>>(opOperand.get()),
+        *maybePaddedShape, paddingValueAttr);
+    LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n");
+
+    // 2.d. Perform actual padding.
+    newOperands.push_back(paddedOperand);
+    if (auto padOp = paddedOperand.getDefiningOp<tensor::PadOp>())
+      padOps.push_back(padOp);
+  }
+
+  // 3. Form the resulting tensor::ExtractSliceOp.
+  ReifiedRankedShapedTypeDims reifiedResultShapes;
+  if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) {
+    LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
+    return rewriter.notifyMatchFailure(opToPad,
+                                       "failed to reify result shapes");
+  }
+  assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
+         "expected same number of results");
+
+  // Clone `opToPad` to operate on the statically padded shapes.
+  auto resultTensorTypes =
+      ValueRange(newOperands).take_back(opToPad->getNumResults()).getTypes();
+  // clone **should** properly notify the rewriter.
+  TilingInterface paddedOp =
+      clone(rewriter, opToPad, resultTensorTypes, newOperands);
+  LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n");
+
+  // Recover the slice out of the new static results. This keeps the original
+  // opToPad around because it uses the dims of the original results.
+  SmallVector<Value> paddedSubtensorResults;
+  paddedSubtensorResults.reserve(opToPad->getNumResults());
+  for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
+    Value paddedResult = en.value();
+    int64_t resultNumber = en.index();
+    int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
+    SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+    SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+    paddedSubtensorResults.push_back(rewriter.create<tensor::ExtractSliceOp>(
+        loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
+        strides));
+  }
+
+  rewriter.replaceOp(opToPad, paddedSubtensorResults);
+
+  return paddedOp;
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
new file mode 100644
index 0000000000000..ee9d2473f81fd
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt --transform-interpreter -canonicalize -split-input-file --verify-diagnostics %s | FileCheck %s
+
+//     CHECK-LABEL: pad_lhs
+func.func @pad_lhs(
+  %arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>)
+     -> tensor<24x25xf32>
+{
+  //      CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>)
+  //      CHECK:   tensor.pad %{{.*}} 
+  //      CHECK:     : tensor<?x12xf32> to tensor<8x12xf32>
+  //      CHECK:   tensor.pad %{{.*}} 
+  //      CHECK:     : tensor<?x25xf32> to tensor<8x25xf32>
+  //      CHECK:   linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<8x12xf32>, tensor<12x25xf32>) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
+  //      CHECK:   tensor.extract_slice %{{.*}}[0, 0] [%{{.*}}, 25] [1, 1]
+  //      CHECK:     : tensor<8x25xf32> to tensor<?x25xf32>
+  //      CHECK:   tensor.insert_slice %{{.*}} into %{{.*}}[%{{.*}}, 0] [%{{.*}}, 25] [1, 1]
+  // CHECK-SAME:     : tensor<?x25xf32> into tensor<24x25xf32>
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+
+    // Tile to 5 then pad to 8 (supposedly to better hit vector ops).
+    %matmul_l1, %loops_l1 = transform.structured.tile_using_for %matmul tile_sizes [5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %matmul_padded, %_ = transform.structured.pad_tiling_interface %matmul_l1 to padding_sizes [8] {
+      padding_values=[0.0: f32, 0.0 : f32, 0.0 : f32],
+      padding_dimensions=[0]
+    } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d0 + d1)>
+module {
+
+// CHECK-LABEL: @generic
+// CHECK-SAME:      %[[T0:.*]]: tensor<7x5xf32>,
+// CHECK-SAME:      %[[T1:.*]]: tensor<7x11x12xf32>)
+  func.func @generic(%arg0: tensor<7x5xf32>, %arg1: tensor<7x11x12xf32>) -> tensor<7x11x12xf32> {
+
+  //  CHECK-DAG: %[[CST:.*]] = arith.constant 0.
+
+  //      CHECK: %[[PAD0:.*]] = tensor.pad %[[T0]] low[0, 0] high[1, 0]
+  //      CHECK:   : tensor<7x5xf32> to tensor<8x5xf32>
+  //      CHECK: %[[PAD1:.*]] = tensor.pad %[[T1]] low[0, 0, 0] high[1, 3, 1] {
+  //      CHECK:   : tensor<7x11x12xf32> to tensor<8x14x13xf32>
+  // CHECK-NEXT: linalg.generic
+  //      CHECK: tensor.extract_slice %{{.*}}[0, 0, 0] [7, 11, 12] [1, 1, 1] : tensor<8x14x13xf32> to tensor<7x11x12xf32>
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<7x5xf32>) outs(%arg1 : tensor<7x11x12xf32>) {
+    ^bb0(%in: f32, %out: f32):
+      linalg.yield %in : f32
+    } -> tensor<7x11x12xf32>
+    return %0 : tensor<7x11x12xf32>
+  }
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+      %padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 14] {
+        padding_dimensions = [0, 2], 
+        padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
+      } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+      transform.yield 
+    }
+  }
+}



More information about the Mlir-commits mailing list