[Mlir-commits] [mlir] [mlir][Transforms] Add a PadTilingInterface transformation and hook i… (PR #144991)

Kunwar Grover llvmlistbot at llvm.org
Fri Jun 20 01:32:16 PDT 2025


================
@@ -2155,6 +2156,166 @@ LogicalResult transform::PadOp::verify() {
   return success();
 }
 
+//===---------------------------------------------------------------------===//
+// PadTilingInterfaceOp
+//===---------------------------------------------------------------------===//
+
+void transform::PadTilingInterfaceOp::build(OpBuilder &b,
+                                            OperationState &result,
+                                            Value target,
+                                            ArrayRef<int64_t> paddingDimensions,
+                                            ArrayRef<int64_t> paddingSizes,
+                                            bool padToMultipleOf) {
+  auto resultType = transform::AnyOpType::get(b.getContext());
+  return build(/*builder=*/b,
+               /*result=*/result,
+               /*types=*/TypeRange{resultType, resultType},
+               /*target=*/target,
+               /*paddingValues=*/ArrayAttr(), // let inference handle this
+               /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
+               /*paddingSizes=*/ValueRange{},
+               /*paddingSizes=*/
+               (paddingSizes.empty() ? DenseI64ArrayAttr()
+                                     : b.getDenseI64ArrayAttr(paddingSizes)),
+               /*padToMultipleOf=*/
+               padToMultipleOf ? b.getUnitAttr() : nullptr);
+}
+
+void transform::PadTilingInterfaceOp::build(
+    OpBuilder &b, OperationState &result, Value target,
+    ArrayRef<int64_t> paddingDimensions,
+    ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
+  auto resultType = transform::AnyOpType::get(b.getContext());
+  SmallVector<int64_t> staticPaddingSizes;
+  SmallVector<Value> dynamicPaddingSizes;
+  dispatchIndexOpFoldResults(mixedPaddingSizes, dynamicPaddingSizes,
+                             staticPaddingSizes);
+  return build(/*builder=*/b,
+               /*result=*/result,
+               /*types=*/TypeRange{resultType, resultType},
+               /*target=*/target,
+               /*paddingValues=*/ArrayAttr(), // let inference handle this
+               /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
+               /*paddingSizes=*/dynamicPaddingSizes,
+               /*paddingSizes=*/staticPaddingSizes,
+               /*usePrescribedTensorShapes=*/padToMultipleOf);
+}
+
+void transform::PadTilingInterfaceOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getPaddingSizesMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  modifiesPayload(effects);
+}
+
+SmallVector<OpFoldResult>
+transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
+  Builder b(getContext());
+  return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
+}
+
+DiagnosedSilenceableFailure
+transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
+                                       transform::TransformResults &results,
+                                       transform::TransformState &state) {
+  SmallVector<Operation *> paddedOps, padOps;
+
+  for (Operation *target : state.getPayloadOps(getTarget())) {
+    auto targetOp = dyn_cast<TilingInterface>(target);
+    if (!targetOp) {
+      auto diag = emitSilenceableError() << "expected TilingInterface target";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+
+    // Only Linalg ops for now, until TilingInterface exposes a loopsToOperand
+    // map / C++ APIs to compute the effect of padding on operands.
+    if (!isa<LinalgOp>(targetOp.getOperation())) {
+      auto diag = emitSilenceableError() << "only LinalgOp supported atm";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
+
+    // Convert the padding values to attributes.
+    SmallVector<Attribute> paddingValues;
+    for (auto const &it :
+         llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
----------------
Groverkss wrote:

nit: use structured bindings:

auto [attr, elementType] = llvm::zip...

https://github.com/llvm/llvm-project/pull/144991


More information about the Mlir-commits mailing list