[Mlir-commits] [mlir] [MLIR] Add continuous tiling to TileUsingForOp (PR #82792)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 23 04:57:00 PDT 2024


================
@@ -2305,57 +2319,105 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
                        rewriter.getIndexAttr(getStaticSplitPoint()));
   }
 
-  // Split each target operation.
-  SmallVector<Operation *> first, second;
-  Operation *noSecondPart = nullptr;
-  for (const auto &pair : llvm::zip(payload, splitPoints)) {
-    Operation *target = std::get<0>(pair);
-    auto linalgOp = dyn_cast<LinalgOp>(target);
-    if (!linalgOp) {
-      auto diag = emitSilenceableError() << "only applies to structured ops";
-      diag.attachNote(target->getLoc()) << "target op";
-      return diag;
-    }
+  if (isMultiwaySplit) {
 
-    if (getDimension() >= linalgOp.getNumLoops()) {
-      auto diag = emitSilenceableError() << "dimension " << getDimension()
-                                         << " does not exist in target op";
-      diag.attachNote(target->getLoc()) << "target op";
-      return diag;
+    // Split a single target operation at multiple points.
+    SmallVector<Operation *> opList;
+    Operation *head, *tail;
+    for (const auto [idx, splitPoint] : llvm::enumerate(splitPoints)) {
+
+      Operation *target;
+      if (idx == 0)
+        target = payload.front();
+      else
+        target = tail;
+
+      if (!target)
+        break;
+
+      auto linalgOp = dyn_cast<LinalgOp>(target);
+
+      if (!linalgOp) {
+        auto diag = emitSilenceableError() << "only applies to structured ops";
+        diag.attachNote(target->getLoc()) << "target op";
+        return diag;
+      }
+
+      if (getDimension() >= linalgOp.getNumLoops()) {
+        auto diag = emitSilenceableError() << "dimension " << getDimension()
+                                           << " does not exist in target op";
+        diag.attachNote(target->getLoc()) << "target op";
+        return diag;
+      }
+
+      rewriter.setInsertionPoint(linalgOp);
+      std::tie(head, tail) = linalg::splitOp(
+          rewriter, cast<TilingInterface>(linalgOp.getOperation()),
+          getDimension(), splitPoint);
+
+      opList.push_back(head);
----------------
muneebkhan85 wrote:

Fixed and Refactored. Uses the lambda function `checkFailureInSplitting` for doing the check.

https://github.com/llvm/llvm-project/pull/82792


More information about the Mlir-commits mailing list