[Mlir-commits] [mlir] [mlir][transform] Add support for transform.param pad multiples in `PadOp` (PR #90755)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 1 17:54:50 PDT 2024


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/90755

>From 6aa9f0db3a39a877a87245206c7562ed27f77657 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 1 May 2024 12:20:02 -0500
Subject: [PATCH 1/5] Add support for transform.param values in `PadOp`s
 pad_to_multiple_of

---
 .../Linalg/TransformOps/LinalgTransformOps.td |  17 ++-
 .../TransformOps/LinalgTransformOps.cpp       | 112 ++++++++++++++++--
 .../mlir/dialects/transform/structured.py     |  20 +++-
 .../test/Dialect/Linalg/transform-op-pad.mlir |   3 +-
 .../dialects/transform_structured_ext.py      |   5 +-
 5 files changed, 140 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index d0ad4ccdf031d9..16fb8f4fcc9466 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1011,7 +1011,9 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
     (ins TransformHandleTypeInterface:$target,
          DefaultValuedAttr<ArrayAttr, "{}">:$padding_values,
          DefaultValuedAttr<I64ArrayAttr, "{}">:$padding_dimensions,
-         OptionalAttr<I64ArrayAttr>:$pad_to_multiple_of,
+         Variadic<TransformAnyParamTypeOrAnyHandle>:$pad_to_multiple_of,
+         DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
+                          $static_pad_to_multiple_of,
          DefaultValuedAttr<I64ArrayAttr, "{}">:$pack_paddings,
          DefaultValuedAttr<
           TypedArrayAttrBase<I64ArrayAttr, "array of arrays of i64">,
@@ -1021,8 +1023,7 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
                       TransformHandleTypeInterface:$pad,
                       TransformHandleTypeInterface:$copy);
 
-  let assemblyFormat =
-    "$target attr-dict `:` functional-type(operands, results)";
+  let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 
   let builders = [
@@ -1033,7 +1034,13 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
     // TODO: support other operations (e.g. min, max etc).
     OpBuilder<(ins "Value":$target,
                    "ArrayRef<int64_t>":$paddingDimensions,
-                   CArg<"ArrayRef<int64_t>", "{}">:$padToMultipleOf,
+                   CArg<"ArrayRef<int64_t>", "{}">:$staticPadToMultipleOf,
+                   CArg<"ArrayRef<int64_t>", "{}">:$packPaddings,
+                   CArg<"ArrayRef<Attribute>", "{}">:$transposePaddings,
+                   CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)>,
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<int64_t>":$paddingDimensions,
+                   "ArrayRef<OpFoldResult>":$mixedPadToMultipleOf,
                    CArg<"ArrayRef<int64_t>", "{}">:$packPaddings,
                    CArg<"ArrayRef<Attribute>", "{}">:$transposePaddings,
                    CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)>
@@ -1043,6 +1050,8 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
     /// copy_back_op attribute value indicating that no copy back is desired.
     static constexpr StringRef kCopyOpNone = "none";
 
+    SmallVector<OpFoldResult> getMixedPadToMultipleOf();
+
     ::mlir::DiagnosedSilenceableFailure applyToOne(
         ::mlir::transform::TransformRewriter &rewriter,
         ::mlir::linalg::LinalgOp target,
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 156784f0e67402..dc060f4c0641cb 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1664,6 +1664,8 @@ transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
 // PadOp
 //===---------------------------------------------------------------------===//
 
+static const StringLiteral kPadToMultipleOfKeyword = "pad_to_multiple_of";
+
 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
                              ArrayRef<int64_t> paddingDimensions,
                              ArrayRef<int64_t> padToMultipleOf,
@@ -1677,14 +1679,111 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
                /*target=*/target,
                /*paddingValues=*/ArrayAttr(), // let inference handle this
                /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
+               /*padToMultipleOf=*/ValueRange{},
                /*padToMultipleOf=*/
-               (padToMultipleOf.empty() ? ArrayAttr()
-                                        : b.getI64ArrayAttr(padToMultipleOf)),
+               (padToMultipleOf.empty()
+                    ? DenseI64ArrayAttr()
+                    : b.getDenseI64ArrayAttr(padToMultipleOf)),
+               /*packPaddings=*/b.getI64ArrayAttr(packPaddings),
+               /*transposePaddings=*/b.getArrayAttr(transposePaddings),
+               /*copyBackOp=*/b.getStringAttr(copyBackOp));
+}
+
+void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
+                             ArrayRef<int64_t> paddingDimensions,
+                             ArrayRef<OpFoldResult> mixedPadToMultipleOf,
+                             ArrayRef<int64_t> packPaddings,
+                             ArrayRef<Attribute> transposePaddings,
+                             StringRef copyBackOp) {
+  auto resultType = transform::AnyOpType::get(b.getContext());
+  SmallVector<int64_t> staticPadToMultipleOf;
+  SmallVector<Value> dynamicPadToMultipleOf;
+  dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
+                             staticPadToMultipleOf);
+  return build(/*builder=*/b,
+               /*result=*/result,
+               /*types=*/TypeRange{resultType, resultType},
+               /*target=*/target,
+               /*paddingValues=*/ArrayAttr(), // let inference handle this
+               /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
+               /*padToMultipleOf=*/dynamicPadToMultipleOf,
+               /*padToMultipleOf=*/staticPadToMultipleOf,
                /*packPaddings=*/b.getI64ArrayAttr(packPaddings),
                /*transposePaddings=*/b.getArrayAttr(transposePaddings),
                /*copyBackOp=*/b.getStringAttr(copyBackOp));
 }
 
+SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
+  OpBuilder b(getContext());
+  return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
+}
+
+ParseResult transform::PadOp::parse(OpAsmParser &parser,
+                                    OperationState &result) {
+  OpAsmParser::UnresolvedOperand target;
+  SmallVector<OpAsmParser::UnresolvedOperand> dynamicPadToMultipleOf;
+  DenseI64ArrayAttr padToMultipleOf;
+  FunctionType functionalType;
+  llvm::SMLoc operandLoc;
+
+  if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc))
+    return ParseResult::failure();
+
+  if (succeeded(parser.parseOptionalKeyword(kPadToMultipleOfKeyword))) {
+    if (failed(parseDynamicIndexList(parser, dynamicPadToMultipleOf,
+                                     padToMultipleOf)))
+      return ParseResult::failure();
+  }
+
+  if (parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColonType(functionalType) ||
+      parser.resolveOperand(target, functionalType.getInputs().front(),
+                            result.operands) ||
+      parser.resolveOperands(dynamicPadToMultipleOf,
+                             functionalType.getInputs().drop_front(),
+                             operandLoc, result.operands))
+    return ParseResult::failure();
+
+  if (padToMultipleOf)
+    result.addAttribute(getStaticPadToMultipleOfAttrName(result.name),
+                        padToMultipleOf);
+
+  result.addTypes(functionalType.getResults());
+
+  return success();
+}
+
+void transform::PadOp::print(OpAsmPrinter &p) {
+  p << ' ' << getTarget() << ' ';
+  if (!getMixedPadToMultipleOf().empty()) {
+    p << kPadToMultipleOfKeyword << ' ';
+    printDynamicIndexList(p, getOperation(), getPadToMultipleOf(),
+                          getStaticPadToMultipleOfAttr(),
+                          /*valueTypes=*/{},
+                          /*scalables=*/{}, OpAsmParser::Delimiter::Square);
+  }
+
+  OpBuilder builder((*this)->getContext());
+  SmallVector<StringRef, 6> elidedAttrs({getStaticPadToMultipleOfAttrName()});
+  if (getCopyBackOpAttr() ==
+      builder.getStringAttr(
+          bufferization::MaterializeInDestinationOp::getOperationName()))
+    elidedAttrs.push_back(getCopyBackOpAttrName());
+  if (getPackPaddingsAttr() == builder.getI64ArrayAttr({}))
+    elidedAttrs.push_back(getPackPaddingsAttrName());
+  if (getTransposePaddingsAttr() == builder.getI64ArrayAttr({}))
+    elidedAttrs.push_back(getTransposePaddingsAttrName());
+  if (getPaddingDimensionsAttr() == builder.getI64ArrayAttr({}))
+    elidedAttrs.push_back(getPaddingDimensionsAttrName());
+  if (getPaddingValuesAttr() == builder.getArrayAttr({}))
+    elidedAttrs.push_back(getPaddingValuesAttrName());
+
+  p.printOptionalAttrDict((*this)->getAttrs(),
+                          /*elidedAttrs=*/elidedAttrs);
+  p << " : ";
+  p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());
+}
+
 DiagnosedSilenceableFailure
 transform::PadOp::apply(transform::TransformRewriter &rewriter,
                         transform::TransformResults &results,
@@ -1750,9 +1849,8 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
     options.paddingDimensions =
         extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
     SmallVector<int64_t> padToMultipleOf(options.paddingDimensions.size(), 1);
-    if (getPadToMultipleOf().has_value())
-      padToMultipleOf =
-          extractFromIntegerArrayAttr<int64_t>(*getPadToMultipleOf());
+    if (!getStaticPadToMultipleOf().empty())
+      padToMultipleOf = llvm::to_vector(getStaticPadToMultipleOf());
     options.padToMultipleOf = padToMultipleOf;
     options.paddingValues = paddingValues;
     options.packPaddings = packPaddings;
@@ -1819,8 +1917,8 @@ LogicalResult transform::PadOp::verify() {
                             "integers, found "
                          << getPaddingDimensions();
   }
-  if (getPadToMultipleOf().has_value()) {
-    if (getPadToMultipleOf()->size() != paddingDimensions.size()) {
+  if (!getMixedPadToMultipleOf().empty()) {
+    if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
       return emitOpError() << "expects as many multiples as padding_dimensions";
     }
   }
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index d7b41c0bd2207d..81bbd6ffb3d403 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -373,10 +373,11 @@ class PadOp(PadOp):
     def __init__(
         self,
         target: Union[Operation, OpView, Value],
+        pad_to_multiple_of: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
         *,
         padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
         padding_dimensions: OptionalIntList = None,
-        pad_to_multiple_of: OptionalIntList = None,
+        static_pad_to_multiple_of: OptionalIntList = None,
         pack_paddings: OptionalIntList = None,
         transpose_paddings: Optional[
             Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
@@ -385,6 +386,20 @@ def __init__(
         loc=None,
         ip=None,
     ):
+        if (
+            static_pad_to_multiple_of is None
+            and pad_to_multiple_of is None
+        ):
+            dynamic_pad_to_multiple_of = []
+        elif static_pad_to_multiple_of is None:
+            (
+                dynamic_pad_to_multiple_of,
+                static_pad_to_multiple_of,
+                _,
+            ) = _dispatch_dynamic_index_list(pad_to_multiple_of)
+        else:
+            dynamic_pad_to_multiple_of = pad_to_multiple_of
+
         transpose_paddings = _get_int_array_array_attr(transpose_paddings)
 
         any_op_type = transform.AnyOpType.get()
@@ -393,9 +408,10 @@ def __init__(
             any_op_type,
             any_op_type,
             target,
+            pad_to_multiple_of=dynamic_pad_to_multiple_of,
             padding_values=padding_values,
             padding_dimensions=padding_dimensions,
-            pad_to_multiple_of=pad_to_multiple_of,
+            static_pad_to_multiple_of=static_pad_to_multiple_of,
             pack_paddings=pack_paddings,
             transpose_paddings=transpose_paddings,
             copy_back_op=copy_back_op,
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index d27276cda49dc4..f82d4500090c5a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -73,10 +73,9 @@ func.func @pad_to_multiple(%arg0: tensor<24x12xf32>,
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %padded, %pad, %copy_back = transform.structured.pad %0 {
+    %padded, %pad, %copy_back = transform.structured.pad %0 pad_to_multiple_of [2, 2, 1] {
       padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32],
       padding_dimensions=[0, 1, 2],
-      pad_to_multiple_of=[2, 2, 1],
       pack_paddings=[1, 1, 0]
     } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
     transform.yield
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 91ecd0fc38e174..418b1216df0532 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -315,9 +315,10 @@ def testPadOpNoArgs(target):
 def testPadOpArgs(target):
     structured.PadOp(
         target,
+        [],
         padding_values=[FloatAttr.get_f32(42.0), StringAttr.get("0")],
         padding_dimensions=Attribute.parse("[1]"),
-        pad_to_multiple_of=[128],
+        static_pad_to_multiple_of=[128],
         pack_paddings=[0],
         transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")],
         copy_back_op="linalg.copy",
@@ -325,9 +326,9 @@ def testPadOpArgs(target):
     # CHECK-LABEL: TEST: testPadOpArgs
     # CHECK: transform.sequence
     # CHECK: transform.structured.pad
+    # CHECK-DAG: pad_to_multiple_of [128]
     # CHECK-DAG: copy_back_op = "linalg.copy"
     # CHECK-DAG: pack_paddings = [0]
-    # CHECK-DAG: pad_to_multiple_of = [128]
     # CHECK-DAG: padding_dimensions = [1]
     # CHECK-DAG: padding_values = [4.200000e+01 : f32, "0"]
     # CHECK-DAG: transpose_paddings = {{\[}}[1, 0], [0, 1]]

>From a1a7b170cff06461421b98d08b0942b60694351f Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 1 May 2024 13:57:34 -0500
Subject: [PATCH 2/5] fix for param pad_to_multiple_of

---
 .../Linalg/TransformOps/LinalgTransformOps.td | 13 ++--
 .../TransformOps/LinalgTransformOps.cpp       | 62 ++++++++++++++++++-
 .../test/Dialect/Linalg/transform-op-pad.mlir | 36 +++++++++++
 3 files changed, 101 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 16fb8f4fcc9466..ada7f7666d5f60 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -978,8 +978,8 @@ def PackTransposeOp : Op<Transform_Dialect, "structured.pack_transpose", [
 //===----------------------------------------------------------------------===//
 
 def PadOp : Op<Transform_Dialect, "structured.pad",
-    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-     DeclareOpInterfaceMethods<TransformOpInterface>,
+    [FunctionalStyleTransformOpTrait, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     TransformOpInterface,
      ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
     Pads the operations pointed to by the target handle using the options
@@ -1052,11 +1052,10 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
 
     SmallVector<OpFoldResult> getMixedPadToMultipleOf();
 
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
+    ::mlir::DiagnosedSilenceableFailure apply(
+      ::mlir::transform::TransformRewriter &rewriter,
+      ::mlir::transform::TransformResults &results,
+      ::mlir::transform::TransformState &state);
   }];
 }
 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index dc060f4c0641cb..c68bc4c1f025ac 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1713,6 +1713,16 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
                /*copyBackOp=*/b.getStringAttr(copyBackOp));
 }
 
+void PadOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTarget(), effects);
+  onlyReadsHandle(getPadToMultipleOf(), effects);
+  producesHandle(getPadded(), effects);
+  producesHandle(getPad(), effects);
+  producesHandle(getCopy(), effects);
+  modifiesPayload(effects);
+}
+
 SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
   OpBuilder b(getContext());
   return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
@@ -1848,9 +1858,55 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
     LinalgPaddingOptions options;
     options.paddingDimensions =
         extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
-    SmallVector<int64_t> padToMultipleOf(options.paddingDimensions.size(), 1);
-    if (!getStaticPadToMultipleOf().empty())
-      padToMultipleOf = llvm::to_vector(getStaticPadToMultipleOf());
+
+    SmallVector<int64_t> padToMultipleOf;
+    for (OpFoldResult sz : getMixedPadToMultipleOf()) {
+      if (sz.is<Attribute>()) {
+        auto attr = sz.get<Attribute>();
+        padToMultipleOf.push_back(cast<IntegerAttr>(attr).getInt());
+        continue;
+      } else if (sz.is<Value>() && isa<ParamType>(sz.get<Value>().getType())) {
+        ArrayRef<Attribute> params = state.getParams(sz.get<Value>());
+        if (params.size() != 1)
+          return emitSilenceableFailure(getLoc()) << "expected a single param";
+        padToMultipleOf.push_back(
+            cast<IntegerAttr>(params.front()).getValue().getSExtValue());
+        continue;
+      }
+
+      auto szPayloads = state.getPayloadOps(sz.get<Value>());
+      if (!llvm::hasSingleElement(szPayloads)) {
+        auto diag = this->emitOpError("requires pad_to_multiple_of handle that "
+                                      "is mapped to 1 payload op");
+        diag.attachNote(sz.get<Value>().getLoc())
+            << "mapped to " << llvm::range_size(szPayloads) << " payload ops";
+        return DiagnosedSilenceableFailure::definiteFailure();
+      }
+
+      Operation *szPayloadOp = *szPayloads.begin();
+      if (szPayloadOp->getNumResults() != 1 ||
+          !szPayloadOp->getResult(0).getType().isIndex()) {
+        auto diag = this->emitOpError(
+            "requires vector pad_to_multiple_of op with 1 index result");
+        diag.attachNote(szPayloadOp->getLoc())
+            << "pad_to_multiple_of payload op";
+        return DiagnosedSilenceableFailure::definiteFailure();
+      }
+
+      IntegerAttr attr;
+      if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) {
+        auto diag = this->emitOpError("requires constant pad_to_multiple_of");
+        diag.attachNote(szPayloadOp->getLoc())
+            << "pad_to_multiple_of payload op";
+        return DiagnosedSilenceableFailure::definiteFailure();
+      }
+
+      padToMultipleOf.push_back(attr.getInt());
+    }
+    if (padToMultipleOf.empty())
+      padToMultipleOf =
+          SmallVector<int64_t>(options.paddingDimensions.size(), 1);
+
     options.padToMultipleOf = padToMultipleOf;
     options.paddingValues = paddingValues;
     options.packPaddings = packPaddings;
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index f82d4500090c5a..47bb5ddf4afc3e 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -86,6 +86,42 @@ module attributes {transform.with_named_sequence} {
 
 #map = affine_map<()[s0] -> (-s0 + 12, 7)>
 
+// CHECK-LABEL: @parametrized_pad_to_multiple
+func.func @parametrized_pad_to_multiple(%arg0: tensor<24x12xf32>,
+                                        %arg1: tensor<12x25xf32>,
+                                        %arg2: tensor<24x25xf32>,
+                                        %iv0 : index, %iv1 : index, %iv2 : index) -> tensor<24x25xf32> {
+  %0 = affine.min #map()[%iv2]
+  %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32>
+  %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor<?x5xf32>
+  %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32>
+
+  //      CHECK: linalg.matmul
+  // CHECK-SAME:     ins(%{{.*}}, %{{.*}} : tensor<4x7xf32>, tensor<7x6xf32>)
+  // CHECK-SAME:     outs(%{{.*}} : tensor<4x6xf32>)
+  %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor<?x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
+  %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
+  func.return %5 : tensor<24x25xf32>
+}
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
+    %padded, %pad, %copy_back = transform.structured.pad %0 pad_to_multiple_of [%c2, 2, 1] {
+      padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32],
+      padding_dimensions=[0, 1, 2],
+      pack_paddings=[1, 1, 0]
+    } : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+#map = affine_map<()[s0] -> (-s0 + 12, 7)>
+
 // CHECK-LABEL: @static_sizes_output_divisible_on_empty_op
 func.func @static_sizes_output_divisible_on_empty_op(%arg0: tensor<24x12xf32>,
     %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>, %iv0: index,

>From 5159eeb464c0e6858056cb877194d2293af8c81c Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 1 May 2024 13:58:54 -0500
Subject: [PATCH 3/5] format

---
 mlir/python/mlir/dialects/transform/structured.py | 5 +----
 1 file changed, 1 insertion(+), 4 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index 81bbd6ffb3d403..4f4a0e598df7d3 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -386,10 +386,7 @@ def __init__(
         loc=None,
         ip=None,
     ):
-        if (
-            static_pad_to_multiple_of is None
-            and pad_to_multiple_of is None
-        ):
+        if static_pad_to_multiple_of is None and pad_to_multiple_of is None:
             dynamic_pad_to_multiple_of = []
         elif static_pad_to_multiple_of is None:
             (

>From 23ef22ab40f4494aa93d6034f4c38714a9ec4a05 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 1 May 2024 14:11:43 -0500
Subject: [PATCH 4/5] cleanup diagnostic messages

---
 .../Linalg/TransformOps/LinalgTransformOps.cpp | 18 +++++++++++-------
 1 file changed, 11 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index c68bc4c1f025ac..c3963aee828565 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1860,6 +1860,7 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
         extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
 
     SmallVector<int64_t> padToMultipleOf;
+    // TODO: This should probably be a common utility function.
     for (OpFoldResult sz : getMixedPadToMultipleOf()) {
       if (sz.is<Attribute>()) {
         auto attr = sz.get<Attribute>();
@@ -1876,8 +1877,9 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
 
       auto szPayloads = state.getPayloadOps(sz.get<Value>());
       if (!llvm::hasSingleElement(szPayloads)) {
-        auto diag = this->emitOpError("requires pad_to_multiple_of handle that "
-                                      "is mapped to 1 payload op");
+        auto diag = this->emitOpError()
+                    << "requires " << kPadToMultipleOfKeyword
+                    << " handle that is mapped to 1 payload op";
         diag.attachNote(sz.get<Value>().getLoc())
             << "mapped to " << llvm::range_size(szPayloads) << " payload ops";
         return DiagnosedSilenceableFailure::definiteFailure();
@@ -1886,18 +1888,20 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
       Operation *szPayloadOp = *szPayloads.begin();
       if (szPayloadOp->getNumResults() != 1 ||
           !szPayloadOp->getResult(0).getType().isIndex()) {
-        auto diag = this->emitOpError(
-            "requires vector pad_to_multiple_of op with 1 index result");
+        auto diag = this->emitOpError()
+                    << "requires " << kPadToMultipleOfKeyword
+                    << " to be result of op with 1 index result";
         diag.attachNote(szPayloadOp->getLoc())
-            << "pad_to_multiple_of payload op";
+            << kPadToMultipleOfKeyword << " payload op";
         return DiagnosedSilenceableFailure::definiteFailure();
       }
 
       IntegerAttr attr;
       if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) {
-        auto diag = this->emitOpError("requires constant pad_to_multiple_of");
+        auto diag = this->emitOpError()
+                    << "requires constant " << kPadToMultipleOfKeyword;
         diag.attachNote(szPayloadOp->getLoc())
-            << "pad_to_multiple_of payload op";
+            << kPadToMultipleOfKeyword << " payload op";
         return DiagnosedSilenceableFailure::definiteFailure();
       }
 

>From 2dade05508512e565feccac381b185e749b3f735 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Wed, 1 May 2024 19:36:40 -0500
Subject: [PATCH 5/5] refactor paramhandle reification

---
 .../TransformOps/LinalgTransformOps.cpp       | 143 +++++++-----------
 1 file changed, 54 insertions(+), 89 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index c3963aee828565..01d4d2a033830c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -171,6 +171,50 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
   return DiagnosedSilenceableFailure::success();
 }
 
+static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(
+    TransformState &state, TransformOpInterface &transformOp,
+    const SmallVectorImpl<OpFoldResult> &mixedResults,
+    SmallVectorImpl<int64_t> &reified) {
+  for (OpFoldResult paramOrHandle : mixedResults) {
+    if (isa<Attribute>(paramOrHandle)) {
+      reified.push_back(
+          cast<IntegerAttr>(paramOrHandle.get<Attribute>()).getInt());
+      continue;
+    } else if (isa<Value>(paramOrHandle) &&
+               isa<ParamType>(paramOrHandle.get<Value>().getType())) {
+      ArrayRef<Attribute> params = state.getParams(paramOrHandle.get<Value>());
+      if (params.size() != 1)
+        return transformOp.emitDefiniteFailure() << "expected a single param";
+      reified.push_back(
+          cast<IntegerAttr>(params.front()).getValue().getSExtValue());
+      continue;
+    }
+
+    auto paramOrHandlePayloads =
+        state.getPayloadOps(paramOrHandle.get<Value>());
+    if (!llvm::hasSingleElement(paramOrHandlePayloads))
+      return transformOp.emitDefiniteFailure()
+             << "requires param or handle that is mapped to 1 payload op";
+
+    Operation *paramOrHandlePayloadOp = *paramOrHandlePayloads.begin();
+    if (paramOrHandlePayloadOp->getNumResults() != 1 ||
+        !paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
+      return transformOp.emitDefiniteFailure()
+             << "requires param or handle to be result of op with 1 index "
+                "result";
+    }
+
+    IntegerAttr attr;
+    if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
+      return transformOp.emitDefiniteFailure()
+             << "requires param or handle to be the result of a constant like "
+                "op";
+
+    reified.push_back(attr.getInt());
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // Apply...PatternsOp
 //===----------------------------------------------------------------------===//
@@ -1798,6 +1842,7 @@ DiagnosedSilenceableFailure
 transform::PadOp::apply(transform::TransformRewriter &rewriter,
                         transform::TransformResults &results,
                         transform::TransformState &state) {
+  auto transformOp = cast<TransformOpInterface>(getOperation());
   SmallVector<Operation *> paddedOps, padOps, copyBackOps;
 
   for (Operation *target : state.getPayloadOps(getTarget())) {
@@ -1860,53 +1905,10 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
         extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
 
     SmallVector<int64_t> padToMultipleOf;
-    // TODO: This should probably be a common utility function.
-    for (OpFoldResult sz : getMixedPadToMultipleOf()) {
-      if (sz.is<Attribute>()) {
-        auto attr = sz.get<Attribute>();
-        padToMultipleOf.push_back(cast<IntegerAttr>(attr).getInt());
-        continue;
-      } else if (sz.is<Value>() && isa<ParamType>(sz.get<Value>().getType())) {
-        ArrayRef<Attribute> params = state.getParams(sz.get<Value>());
-        if (params.size() != 1)
-          return emitSilenceableFailure(getLoc()) << "expected a single param";
-        padToMultipleOf.push_back(
-            cast<IntegerAttr>(params.front()).getValue().getSExtValue());
-        continue;
-      }
-
-      auto szPayloads = state.getPayloadOps(sz.get<Value>());
-      if (!llvm::hasSingleElement(szPayloads)) {
-        auto diag = this->emitOpError()
-                    << "requires " << kPadToMultipleOfKeyword
-                    << " handle that is mapped to 1 payload op";
-        diag.attachNote(sz.get<Value>().getLoc())
-            << "mapped to " << llvm::range_size(szPayloads) << " payload ops";
-        return DiagnosedSilenceableFailure::definiteFailure();
-      }
-
-      Operation *szPayloadOp = *szPayloads.begin();
-      if (szPayloadOp->getNumResults() != 1 ||
-          !szPayloadOp->getResult(0).getType().isIndex()) {
-        auto diag = this->emitOpError()
-                    << "requires " << kPadToMultipleOfKeyword
-                    << " to be result of op with 1 index result";
-        diag.attachNote(szPayloadOp->getLoc())
-            << kPadToMultipleOfKeyword << " payload op";
-        return DiagnosedSilenceableFailure::definiteFailure();
-      }
-
-      IntegerAttr attr;
-      if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) {
-        auto diag = this->emitOpError()
-                    << "requires constant " << kPadToMultipleOfKeyword;
-        diag.attachNote(szPayloadOp->getLoc())
-            << kPadToMultipleOfKeyword << " payload op";
-        return DiagnosedSilenceableFailure::definiteFailure();
-      }
-
-      padToMultipleOf.push_back(attr.getInt());
-    }
+    DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
+        state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
+    if (!status.succeeded())
+      return status;
     if (padToMultipleOf.empty())
       padToMultipleOf =
           SmallVector<int64_t>(options.paddingDimensions.size(), 1);
@@ -3362,49 +3364,12 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
   auto targets = state.getPayloadOps(getTarget());
   if (std::empty(targets))
     return DiagnosedSilenceableFailure::success();
-
+  auto transformOp = cast<TransformOpInterface>(getOperation());
   SmallVector<int64_t> vectorSizes;
-  for (OpFoldResult sz : getMixedVectorSizes()) {
-    if (sz.is<Attribute>()) {
-      auto attr = sz.get<Attribute>();
-      vectorSizes.push_back(cast<IntegerAttr>(attr).getInt());
-      continue;
-    } else if (sz.is<Value>() && isa<ParamType>(sz.get<Value>().getType())) {
-      ArrayRef<Attribute> params = state.getParams(sz.get<Value>());
-      if (params.size() != 1)
-        return emitSilenceableFailure(getLoc()) << "expected a single param";
-      vectorSizes.push_back(
-          cast<IntegerAttr>(params.front()).getValue().getSExtValue());
-      continue;
-    }
-
-    auto szPayloads = state.getPayloadOps(sz.get<Value>());
-    if (!llvm::hasSingleElement(szPayloads)) {
-      auto diag = this->emitOpError(
-          "requires vector size handle that is mapped to 1 payload op");
-      diag.attachNote(sz.get<Value>().getLoc())
-          << "mapped to " << llvm::range_size(szPayloads) << " payload ops";
-      return DiagnosedSilenceableFailure::definiteFailure();
-    }
-
-    Operation *szPayloadOp = *szPayloads.begin();
-    if (szPayloadOp->getNumResults() != 1 ||
-        !szPayloadOp->getResult(0).getType().isIndex()) {
-      auto diag = this->emitOpError(
-          "requires vector size payload op with 1 index result");
-      diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
-      return DiagnosedSilenceableFailure::definiteFailure();
-    }
-
-    IntegerAttr attr;
-    if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) {
-      auto diag = this->emitOpError("requires constant vector size");
-      diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
-      return DiagnosedSilenceableFailure::definiteFailure();
-    }
-
-    vectorSizes.push_back(attr.getInt());
-  }
+  DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
+      state, transformOp, getMixedVectorSizes(), vectorSizes);
+  if (!status.succeeded())
+    return status;
 
   // TODO: Check that the correct number of vectorSizes was provided.
   for (Operation *target : targets) {



More information about the Mlir-commits mailing list