[Mlir-commits] [mlir] 88c5027 - [mlir] make multi-size tiling use transform parameters

Alex Zinenko llvmlistbot at llvm.org
Thu Jan 19 02:19:46 PST 2023


Author: Alex Zinenko
Date: 2023-01-19T10:19:37Z
New Revision: 88c5027b93a9f447a8b3ce02e5d74f1c10c14da1

URL: https://github.com/llvm/llvm-project/commit/88c5027b93a9f447a8b3ce02e5d74f1c10c14da1
DIFF: https://github.com/llvm/llvm-project/commit/88c5027b93a9f447a8b3ce02e5d74f1c10c14da1.diff

LOG: [mlir] make multi-size tiling use transform parameters

Use the recently introduced transform dialect parameter mechanism to
perform controllable multi-size tiling with sizes computed at the
transformation time rather than at runtime.

This requires to generalize tile and split structured transform
operations to work with any transform dialect handle types, which is
desirable in itself to avoid unchecked overuse of PDL OperationType.

Reviewed By: shabalin

Differential Revision: https://reviews.llvm.org/D140980

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/python/mlir/dialects/_structured_transform_ops_ext.py
    mlir/test/Dialect/LLVM/transform-e2e.mlir
    mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
    mlir/test/Dialect/Linalg/promotion_options.mlir
    mlir/test/Dialect/Linalg/tile-conv.mlir
    mlir/test/Dialect/Linalg/tile-indexed.mlir
    mlir/test/Dialect/Linalg/tile-tensors.mlir
    mlir/test/Dialect/Linalg/transform-op-fuse.mlir
    mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir
    mlir/test/Dialect/Linalg/transform-op-scalarize.mlir
    mlir/test/Dialect/Linalg/transform-op-split.mlir
    mlir/test/Dialect/Linalg/transform-op-tile.mlir
    mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
    mlir/test/Dialect/Linalg/transform-ops.mlir
    mlir/test/Dialect/Linalg/transform-patterns.mlir
    mlir/test/Dialect/Transform/selective-targeting.mlir
    mlir/test/Dialect/Vector/transform-vector.mlir
    mlir/test/python/dialects/transform_structured_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 8ca661e9c9455..f6c601f73df8e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -12,6 +12,7 @@
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Dialect/PDL/IR/PDLTypes.td"
 include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -19,6 +20,13 @@ include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 include "mlir/IR/RegionKindInterface.td"
 
+// This is roughly similar to OpFoldResult assuming the handle produces a single
+// value in the payload IR.
+def TransformParamTypeOrAnyHandle : Type<
+    Or<[TransformHandleTypeInterface.predicate,
+        Transform_ParamType.predicate]>,
+    "transform 'param' type or any handle type">;
+
 //===----------------------------------------------------------------------===//
 // DecomposeOp
 //===----------------------------------------------------------------------===//
@@ -311,27 +319,41 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
     ```mlir
     %sz1, %sz2, %split = structured.multitile_sizes %target
                          { target_size = 10, dimension = 1 }
+                       : !transform.any_op, !transform.param<i64>,
+                         !transform.param<i64>, !transform.param<i64>
     %low, %high = structured.split %target after %split { dimension = 1 }
-    %tiled_low = structured.tile %low [0, %sz1]
-    %tiled_high = structured.tile %high [0, %sz2]
-    %common = merge_handles %tiled_low, %tiled_high
+                : !transform.any_op, !transform.param<i64>
+    %tiled_low, %loop1 = structured.tile %low [0, %sz1]
+                       : (!transform.any_op, !transform.param<i64>)
+                      -> (!transform.any_op, !transform.any_op)
+    %tiled_high, %loop2 = structured.tile %high [0, %sz2]
+                        : (!transform.any_op, !transform.param<i64>)
+                       -> (!transform.any_op, !transform.any_op)
+    %common = merge_handles %tiled_low, %tiled_high : !transform.any_op
 
     %sz3, %sz4, %split = structured.multitile_size %target
                          { target_size = 42, dimension = 0 }
+                       : !transform.any_op, !transform.any_op,
+                         !transform.any_op, !transform.any_op
     %sz3r, %sz4r, %splitr = replicate num(%common) %sz3, %sz4, %splitr
+             : !transform.any_op, !transform.any_op, !transform.any_op
     structured.split %common after %splitr { dimension = 0 }
+             : !transform.any_op, !transform.any_op
     // ...
     ```
   }];
 
-  let arguments = (ins PDL_Operation:$target,
+  let arguments = (ins TransformHandleTypeInterface:$target,
                        I64Attr:$dimension,
                        I64Attr:$target_size,
                        DefaultValuedAttr<I64Attr, "1">:$divisor);
-  let results = (outs PDL_Operation:$low_size,
-                      PDL_Operation:$high_size,
-                      PDL_Operation:$split_point);
-  let assemblyFormat = "$target attr-dict";
+  let results = (outs TransformParamTypeOrAnyHandle:$low_size,
+                      TransformParamTypeOrAnyHandle:$high_size,
+                      TransformParamTypeOrAnyHandle:$split_point);
+  let hasVerifier = 1;
+  let assemblyFormat =
+    "$target attr-dict `:` custom<MultitileSizesTypes>("
+    "type($target), type($low_size), type($high_size), type($split_point))";
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -608,11 +630,12 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
     iteration space indices.
   }];
 
-  let arguments = (ins PDL_Operation:$target,
+  let arguments = (ins TransformHandleTypeInterface:$target,
                        I64Attr:$dimension,
-                       Optional<PDL_Operation>:$dynamic_split_point,
+                       Optional<TransformParamTypeOrAnyHandle>:$dynamic_split_point,
                        I64Attr:$static_split_point);
-  let results = (outs PDL_Operation:$first, PDL_Operation:$second);
+  let results = (outs TransformHandleTypeInterface:$first,
+                      TransformHandleTypeInterface:$second);
   let hasVerifier = 1;
   let hasCustomAssemblyFormat = 1;
 }
@@ -1046,19 +1069,28 @@ def TileOp : Op<Transform_Dialect, "structured.tile",
     produces a definite failure.
   }];
 
-  let arguments = (ins PDL_Operation:$target,
-                   Variadic<PDL_Operation>:$dynamic_sizes,
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                   Variadic<TransformParamTypeOrAnyHandle>:$dynamic_sizes,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange);
-  let results = (outs PDL_Operation:$tiled_linalg_op,
-                      Variadic<PDL_Operation>:$loops);
+  let results = (outs TransformHandleTypeInterface:$tiled_linalg_op,
+                      Variadic<TransformHandleTypeInterface>:$loops);
   let builders = [
+    OpBuilder<(ins "TypeRange":$loopTypes,
+                   "Value":$target,
+                   "ArrayRef<int64_t>":$staticTileSizes,
+                   CArg<"ArrayRef<int64_t>", "{}">:$interchange)>,
+    OpBuilder<(ins "TypeRange":$loopTypes,
+                   "Value":$target,
+                   "ArrayRef<OpFoldResult>":$mixedTileSizes,
+                   CArg<"ArrayRef<int64_t>", "{}">:$interchange)>,
     OpBuilder<(ins "Value":$target,
                    "ArrayRef<int64_t>":$staticTileSizes,
                    CArg<"ArrayRef<int64_t>", "{}">:$interchange)>,
     OpBuilder<(ins "Value":$target,
                    "ArrayRef<OpFoldResult>":$mixedTileSizes,
                    CArg<"ArrayRef<int64_t>", "{}">:$interchange)>
+
   ];
 
   let hasCustomAssemblyFormat = 1;

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index b4692776a61bd..286368550aaad 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -416,15 +416,23 @@ makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
                     ArrayRef<OpFoldResult> allShapeSizes,
                     ArrayRef<OpFoldResult> allTileSizes);
 
-/// A description of a multi-size tiling comprising tile sizes and numbers of
-/// tiles, expressed as Values which may or may not be constant. Multi-size
-/// currently means two-size.
-struct MultiSizeSpecification {
+namespace detail {
+template <typename T>
+struct MultiSizeSpecificationBase {
   /// Tile sizes.
-  Value lowTileSize, highTileSize;
+  T lowTileSize, highTileSize;
   /// Number of tiles associated with each size.
-  Value lowTripCount, highTripCount;
+  T lowTripCount, highTripCount;
 };
+} // namespace detail
+
+/// A description of a multi-size tiling comprising tile sizes and numbers of
+/// tiles, expressed as Values which may or may not be constant. Multi-size
+/// currently means two-size.
+struct MultiSizeSpecification
+    : public detail::MultiSizeSpecificationBase<Value> {};
+struct StaticMultiSizeSpecification
+    : public detail::MultiSizeSpecificationBase<int64_t> {};
 
 /// Emits the IR computing the multi-sized tiling specification with two tile
 /// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such
@@ -457,6 +465,9 @@ FailureOr<MultiSizeSpecification>
 computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension,
                       OpFoldResult targetSize, OpFoldResult divisor,
                       bool emitAssertions = true);
+FailureOr<StaticMultiSizeSpecification>
+computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
+                            int64_t divisor);
 
 /// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`, applying
 /// tiling by `numThreads`.

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
index f4d66c5f7fa67..22c0c94b27606 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td
@@ -153,9 +153,13 @@ def TransformParamTypeInterface
     underlying type. A user of the value can assume that the parameter has been
     verified.
   }];
-
 }
 
+def Transform_AnyHandleOrParamType
+  : Type<Or<[TransformParamTypeInterface.predicate,
+             TransformHandleTypeInterface.predicate]>,
+         "any transform handle or parameter">;
+
 def FunctionalStyleTransformOpTrait
     : NativeOpTrait<"FunctionalStyleTransformOpTrait"> {
   let cppNamespace = "::mlir::transform";

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index d72d38a365b7b..699929fa17f0a 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -371,8 +371,8 @@ def ReplicateOp : TransformDialectOp<"replicate",
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$pattern,
-                       Variadic<TransformHandleTypeInterface>:$handles);
-  let results = (outs Variadic<TransformHandleTypeInterface>:$replicated);
+                       Variadic<Transform_AnyHandleOrParamType>:$handles);
+  let results = (outs Variadic<Transform_AnyHandleOrParamType>:$replicated);
   let assemblyFormat = "`num` `(` $pattern `)` $handles attr-dict `:` "
                        "type($pattern) `,` type($handles)";
 }

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index c2eafcf5048f7..b43469b08b9d0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformUtils.h"
 #include "mlir/Dialect/Transform/Utils/Utils.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -771,9 +772,65 @@ transform::MatchOp::apply(transform::TransformResults &results,
 // MultiTileSizesOp
 //===---------------------------------------------------------------------===//
 
+static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op,
+                                     Type targetType, Type lowSizeType, Type,
+                                     Type) {
+  printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
+}
+
+static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
+                                            Type &targetType, Type &lowSizeType,
+                                            Type &highSizeType,
+                                            Type &splitPointType) {
+  FunctionType funcType;
+  llvm::SMLoc typeLoc = parser.getCurrentLocation();
+  if (failed(parser.parseType<FunctionType>(funcType)))
+    return failure();
+
+  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
+    parser.emitError(typeLoc) << "expects a trailing functional type with one "
+                                 "argument and one result";
+  }
+  targetType = funcType.getInput(0);
+  lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
+
+  return success();
+}
+
 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
     LinalgOp target, transform::ApplyToEachResultList &results,
     TransformState &state) {
+  if (getLowSize().getType().isa<TransformParamTypeInterface>()) {
+    if (target.hasDynamicShape()) {
+      results.assign(
+          ArrayRef<Attribute>({Attribute(), Attribute(), Attribute()}));
+      auto diag = emitSilenceableError()
+                  << "cannot compute parametric tile sizes for dynamically "
+                     "shaped payload op";
+      diag.attachNote(target->getLoc()) << "payload op";
+      return diag;
+    }
+
+    FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes(
+        target, getDimension(), getTargetSize(), getDivisor());
+    if (failed(spec)) {
+      results.assign(
+          ArrayRef<Attribute>({Attribute(), Attribute(), Attribute()}));
+      return emitSilenceableError()
+             << "failed to compute multi-size tiling sizes";
+    }
+
+    Builder builder(target.getContext());
+    results.assign(llvm::map_range(
+        ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
+                           spec->lowTileSize * spec->lowTripCount}),
+        [&builder, this](int64_t value) {
+          return builder.getIntegerAttr(
+              getLowSize().getType().cast<ParamType>().getType(), value);
+        }));
+    return DiagnosedSilenceableFailure::success();
+  }
+
   OpBuilder builder(target.getContext());
   builder.setInsertionPoint(target);
   OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
@@ -804,7 +861,18 @@ void transform::MultiTileSizesOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   onlyReadsHandle(getTarget(), effects);
   producesHandle(getResults(), effects);
-  modifiesPayload(effects);
+  if (getLowSize().getType().isa<TransformParamTypeInterface>())
+    onlyReadsPayload(effects);
+  else
+    modifiesPayload(effects);
+}
+
+LogicalResult transform::MultiTileSizesOp::verify() {
+  if (getLowSize().getType() != getHighSize().getType() ||
+      getLowSize().getType() != getSplitPoint().getType()) {
+    return emitOpError() << "expects all results type to be the same";
+  }
+  return success();
 }
 
 //===---------------------------------------------------------------------===//
@@ -1406,17 +1474,23 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
   splitPoints.reserve(payload.size());
   if (getDynamicSplitPoint()) {
     auto diag = DiagnosedSilenceableFailure::success();
-    splitPoints = llvm::to_vector(llvm::map_range(
-        state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
-          if (op->getNumResults() != 1 ||
-              !op->getResult(0).getType().isIndex()) {
-            diag = emitSilenceableError()
-                   << "expected dynamic split point handle to point to a "
-                      "single-result index-typed op";
-            diag.attachNote(op->getLoc()) << "dynamic split point";
-          }
-          return OpFoldResult(op->getResult(0));
-        }));
+    if (getDynamicSplitPoint().getType().isa<TransformHandleTypeInterface>()) {
+      splitPoints = llvm::to_vector(llvm::map_range(
+          state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
+            if (op->getNumResults() != 1 ||
+                !op->getResult(0).getType().isIndex()) {
+              diag = emitSilenceableError()
+                     << "expected dynamic split point handle to point to a "
+                        "single-result index-typed op";
+              diag.attachNote(op->getLoc()) << "dynamic split point";
+            }
+            return OpFoldResult(op->getResult(0));
+          }));
+    } else {
+      splitPoints = llvm::to_vector(
+          llvm::map_range(state.getParams(getDynamicSplitPoint()),
+                          [](Attribute attr) { return OpFoldResult(attr); }));
+    }
     if (diag.isSilenceableFailure()) {
       results.set(getFirst().cast<OpResult>(), {});
       results.set(getSecond().cast<OpResult>(), {});
@@ -1507,11 +1581,7 @@ void SplitOp::getEffects(
 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
   OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
   IntegerAttr staticSplitPoint;
-  auto pdlOperationType =
-      pdl::OperationType::get(parser.getBuilder().getContext());
-  if (parser.parseOperand(target) ||
-      parser.resolveOperand(target, pdlOperationType, result.operands) ||
-      parser.parseKeyword("after"))
+  if (parser.parseOperand(target) || parser.parseKeyword("after"))
     return failure();
 
   OptionalParseResult dynamicPointParseResult =
@@ -1523,9 +1593,19 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
 
     staticSplitPoint =
         parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
-  } else {
-    if (failed(*dynamicPointParseResult) ||
-        parser.resolveOperand(dynamicSplitPoint, pdlOperationType,
+  }
+
+  Type targetType;
+  if (parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColonType(targetType) ||
+      parser.resolveOperand(target, targetType, result.operands)) {
+    return failure();
+  }
+  if (dynamicPointParseResult.has_value()) {
+    Type splitPointType;
+    if (failed(*dynamicPointParseResult) || parser.parseComma() ||
+        parser.parseType(splitPointType) ||
+        parser.resolveOperand(dynamicSplitPoint, splitPointType,
                               result.operands)) {
       return failure();
     }
@@ -1537,10 +1617,7 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
   result.addAttribute(
       SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
       staticSplitPoint);
-  if (failed(parser.parseOptionalAttrDict(result.attributes)))
-    return failure();
-
-  result.addTypes({pdlOperationType, pdlOperationType});
+  result.addTypes({targetType, targetType});
   return success();
 }
 
@@ -1554,6 +1631,9 @@ void SplitOp::print(OpAsmPrinter &printer) {
   printer << " ";
   printer.printOptionalAttrDict(getOperation()->getAttrs(),
                                 {getStaticSplitPointAttrName()});
+  printer << " : " << getTarget().getType();
+  if (staticSplitSize == ShapedType::kDynamic)
+    printer << ", " << getDynamicSplitPoint().getType();
 }
 
 LogicalResult SplitOp::verify() {
@@ -1716,31 +1796,58 @@ transform::TileReductionUsingForeachThreadOp::applyToOne(
 //===----------------------------------------------------------------------===//
 // TileOp
 //===----------------------------------------------------------------------===//
+
 void transform::TileOp::build(OpBuilder &builder, OperationState &result,
-                              Value target, ArrayRef<int64_t> staticTileSizes,
+                              TypeRange loopTypes, Value target,
+                              ArrayRef<int64_t> staticTileSizes,
                               ArrayRef<int64_t> interchange) {
-  return build(builder, result,
+  return build(builder, result, loopTypes,
                /*target=*/target,
                /*mixedTileSizes=*/
                getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
                interchange);
 }
 
+void transform::TileOp::build(OpBuilder &builder, OperationState &result,
+                              Value target, ArrayRef<int64_t> staticTileSizes,
+                              ArrayRef<int64_t> interchange) {
+  build(builder, result, target,
+        getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
+        interchange);
+}
+
 void transform::TileOp::build(OpBuilder &builder, OperationState &result,
                               Value target,
                               ArrayRef<OpFoldResult> mixedTileSizes,
                               ArrayRef<int64_t> interchange) {
+  // Loop types are automaticaly splat by the callee, setting up one is enough.
+  SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
+  build(builder, result, loopTypes, target, mixedTileSizes, interchange);
+}
+
+void transform::TileOp::build(OpBuilder &builder, OperationState &result,
+                              TypeRange loopTypes, Value target,
+                              ArrayRef<OpFoldResult> mixedTileSizes,
+                              ArrayRef<int64_t> interchange) {
   SmallVector<int64_t> staticTileSizes;
   SmallVector<Value> dynamicTileSizes;
   dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
   // Call the default builder which sets up the proper operands segment sizes
   // attributes for multiple variadic operands. In the absence of this, horrible
   // bugs ensue.
-  MLIRContext *ctx = builder.getContext();
-  auto operationType = pdl::OperationType::get(ctx);
   auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
-  build(builder, result,
-        /*resultTypes=*/TypeRange{operationType, operationType},
+  unsigned numExpectedLoops =
+      staticTileSizes.size() - llvm::count(staticTileSizes, 0);
+  SmallVector<Type> resultTypes;
+  resultTypes.reserve(numExpectedLoops);
+  assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
+         "expected one loop type or as many as loops");
+  if (loopTypes.size() == 1)
+    resultTypes.append(numExpectedLoops, loopTypes[0]);
+  else
+    llvm::append_range(resultTypes, loopTypes);
+  build(builder, result, /*tiled_linalg_op=*/target.getType(),
+        /*loops=*/resultTypes,
         /*target=*/target,
         /*dynamic_sizes=*/dynamicTileSizes,
         /*static_sizes=*/staticTileSizesAttr,
@@ -1754,18 +1861,44 @@ transform::TileOp::apply(TransformResults &transformResults,
 
   ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
   SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
+  SmallVector<SmallVector<int64_t>> paramSizes;
   dynamicSizeProducers.reserve(getDynamicSizes().size());
-  for (Value dynamicSizeProducerHandle : getDynamicSizes()) {
-    dynamicSizeProducers.push_back(
-        state.getPayloadOps(dynamicSizeProducerHandle));
+  paramSizes.reserve(getDynamicSizes().size());
+  for (Value transformValue : getDynamicSizes()) {
+    if (transformValue.getType().isa<ParamType>()) {
+      dynamicSizeProducers.push_back({});
+      ArrayRef<Attribute> params = state.getParams(transformValue);
+      paramSizes.push_back(
+          llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
+            return attr.cast<IntegerAttr>().getValue().getSExtValue();
+          })));
+
+      if (paramSizes.back().size() != targets.size()) {
+        for (OpResult r : getResults())
+          transformResults.set(r, {});
+        DiagnosedSilenceableFailure diag =
+            emitSilenceableError()
+            << "expected as many parameter values ("
+            << dynamicSizeProducers.back().size() << ") as target ops ("
+            << targets.size() << ")";
+        diag.attachNote(transformValue.getLoc()) << "for this parameter";
+        return diag;
+      }
+
+      continue;
+    }
+    paramSizes.push_back({});
+    dynamicSizeProducers.push_back(state.getPayloadOps(transformValue));
 
     if (dynamicSizeProducers.back().size() != targets.size()) {
+      for (OpResult r : getResults())
+        transformResults.set(r, {});
       DiagnosedSilenceableFailure diag =
           emitSilenceableError()
           << "expected as many dynamic size-producing operations ("
           << dynamicSizeProducers.back().size() << ") as target ops ("
           << targets.size() << ")";
-      diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
+      diag.attachNote(transformValue.getLoc()) << "for this handle";
       return diag;
     }
 
@@ -1773,11 +1906,14 @@ transform::TileOp::apply(TransformResults &transformResults,
       if (op->getNumResults() == 1 &&
           op->getResult(0).getType().isa<IndexType>())
         continue;
+
+      for (OpResult r : getResults())
+        transformResults.set(r, {});
       DiagnosedSilenceableFailure diag =
           emitSilenceableError() << "expected sizes to be produced by ops "
                                     "with a single index-type result";
       diag.attachNote(op->getLoc()) << "size producer op";
-      diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
+      diag.attachNote(transformValue.getLoc()) << "for this handle";
       return diag;
     }
   }
@@ -1806,9 +1942,19 @@ transform::TileOp::apply(TransformResults &transformResults,
               if (auto attr = ofr.dyn_cast<Attribute>()) {
                 sizes.push_back(b.create<arith::ConstantIndexOp>(
                     getLoc(), attr.cast<IntegerAttr>().getInt()));
-              } else {
+                continue;
+              }
+              ArrayRef<Operation *> dynamicSizes =
+                  dynamicSizeProducers[dynamicIdx];
+              ArrayRef<int64_t> params = paramSizes[dynamicIdx];
+              ++dynamicIdx;
+              assert((dynamicSizes.empty() ^ params.empty()) &&
+                     "expected either dynamic sizes or parameters");
+              if (!params.empty()) {
                 sizes.push_back(
-                    dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
+                    b.create<arith::ConstantIndexOp>(getLoc(), params[index]));
+              } else {
+                sizes.push_back(dynamicSizes[index]->getResult(0));
               }
             }
             return sizes;
@@ -1890,20 +2036,34 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
   OpAsmParser::UnresolvedOperand target;
   SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
   DenseI64ArrayAttr staticSizes;
-  auto pdlOperationType = pdl::OperationType::get(parser.getContext());
-  if (parser.parseOperand(target) ||
-      parser.resolveOperand(target, pdlOperationType, result.operands) ||
+  FunctionType functionalType;
+  llvm::SMLoc operandLoc;
+  if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) ||
       parseDynamicIndexList(parser, dynamicSizes, staticSizes) ||
-      parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands))
+      parseOptionalInterchange(parser, result) ||
+      parser.parseColonType(functionalType))
     return ParseResult::failure();
 
-  // Parse optional interchange.
-  if (failed(parseOptionalInterchange(parser, result)))
-    return ParseResult::failure();
-  result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
   size_t numExpectedLoops =
       staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0);
-  result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
+  if (functionalType.getNumResults() != numExpectedLoops + 1) {
+    return parser.emitError(parser.getNameLoc())
+           << "expected " << (numExpectedLoops + 1) << " result type(s)";
+  }
+  if (functionalType.getNumInputs() != dynamicSizes.size() + 1) {
+    return parser.emitError(operandLoc)
+           << "expected " << dynamicSizes.size() + 1 << " operand type(s)";
+  }
+  if (parser.resolveOperand(target, functionalType.getInputs().front(),
+                            result.operands) ||
+      parser.resolveOperands(dynamicSizes,
+                             functionalType.getInputs().drop_front(),
+                             operandLoc, result.operands)) {
+    return failure();
+  }
+
+  result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
+  result.addTypes(functionalType.getResults());
   return success();
 }
 
@@ -1911,6 +2071,8 @@ void TileOp::print(OpAsmPrinter &p) {
   p << ' ' << getTarget();
   printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes());
   printOptionalInterchange(p, getInterchange());
+  p << " : ";
+  p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());
 }
 
 void transform::TileOp::getEffects(

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index c410b5d15a771..9df009f241c53 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -117,6 +117,32 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
       b.getStringAttr("expected strictly positive tile size and divisor"));
 }
 
+FailureOr<StaticMultiSizeSpecification>
+mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension,
+                                          int64_t targetSize, int64_t divisor) {
+  assert(!op.hasDynamicShape() &&
+         "cannot compute static multi-tile sizes for an op with dynamic shape");
+  assert(targetSize > 0 && "target size must be non-negative");
+  assert(divisor > 0 && "divisor must be non-negative");
+  assert(dimension < op.getNumLoops() && "dimension overflow");
+
+  StaticMultiSizeSpecification spec;
+  int64_t tripCount = op.getStaticLoopRanges()[dimension];
+  int64_t a = tripCount / divisor;
+  int64_t t = (targetSize + divisor - 1) / divisor;
+  int64_t totalTripCount = (a + t - 1) / t;
+  spec.lowTileSize = (a / totalTripCount) * divisor;
+  spec.highTileSize = spec.lowTileSize + divisor;
+  spec.highTripCount = a % totalTripCount;
+  spec.lowTripCount = totalTripCount - spec.highTripCount;
+  if (spec.lowTileSize * spec.lowTripCount +
+          spec.highTileSize * spec.highTripCount !=
+      tripCount) {
+    return failure();
+  }
+  return spec;
+}
+
 FailureOr<MultiSizeSpecification>
 mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op,
                                     unsigned dimension, OpFoldResult targetSize,

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index f711419a11a2f..2e40d78ae326e 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -586,12 +586,23 @@ transform::ReplicateOp::apply(transform::TransformResults &results,
   unsigned numRepetitions = state.getPayloadOps(getPattern()).size();
   for (const auto &en : llvm::enumerate(getHandles())) {
     Value handle = en.value();
-    ArrayRef<Operation *> current = state.getPayloadOps(handle);
-    SmallVector<Operation *> payload;
-    payload.reserve(numRepetitions * current.size());
-    for (unsigned i = 0; i < numRepetitions; ++i)
-      llvm::append_range(payload, current);
-    results.set(getReplicated()[en.index()].cast<OpResult>(), payload);
+    if (handle.getType().isa<TransformHandleTypeInterface>()) {
+      ArrayRef<Operation *> current = state.getPayloadOps(handle);
+      SmallVector<Operation *> payload;
+      payload.reserve(numRepetitions * current.size());
+      for (unsigned i = 0; i < numRepetitions; ++i)
+        llvm::append_range(payload, current);
+      results.set(getReplicated()[en.index()].cast<OpResult>(), payload);
+    } else {
+      assert(handle.getType().isa<TransformParamTypeInterface>() &&
+             "expected param type");
+      ArrayRef<Attribute> current = state.getParams(handle);
+      SmallVector<Attribute> params;
+      params.reserve(numRepetitions * current.size());
+      for (unsigned i = 0; i < numRepetitions; ++i)
+        llvm::append_range(params, current);
+      results.setParams(getReplicated()[en.index()].cast<OpResult>(), params);
+    }
   }
   return DiagnosedSilenceableFailure::success();
 }

diff  --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index 2525ea34c375c..f045e5c13c1ed 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -5,11 +5,11 @@
 try:
   from ..ir import *
   from ._ods_common import get_op_result_or_value as _get_op_result_or_value
-  from ..dialects import pdl
+  from ..dialects import pdl, transform
 except ImportError as e:
   raise RuntimeError("Error loading imports from extension module") from e
 
-from typing import List, Optional, Sequence, Union
+from typing import List, Optional, Sequence, Union, overload
 
 IntOrAttrList = Sequence[Union[IntegerAttr, int]]
 OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
@@ -51,13 +51,13 @@ def _get_int_array_attr(
 
 def _get_dense_int64_array_attr(
         values: Sequence[int]) -> DenseI64ArrayAttr:
-    """Creates a dense integer array from a sequence of integers.
+  """Creates a dense integer array from a sequence of integers.
     Expects the thread-local MLIR context to have been set by the context 
     manager.
     """
-    if values is None:
-        return DenseI64ArrayAttr.get([])
-    return DenseI64ArrayAttr.get(values)
+  if values is None:
+    return DenseI64ArrayAttr.get([])
+  return DenseI64ArrayAttr.get(values)
 
 def _get_int_int_array_attr(
     values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr,
@@ -141,6 +141,7 @@ class MultiTileSizesOp:
   """Specialization for MultitileSizesOp class."""
 
   def __init__(self,
+               result_type: Type,
                target: Union[Operation, Value],
                *,
                dimension: Union[int, IntegerAttr],
@@ -149,9 +150,9 @@ def __init__(self,
                loc=None,
                ip=None):
     super().__init__(
-        pdl.OperationType.get(),
-        pdl.OperationType.get(),
-        pdl.OperationType.get(),
+        result_type,
+        result_type,
+        result_type,
         _get_op_result_or_value(target),
         dimension=_get_int64_attr(dimension),
         target_size=_get_int64_attr(target_size),
@@ -223,11 +224,12 @@ def __init__(self,
       static_split_point = _get_int64_attr(ShapedType.get_dynamic_size())
       dynamic_split_point = _get_op_result_or_value(split_point)
 
-    pdl_operation_type = pdl.OperationType.get()
+    target = _get_op_result_or_value(target)
+
     super().__init__(
-        pdl_operation_type,
-        pdl_operation_type,
-        _get_op_result_or_value(target),
+        target.type,
+        target.type,
+        target,
         dimension=dimension,
         static_split_point=static_split_point,
         dynamic_split_point=dynamic_split_point,
@@ -238,7 +240,9 @@ def __init__(self,
 class TileOp:
   """Specialization for TileOp class."""
 
+  @overload
   def __init__(self,
+               loop_types: Union[Type, List[Type]],
                target: Union[Operation, Value],
                *,
                sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation,
@@ -246,9 +250,28 @@ def __init__(self,
                interchange: OptionalIntList = None,
                loc=None,
                ip=None):
-    pdl_operation_type = pdl.OperationType.get()
-    i64_type = IntegerType.get_signless(64)
+    ...
 
+  @overload
+  def __init__(self,
+               target: Union[Operation, Value],
+               *,
+               sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation,
+                                                    Value]], ArrayAttr]] = None,
+               interchange: OptionalIntList = None,
+               loc=None,
+               ip=None):
+    ...
+
+  def __init__(self,
+               loop_types_or_target: Union[Type, List[Type], Operation, Value],
+               target_or_none: Optional[Union[Operation, Value]] = None,
+               *,
+               sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation,
+                                                    Value]], ArrayAttr]] = None,
+               interchange: OptionalIntList = None,
+               loc=None,
+               ip=None):
     if sizes is None:
       sizes = []
 
@@ -267,12 +290,26 @@ def __init__(self,
 
     num_loops = sum(
         v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
+
+    if isinstance(loop_types_or_target, (Operation, Value)):
+      loop_types = [transform.AnyOpType.get()] * num_loops
+      target = loop_types_or_target
+      assert target_or_none is None, "Cannot construct TileOp with two targets."
+    else:
+      loop_types = ([loop_types_or_target] * num_loops) if isinstance(
+          loop_types_or_target, Type) else loop_types_or_target
+      target = target_or_none
+
+    target = _get_op_result_or_value(target)
+
     super().__init__(
-        pdl_operation_type, [pdl_operation_type] * num_loops,
-        _get_op_result_or_value(target),
+        target.type,
+        loop_types,
+        target,
         dynamic_sizes=dynamic_sizes,
         static_sizes=sizes_attr,
-        interchange=_get_dense_int64_array_attr(interchange) if interchange else None,
+        interchange=_get_dense_int64_array_attr(interchange)
+        if interchange else None,
         loc=loc,
         ip=ip)
 

diff  --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index f899a81d1a5d0..9fae730a2bd16 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -15,7 +15,7 @@ func.func @matmul_tensors(
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.matmul"]} in %module_op
-  %1, %loops:3 = transform.structured.tile %0 [2, 2, 2]
+  %1, %loops:3 = transform.structured.tile %0 [2, 2, 2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
   %2 = get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation
   transform.structured.vectorize %2
   transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op

diff  --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
index e309276de75cc..2e6167651dc95 100644
--- a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
+++ b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir
@@ -1,23 +1,29 @@
-// RUN: mlir-opt --test-transform-dialect-interpreter --scf-for-loop-canonicalization --canonicalize %s | FileCheck %s
+// RUN: mlir-opt --test-transform-dialect-interpreter --scf-for-loop-canonicalization --canonicalize --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s --check-prefix=NOCANON
 
 // This implements a 2D multisize tiling with target sizes [3, 10].
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-  %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3}
-  %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10}
-  %2:2 = transform.structured.split %0 after %1#2 { dimension = 0 }
-  %3:2 = transform.structured.tile %2#0 [%1#0]
-  %4:2 = transform.structured.tile %2#1 [%1#1]
+  %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!pdl.operation) -> !pdl.operation
+  %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!pdl.operation) -> !pdl.operation
+  %2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !pdl.operation, !pdl.operation
+  %3:2 = transform.structured.tile %2#0 [%1#0] : (!pdl.operation, !pdl.operation) -> (!pdl.operation, !pdl.operation)
+  %4:2 = transform.structured.tile %2#1 [%1#1] : (!pdl.operation, !pdl.operation) -> (!pdl.operation, !pdl.operation)
   %5 = merge_handles %3#0, %4#0 : !pdl.operation
   %tt:3 = replicate num(%5) %t#0, %t#1, %t#2 : !pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation
-  %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 }
-  transform.structured.tile %6#0 [0, %tt#0]
-  transform.structured.tile %6#1 [0, %tt#1]
+  %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !pdl.operation, !pdl.operation
+  transform.structured.tile %6#0 [0, %tt#0] : (!pdl.operation, !pdl.operation) -> (!pdl.operation, !pdl.operation)
+  transform.structured.tile %6#1 [0, %tt#1] : (!pdl.operation, !pdl.operation) -> (!pdl.operation, !pdl.operation)
 }
 
 func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
 
+// Without canonicalization, tile sizes are computed dynamically as affine maps.
+// NOCANON-LABEL: @two_d
+// NOCANON-COUNT-8: affine.apply
+// NOCANON:         scf.for
+
 // CHECK-LABEL: @two_d
 // CHECK-SAME: %[[IN:.+]]: tensor<10x34xf32>, %[[OUT:.+]]: tensor<10x34xf32>
 func.func @two_d(%arg0: tensor<10x34xf32>,
@@ -93,3 +99,96 @@ func.func @two_d(%arg0: tensor<10x34xf32>,
 
   return %0 : tensor<10x34xf32>
 }
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  %1:3 = transform.structured.multitile_sizes %0 { dimension = 0, target_size = 3} : (!pdl.operation) -> !transform.param<i64>
+  %t:3 = transform.structured.multitile_sizes %0 { dimension = 1, target_size = 10} : (!pdl.operation) -> !transform.param<i64>
+  %2:2 = transform.structured.split %0 after %1#2 { dimension = 0 } : !pdl.operation, !transform.param<i64>
+  %3:2 = transform.structured.tile %2#0 [%1#0] : (!pdl.operation, !transform.param<i64>) -> (!pdl.operation, !pdl.operation)
+  %4:2 = transform.structured.tile %2#1 [%1#1] : (!pdl.operation, !transform.param<i64>) -> (!pdl.operation, !pdl.operation)
+  %5 = merge_handles %3#0, %4#0 : !pdl.operation
+  %tt:3 = replicate num(%5) %t#0, %t#1, %t#2 : !pdl.operation, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
+  %6:2 = transform.structured.split %5 after %tt#2 { dimension = 1 } : !pdl.operation, !transform.param<i64>
+  transform.structured.tile %6#0 [0, %tt#0] : (!pdl.operation, !transform.param<i64>) -> (!pdl.operation, !pdl.operation)
+  transform.structured.tile %6#1 [0, %tt#1] : (!pdl.operation, !transform.param<i64>) -> (!pdl.operation, !pdl.operation)
+}
+
+func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
+
+// Even without canonicalization, tile sizes can be computed statically thanks
+// to parameters.
+// NOCANON-LABEL: @two_d
+// NOCANON-NOT:   affine.apply
+// NOCANON:       scf.for
+
+// CHECK-LABEL: @two_d_param
+// CHECK-SAME: %[[IN:.+]]: tensor<10x34xf32>, %[[OUT:.+]]: tensor<10x34xf32>
+func.func @two_d_param(%arg0: tensor<10x34xf32>,
+                       %arg1: tensor<10x34xf32>) -> tensor<10x34xf32> {
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(i, j) -> (i, j)>,
+                     affine_map<(i, j) -> (i, j)>],
+    iterator_types = ["parallel", "parallel"]
+  }
+  ins(%arg0: tensor<10x34xf32>)
+  outs(%arg1: tensor<10x34xf32>) {
+  ^bb0(%0: f32, %1: f32):
+    %i = linalg.index 0 : index
+    %j = linalg.index 1 : index
+    %call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32
+    linalg.yield %call_res : f32
+  } -> tensor<10x34xf32>
+
+  // CHECK:      %[[SLICE_1_IN:.+]] = tensor.extract_slice %[[IN]][0, 0] [4, 34] [1, 1]
+  // CHECK:      %[[SLICE_1:.+]] = tensor.extract_slice %[[OUT]][0, 0] [4, 34] [1, 1]
+  // CHECK:      scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_1:.+]] = %[[SLICE_1]])
+  // CHECK:        %[[OUTSLICE_1_IN:.+]] = tensor.extract_slice %[[SLICE_1_IN]][%[[I1]], 0] [2, 34] [1, 1]
+  // CHECK:        %[[OUTSLICE_1:.+]] = tensor.extract_slice %[[ITERARG_1]][%[[I1]], 0] [2, 34] [1, 1]
+
+  // CHECK:        %[[SLICE_2_IN:.+]] = tensor.extract_slice %[[OUTSLICE_1_IN]][0, 0] [2, 16] [1, 1]
+  // CHECK:        %[[SLICE_2:.+]] = tensor.extract_slice %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1]
+  // CHECK:        %[[LOOPRES:.+]] = scf.for %[[I2:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_2:.+]] = %[[SLICE_2]])
+  // CHECK:          %[[INSLICE_2:.+]] = tensor.extract_slice %[[SLICE_2_IN]][0, %[[I2]]] [2, 8] [1, 1]
+  // CHECK:          %[[OUTSLICE_2:.+]] = tensor.extract_slice %[[ITERARG_2]][0, %[[I2]]] [2, 8] [1, 1]
+  // CHECK:          %[[RESSLICE_1:.+]] = linalg.generic {{.*}} ins(%[[INSLICE_2]] : tensor<2x8xf32>) outs(%[[OUTSLICE_2]] : tensor<2x8xf32>)
+  // CHECK:          %[[RESPARTIAL:.+]] = tensor.insert_slice %[[RESSLICE_1]] into %[[ITERARG_2]]
+  // CHECK:          scf.yield %[[RESPARTIAL]]
+
+  // CHECK:        %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1]
+  // CHECK:        %[[OUTSLICE_3:.+]] = tensor.extract_slice %[[INSERTED]][0, 16] [2, 18] [1, 1]
+  // CHECK:        scf.for %{{.*}} iter_args(%{{.*}} = %[[OUTSLICE_3]])
+  // CHECK-COUNT-2:  tensor.extract_slice
+  // CHECK:          linalg.generic {{.*}} ins(%{{.*}} : tensor<2x9xf32>)
+  // CHECK:          tensor.insert_slice
+  // CHECK:          scf.yield
+  // CHECK:        %[[INSERTED_2:.+]] = tensor.insert_slice %{{.*}} into %[[INSERTED]]
+  // CHECK:        %[[INSERTED_3:.+]] = tensor.insert_slice %[[INSERTED_2]] into %[[ITERARG_1]]
+  // CHECK:        scf.yield %[[INSERTED_3]]
+
+  // CHECK:        tensor.insert_slice
+  // CHECK:        tensor.extract_slice
+  // CHECK:        scf.for
+  // CHECK-COUNT-2:  tensor.extract_slice
+  // CHECK:          scf.for
+  // CHECK-COUNT-2:    tensor.extract_slice
+  // CHECK:            linalg.generic {{.*}} ins(%{{.*}} : tensor<3x8xf32>)
+  // CHECK:            tensor.insert_slice
+  // CHECK:            scf.yield
+  // CHECK:          tensor.insert_slice
+  // CHECK:          tensor.extract_slice
+  // CHECK:          scf.for
+  // CHECK-COUNT-2:    tensor.extract_slice
+  // CHECK:            linalg.generic {{.*}} ins(%{{.*}} : tensor<3x9xf32>)
+  // CHECK:            tensor.insert_slice
+  // CHECK:            scf.yield
+  // CHECK-COUNT-2:  tensor.insert_slice
+  // CHECK:          scf.yield
+  // CHECK:        %[[RESULT:.+]] = tensor.insert_slice
+  // CHECK:        return %[[RESULT]]
+
+  return %0 : tensor<10x34xf32>
+}

diff  --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir
index 7872ec4a236cc..70cac58acaf35 100644
--- a/mlir/test/Dialect/Linalg/promotion_options.mlir
+++ b/mlir/test/Dialect/Linalg/promotion_options.mlir
@@ -37,6 +37,6 @@ func.func @gemm(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>
 transform.sequence failures(propagate) {
 ^bb0(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
-  %1, %loops:3 = transform.structured.tile %0 [16, 16, 16]
+  %1, %loops:3 = transform.structured.tile %0 [16, 16, 16] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
   %2 = transform.structured.promote %1 { operands_to_promote = [0, 2], force_full_tiles = [false, false], use_full_tiles_by_default }
 }

diff  --git a/mlir/test/Dialect/Linalg/tile-conv.mlir b/mlir/test/Dialect/Linalg/tile-conv.mlir
index a6998f81620ca..cd126be25dfc2 100644
--- a/mlir/test/Dialect/Linalg/tile-conv.mlir
+++ b/mlir/test/Dialect/Linalg/tile-conv.mlir
@@ -12,7 +12,7 @@ func.func @conv(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>, %arg2 : memref
 transform.sequence failures(propagate) {
   ^bb0(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["linalg.conv_2d"]} in %arg1
-    %1, %loop:2 = transform.structured.tile %0 [2, 3]
+    %1, %loop:2 = transform.structured.tile %0 [2, 3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
 }
 
 //       CHECK: func @conv

diff  --git a/mlir/test/Dialect/Linalg/tile-indexed.mlir b/mlir/test/Dialect/Linalg/tile-indexed.mlir
index 0047e4409dec2..3e6e0d5c5c686 100644
--- a/mlir/test/Dialect/Linalg/tile-indexed.mlir
+++ b/mlir/test/Dialect/Linalg/tile-indexed.mlir
@@ -14,7 +14,7 @@ func.func @indexed_vector(%arg0: memref<50xindex>) {
 transform.sequence failures(propagate) {
   ^bb0(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-    %1, %loop = transform.structured.tile %0 [10]
+    %1, %loop = transform.structured.tile %0 [10] : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
 }
 
 // TILE-10n25-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>
@@ -44,7 +44,7 @@ func.func @indexed_matrix(%arg0: memref<50x50xindex>) {
 transform.sequence failures(propagate) {
   ^bb0(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-    %1, %loop:2 = transform.structured.tile %0 [10, 25]
+    %1, %loop:2 = transform.structured.tile %0 [10, 25] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
 }
 
 // TILE-10n25-DAG: [[$MAP:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0 + d1)>

diff  --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index 273599ff53dea..484534e740ab3 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -30,7 +30,7 @@ func.func @matmul_tensors(
 transform.sequence failures(propagate) {
   ^bb0(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
-    %1, %loops:3 = transform.structured.tile %0 [2, 3, 4]
+    %1, %loops:3 = transform.structured.tile %0 [2, 3, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
 }
 
 // -----
@@ -61,7 +61,7 @@ func.func @generic_op_tensors(
 transform.sequence failures(propagate) {
   ^bb0(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-    %1, %loops:3 = transform.structured.tile %0 [2, 3, 4]
+    %1, %loops:3 = transform.structured.tile %0 [2, 3, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
 }
 
 // CHECK-LABEL: func @generic_op_tensors
@@ -132,5 +132,5 @@ func.func @fold_extract_slice(
 transform.sequence failures(propagate) {
   ^bb0(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-    %1, %loops:3 = transform.structured.tile %0 [2, 3, 4]
+    %1, %loops:3 = transform.structured.tile %0 [2, 3, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
 }

diff  --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
index 580ad597ef30d..c50c8f4087e33 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -89,7 +89,7 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
   %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [5, 0, 7], tile_interchange = [0, 2, 1]}
-  %2, %loops_2 = transform.structured.tile %1 [0, 4]
+  %2, %loops_2 = transform.structured.tile %1 [0, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
 }
 
 // -----

diff  --git a/mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir b/mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir
index cd30ff7606ae9..b035d58d18864 100644
--- a/mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-multitile-sizes.mlir
@@ -1,11 +1,11 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file --verify-diagnostics | FileCheck %s
 
 // CHECK-DAG: #[[$MAP13:.+]] = affine_map<() -> (13)>
 
 transform.sequence failures(propagate) {
   ^bb0(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
-    transform.structured.multitile_sizes %0 { target_size = 3, dimension = 0 }
+    transform.structured.multitile_sizes %0 { target_size = 3, dimension = 0 } : (!pdl.operation) -> !pdl.operation
 }
 
 // CHECK-LABEL: @multitile_sizes_static
@@ -29,7 +29,34 @@ func.func @multitile_sizes_static(
 transform.sequence failures(propagate) {
   ^bb0(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
-    transform.structured.multitile_sizes %0 { target_size = 3, divisor = 2, dimension = 0 }
+    %low_tile, %high_tile, %split_point =
+      transform.structured.multitile_sizes %0 { target_size = 3, dimension = 0 }
+      : (!pdl.operation) -> !transform.param<i64>
+    // expected-remark @below {{2 : i64}}
+    transform.test_print_param %low_tile : !transform.param<i64>
+    // expected-remark @below {{3 : i64}}
+    transform.test_print_param %high_tile : !transform.param<i64>
+    // expected-remark @below {{4 : i64}}
+    transform.test_print_param %split_point : !transform.param<i64>
+}
+
+// CHECK-LABEL: @multitile_sizes_static_gen
+func.func @multitile_sizes_static_gen(
+  %arg0: tensor<13x34xf32>, %arg1: tensor<34x42xf32>, %arg2: tensor<13x42xf32>)
+    -> tensor<13x42xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<13x34xf32>, tensor<34x42xf32>)
+                     outs(%arg2: tensor<13x42xf32>)
+    -> tensor<13x42xf32>
+
+  return %0 : tensor<13x42xf32>
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    transform.structured.multitile_sizes %0 { target_size = 3, divisor = 2, dimension = 0 } : (!pdl.operation) -> !pdl.operation
 }
 
 // CHECK: #[[$MAP_A:.+]] = affine_map<()[s0] -> ([[A_IMPL:s0 floordiv 2]])>
@@ -64,3 +91,24 @@ func.func @multitile_sizes_dynamic(
 
   return %0 : tensor<?x?xf32>
 }
+
+// -----
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    // expected-error @below {{cannot compute parametric tile sizes for dynamically shaped payload op}}
+    transform.structured.multitile_sizes %0 { target_size = 3, divisor = 2, dimension = 0 }
+      : (!pdl.operation) -> !transform.param<i64>
+}
+
+func.func @multitile_sizes_dynamic_gen(
+  %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+  // expected-note @below {{payload op}}
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
+                     outs(%arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32>
+
+  return %0 : tensor<?x?xf32>
+}

diff  --git a/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir
index fbf083c3d1ad8..e46f19d34ae2f 100644
--- a/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir
@@ -21,6 +21,6 @@ func.func @scalarize(%arg0: tensor<24x12xf32>,
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
-  %1, %loops = transform.structured.tile %0 [10, 0, 0]
+  %1, %loops = transform.structured.tile %0 [10, 0, 0] : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
   %2 = transform.structured.scalarize %1
 }

diff  --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir
index 1d7f15efe73cb..6313e77fc9dde 100644
--- a/mlir/test/Dialect/Linalg/transform-op-split.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir
@@ -3,7 +3,7 @@
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-  %1:2 = transform.structured.split %0 after 42 { dimension = 0 }
+  %1:2 = transform.structured.split %0 after 42 { dimension = 0 } : !pdl.operation
 }
 
 func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
@@ -51,7 +51,7 @@ func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tenso
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-  %1:2 = transform.structured.split %0 after 42 { dimension = 0 }
+  %1:2 = transform.structured.split %0 after 42 { dimension = 0 } : !pdl.operation
 }
 
 func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
@@ -85,7 +85,7 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
   %1 = transform.structured.match ops{["func.call"]} in %arg1
-  transform.structured.split %0 after %1 { dimension = 0 }
+  transform.structured.split %0 after %1 { dimension = 0 } : !pdl.operation, !pdl.operation
 }
 
 func.func private @get_size() -> index
@@ -132,8 +132,8 @@ func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-  %1:2 = transform.structured.split %0 after 4 { dimension = 0}
-  %2:2 = transform.structured.split %1#1 after 16 { dimension = 1 }
+  %1:2 = transform.structured.split %0 after 4 { dimension = 0 } : !pdl.operation
+  %2:2 = transform.structured.split %1#1 after 16 { dimension = 1 } : !pdl.operation
 }
 
 func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
@@ -199,7 +199,7 @@ transform.sequence failures(propagate) {
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
   %1 = transform.structured.match ops{["func.call"]} in %arg1
   // expected-error @below {{expected dynamic split point handle to point to a single-result index-typed op}}
-  transform.structured.split %0 after %1 { dimension = 0 }
+  transform.structured.split %0 after %1 { dimension = 0 } : !pdl.operation, !pdl.operation
 }
 
 func.func private @get_size() -> i64
@@ -225,7 +225,7 @@ transform.sequence failures(propagate) {
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
   %1 = transform.structured.match ops{["func.call"]} in %arg1
   // expected-error @below {{expected the dynamic split point handle to point to as many operations (0) as the target handle (1)}}
-  transform.structured.split %0 after %1 { dimension = 0 }
+  transform.structured.split %0 after %1 { dimension = 0 } : !pdl.operation, !pdl.operation
 }
 
 func.func private @get_size() -> i64
@@ -248,7 +248,7 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["func.return"]} in %arg1
   // expected-error @below {{only applies to structured ops}}
-  transform.structured.split %0 after 16 { dimension = 1 }
+  transform.structured.split %0 after 16 { dimension = 1 } : !pdl.operation
 }
 
 func.func @noop(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
@@ -262,7 +262,7 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
   // expected-error @below {{dimension 1 does not exist in target op}}
-  transform.structured.split %0 after 16 { dimension = 1 }
+  transform.structured.split %0 after 16 { dimension = 1 } : !pdl.operation
 }
 
 func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
@@ -285,7 +285,7 @@ transform.sequence failures(propagate) {
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
   // expected-error @below {{splitting does not produce the second part for a subset of targets}}
   // expected-note @below {{expected splitting to produce the second part of all or none of the targets}}
-  %1:2 = transform.structured.split %0 after 142 { dimension = 0 }
+  %1:2 = transform.structured.split %0 after 142 { dimension = 0 } : !pdl.operation
 }
 
 func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32

diff  --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index 46027dce04e00..10517050bf01d 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
 
 transform.sequence failures(propagate) {
 ^bb0(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
-  %1, %loops:3 = transform.structured.tile %0 [4, 4, 4]
+  %1, %loops:3 = transform.structured.tile %0 [4, 4, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
 }
 
 // CHECK-LABEL: func @tile_linalg_matmul(
@@ -40,7 +40,7 @@ transform.sequence failures(propagate) {
 ^bb0(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
   %1 = transform.structured.match ops{["func.call"]} in %arg1
-  %2, %loops:3 = transform.structured.tile %0 [%1, %1, 4]
+  %2, %loops:3 = transform.structured.tile %0 [%1, %1, 4] : (!pdl.operation, !pdl.operation, !pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
 }
 
 func.func private @get_dynamic_tile_size() -> index
@@ -73,3 +73,53 @@ func.func @tile_linalg_matmul_dynamic(
 //      CHECK: return %[[TD0]] : tensor<128x128xf32>
   return %0 : tensor<128x128xf32>
 }
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+  // expected-note @below {{for this parameter}}
+  %1 = transform.test_produce_integer_param_with_type i64 : !transform.param<i64>
+  // expected-error @below {{expected as many parameter values (0) as target ops (2)}}
+  transform.structured.tile %0 [%1, %1, %1]
+    : (!pdl.operation, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>)
+    -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
+}
+
+func.func @tile_linalg_matmul(
+  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
+    -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg2: tensor<128x128xf32>)
+    -> tensor<128x128xf32>
+  %1 = linalg.matmul  ins(%0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg2: tensor<128x128xf32>)
+    -> tensor<128x128xf32>
+  return %0, %1 : tensor<128x128xf32>, tensor<128x128xf32>
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+  // expected-note @below {{for this handle}}
+  %1 = transform.structured.match ops{["arith.constant"]} in %arg1
+  // expected-error @below {{expected as many dynamic size-producing operations (0) as target ops (2)}}
+  transform.structured.tile %0 [%1, %1, 1]
+    : (!pdl.operation, !pdl.operation, !pdl.operation)
+    -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
+}
+
+func.func @tile_linalg_matmul(
+  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
+    -> (tensor<128x128xf32>, tensor<128x128xf32>) {
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg2: tensor<128x128xf32>)
+    -> tensor<128x128xf32>
+  %1 = linalg.matmul  ins(%0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg2: tensor<128x128xf32>)
+    -> tensor<128x128xf32>
+  return %0, %1 : tensor<128x128xf32>, tensor<128x128xf32>
+}

diff  --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
index e21b21a8fa408..fb84018580305 100644
--- a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
@@ -45,3 +45,12 @@ transform.sequence failures(propagate) {
   // expected-error at below {{'transform.structured.interchange' op attribute 'iterator_interchange' failed to satisfy constraint: i64 dense array attribute whose value is non-negative}}
   transform.structured.interchange %arg0 iterator_interchange = [-3, 1]
 }
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation):
+  // expected-error at below {{expects all results type to be the same}}
+  "transform.structured.multitile_sizes"(%arg0) { target_size = 3, divisor = 2, dimension = 0 }
+      : (!pdl.operation) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i32>)
+}

diff  --git a/mlir/test/Dialect/Linalg/transform-ops.mlir b/mlir/test/Dialect/Linalg/transform-ops.mlir
index 64cf3fbd04f90..4fa426e1394e0 100644
--- a/mlir/test/Dialect/Linalg/transform-ops.mlir
+++ b/mlir/test/Dialect/Linalg/transform-ops.mlir
@@ -3,7 +3,13 @@
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   // CHECK %{{.*}}, %{{.*}}:2 = transform.structured.tile
-  %0, %1:2 = transform.structured.tile %arg0 [2, 0, 3]
+  %0, %1:2 = transform.structured.tile %arg0 [2, 0, 3] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg0: !transform.any_op):
+  %0:2 = transform.structured.split %arg0 after 42 { dimension = 0 } : !transform.any_op
+  transform.structured.split %0#0 after %0#1 { dimension = 1 } : !transform.any_op, !transform.any_op
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
index 65ff4d62809af..c34e94a0b511d 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -14,7 +14,7 @@ func.func @dot(%x: memref<?xf32, strided<[1], offset: ?>>,
 transform.sequence failures(propagate) {
   ^bb0(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["linalg.dot"]} in %arg1
-    %1, %loop = transform.structured.tile %0 [8000]
+    %1, %loop = transform.structured.tile %0 [8000] : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
 }
 
 // CHECK-LABEL: func @dot
@@ -38,7 +38,7 @@ func.func @matvec(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
 transform.sequence failures(propagate) {
   ^bb0(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["linalg.matvec"]} in %arg1
-    %1, %loops:2 = transform.structured.tile %0 [5, 6]
+    %1, %loops:2 = transform.structured.tile %0 [5, 6] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
 }
 
 // CHECK-LABEL: func @matvec
@@ -65,10 +65,10 @@ func.func @matmul(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
 transform.sequence failures(propagate) {
   ^bb0(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
-    %1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000]
-    %2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400]
-    %3, %loops_3:3 = transform.structured.tile %2 [20, 30, 40]
-    %4, %loops_4:3 = transform.structured.tile %3 [2, 3, 4]
+    %1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
+    %2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
+    %3, %loops_3:3 = transform.structured.tile %2 [20, 30, 40] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
+    %4, %loops_4:3 = transform.structured.tile %3 [2, 3, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
 }
 
 // CHECK-LABEL: func @matmul
@@ -164,7 +164,7 @@ func.func @matvec_perm(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
 transform.sequence failures(propagate) {
   ^bb0(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["linalg.matvec"]} in %arg1
-    %1, %loops:2 = transform.structured.tile %0 [5, 6] {interchange = [1, 0]}
+    %1, %loops:2 = transform.structured.tile %0 [5, 6] {interchange = [1, 0]} : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation)
 }
 
 // CHECK-LABEL: func @matvec_perm
@@ -191,9 +191,9 @@ func.func @matmul_perm(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
 transform.sequence failures(propagate) {
   ^bb0(%arg1: !pdl.operation):
     %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
-    %1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] {interchange = [1, 2, 0]}
-    %2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] {interchange = [1, 0, 2]}
-    %3, %loops_3:3 = transform.structured.tile %2 [20, 30, 40]
+    %1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] {interchange = [1, 2, 0]} : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
+    %2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] {interchange = [1, 0, 2]} : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
+    %3, %loops_3:3 = transform.structured.tile %2 [20, 30, 40] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
 }
 
 // CHECK-LABEL: func @matmul_perm

diff  --git a/mlir/test/Dialect/Transform/selective-targeting.mlir b/mlir/test/Dialect/Transform/selective-targeting.mlir
index b11a5ea09ccb3..231ff3099d175 100644
--- a/mlir/test/Dialect/Transform/selective-targeting.mlir
+++ b/mlir/test/Dialect/Transform/selective-targeting.mlir
@@ -77,7 +77,7 @@ transform.with_pdl_patterns {
   transform.sequence %arg0 : !pdl.operation failures(propagate) {
   ^bb1(%arg1: !pdl.operation):
     %0 = pdl_match @pdl_target_attrA in %arg1 : (!pdl.operation) -> !pdl.operation
-    transform.structured.tile %0 [4, 4, 4]
+    transform.structured.tile %0 [4, 4, 4] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
     %1 = pdl_match @pdl_target_attrC in %arg1 : (!pdl.operation) -> !pdl.operation
     %2 = transform.get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation
     transform.structured.vectorize %2

diff  --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index a753b229576a7..864fd8ffc3476 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -16,7 +16,7 @@ func.func @matmul_tensors(
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.matmul"]} in %module_op
-  %1, %loops:3 = transform.structured.tile %0 [8, 4, 2]
+  %1, %loops:3 = transform.structured.tile %0 [8, 4, 2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
   %2 = get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation
   transform.structured.vectorize %2
   transform.bufferization.one_shot_bufferize %module_op

diff  --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index b88f7465e08b8..e7696033980fe 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -56,8 +56,10 @@ def testInterchange():
 def testMultitileSizes():
   sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
   with InsertionPoint(sequence.body):
-    structured.MultiTileSizesOp(
-        sequence.bodyTarget, dimension=1, target_size=42)
+    structured.MultiTileSizesOp(pdl.OperationType.get(),
+                                sequence.bodyTarget,
+                                dimension=1,
+                                target_size=42)
     transform.YieldOp()
   # CHECK-LABEL: TEST: testMultitileSizes
   # CHECK: transform.sequence
@@ -110,7 +112,9 @@ def testSplit():
 def testTileCompact():
   sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
   with InsertionPoint(sequence.body):
-    structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
+    structured.TileOp(sequence.bodyTarget,
+                      sizes=[4, 8],
+                      interchange=[0, 1])
     transform.YieldOp()
   # CHECK-LABEL: TEST: testTileCompact
   # CHECK: transform.sequence
@@ -123,7 +127,9 @@ def testTileAttributes():
   attr = DenseI64ArrayAttr.get([4, 8])
   ichange = DenseI64ArrayAttr.get([0, 1])
   with InsertionPoint(sequence.body):
-    structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
+    structured.TileOp(sequence.bodyTarget,
+                      sizes=attr,
+                      interchange=ichange)
     transform.YieldOp()
   # CHECK-LABEL: TEST: testTileAttributes
   # CHECK: transform.sequence
@@ -134,8 +140,9 @@ def testTileAttributes():
 def testTileZero():
   sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
   with InsertionPoint(sequence.body):
-    structured.TileOp(
-        sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3])
+    structured.TileOp(sequence.bodyTarget,
+                      sizes=[4, 0, 2, 0],
+                      interchange=[0, 1, 2, 3])
     transform.YieldOp()
   # CHECK-LABEL: TEST: testTileZero
   # CHECK: transform.sequence
@@ -151,7 +158,8 @@ def testTileDynamic():
     with InsertionPoint(sequence.body):
       m1 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first")
       m2 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second")
-      structured.TileOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0])
+      structured.TileOp(sequence.bodyTarget,
+                        sizes=[m1, 3, m2, 0])
       transform.YieldOp()
   # CHECK-LABEL: TEST: testTileDynamic
   # CHECK: %[[FIRST:.+]] = pdl_match
@@ -159,6 +167,37 @@ def testTileDynamic():
   # CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0]
 
 
+ at run
+def testTileExplicitLoopTypeSingle():
+  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
+                                  [], transform.AnyOpType.get())
+  with InsertionPoint(sequence.body):
+    structured.TileOp(transform.OperationType.get("scf.for"),
+                      sequence.bodyTarget,
+                      sizes=[2, 3, 4])
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testTileExplicitLoopTypeSingle
+  # CHECK: = transform.structured.tile %{{.*}} : (!{{.*}}) ->
+  # CHECK-COUNT-3: !transform.op<"scf.for">
+
+
+
+ at run
+def testTileExplicitLoopTypeAll():
+  sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
+                                  [], transform.AnyOpType.get())
+  types = [
+      transform.OperationType.get(x)
+      for x in ["scf.for", "scf.parallel", "scf.foreach_thread"]
+  ]
+  with InsertionPoint(sequence.body):
+    structured.TileOp(types, sequence.bodyTarget, sizes=[2, 3, 4])
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testTileExplicitLoopTypeAll
+  # CHECK: = transform.structured.tile
+  # CHECK-SAME : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">,
+  # CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.foreach_thread">
+
 @run
 def testVectorize():
   sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())


        


More information about the Mlir-commits mailing list