[Mlir-commits] [mlir] [MLIR] Add continuous tiling to transform dialect (PR #82792)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Fri Jun 7 02:37:18 PDT 2024


================
@@ -2286,103 +2299,168 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
             return OpFoldResult(op->getResult(0));
           }));
     } else {
-      splitPoints = llvm::to_vector(
-          llvm::map_range(state.getParams(getDynamicSplitPoint()),
+      chunkSizes = llvm::to_vector(
+          llvm::map_range(state.getParams(getDynamicChunkSizes()),
                           [](Attribute attr) { return OpFoldResult(attr); }));
     }
     if (diag.isSilenceableFailure())
       return diag;
 
-    if (splitPoints.size() != payload.size()) {
+    // For multiway split, a single payload is expected to have multiple
+    // split points.
+    if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
       return emitDefiniteFailure()
              << "expected the dynamic split point handle to point to as "
                 "many operations ("
-             << splitPoints.size() << ") as the target handle ("
+             << chunkSizes.size() << ") as the target handle ("
              << payload.size() << ")";
     }
   } else {
-    splitPoints.resize(payload.size(),
-                       rewriter.getIndexAttr(getStaticSplitPoint()));
+    chunkSizes.resize(payload.size(),
+                       rewriter.getIndexAttr(getStaticChunkSizes()));
   }
 
-  // Split each target operation.
-  SmallVector<Operation *> first, second;
-  Operation *noSecondPart = nullptr;
-  for (const auto &pair : llvm::zip(payload, splitPoints)) {
-    Operation *target = std::get<0>(pair);
-    auto linalgOp = dyn_cast<LinalgOp>(target);
+  auto checkStructuredOpAndDimensions =
+      [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
     if (!linalgOp) {
       auto diag = emitSilenceableError() << "only applies to structured ops";
-      diag.attachNote(target->getLoc()) << "target op";
+      diag.attachNote(loc) << "target op";
       return diag;
     }
 
     if (getDimension() >= linalgOp.getNumLoops()) {
       auto diag = emitSilenceableError() << "dimension " << getDimension()
-                                         << " does not exist in target op";
-      diag.attachNote(target->getLoc()) << "target op";
+                                          << " does not exist in target op";
+      diag.attachNote(loc) << "target op";
       return diag;
     }
+    return DiagnosedSilenceableFailure::success();
+  };
 
-    rewriter.setInsertionPoint(linalgOp);
-    std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
-        rewriter, cast<TilingInterface>(linalgOp.getOperation()),
-        getDimension(), std::get<1>(pair));
-
-    // Propagate errors.
-    if (!first.back() && !second.back()) {
+  auto checkFailureInSplitting =
+      [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
+    if (hasFailed) {
       auto diag = emitDefiniteFailure() << "internal failure in splitting";
-      diag.attachNote(target->getLoc()) << "target op";
+      diag.attachNote(loc) << "target op";
       return diag;
     }
+    return DiagnosedSilenceableFailure::success();
+  };
+
+  if (isMultiwaySplit) {
 
-    // Do not add null second parts.
-    if (!second.back()) {
-      noSecondPart = target;
-      second.pop_back();
+    // Split a single target operation at multiple points.
+    SmallVector<Operation *> opList;
+    Operation *head, *tail;
+    Operation *target = payload.front();
+
+    auto linalgOp = dyn_cast<LinalgOp>(target);
----------------
ftynse wrote:

Please document that. Or, better, have `checkStructuredOpAndDimensions` take `Operation *`.

https://github.com/llvm/llvm-project/pull/82792


More information about the Mlir-commits mailing list