[Mlir-commits] [mlir] [mlir][TilingInterface] Use `LoopLikeOpInterface` in tiling using SCF to unify tiling with `scf.for` and `scf.forall`. (PR #77874)

lorenzo chelini llvmlistbot at llvm.org
Fri Jan 12 07:42:18 PST 2024


================
@@ -235,6 +236,97 @@ void transform::TestTileUsingForallOp::getEffects(
   modifiesPayload(effects);
 }
 
+//===----------------------------------------------------------------------===//
+// TestFuseUsingForallOp
+//===----------------------------------------------------------------------===//
+
+/// Apply a tiling transformation to all payload ops and store both the
+/// tiled operation as well as the created tile loops.
+template <typename Range>
+static LogicalResult applyTilingToAll(
+    RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
+    unsigned numLoops, transform::TransformResults &transformResults,
+    function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
+        applyFn) {
+  SmallVector<Operation *> tiledLinalgOps;
+  SmallVector<SmallVector<Operation *>> loopOps(1);
+
+  for (Operation *target : payloadOps) {
+    auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
+    if (!tilingInterfaceOp)
+      return transformOp->emitError("only TilingInterface ops are supported");
+
+    rewriter.setInsertionPoint(target);
+    FailureOr<scf::SCFTileAndFuseResult> tiledResults =
+        applyFn(tilingInterfaceOp);
+    if (failed(tiledResults))
+      return failure();
+
+    // Perform the replacement of tiled and fused values.
+    SmallVector<Operation *> opsToReplace{target};
+    llvm::append_range(opsToReplace, tiledResults->fusedProducers);
+    for (Operation *toReplace : opsToReplace) {
+      for (OpResult res : toReplace->getResults())
+        if (auto replacement = tiledResults->replacements.lookup(res))
+          rewriter.replaceAllUsesWith(res, replacement);
+      if (toReplace->use_empty()) {
+        rewriter.eraseOp(toReplace);
+      }
+    }
+
+    // Report back the relevant handles to the transform op.
+    tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
+    assert(tiledResults->loops.size() == 1 &&
+           cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops &&
+           "Mismatched number of loops, tile and fuse transform should have "
+           "failed");
+    loopOps[0].push_back({tiledResults->loops[0]});
+  }
+
+  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
+  if (!loopOps.empty())
+    transformResults.set(transformOp->getOpResult(1), loopOps[0]);
+
+  return success();
+}
+
+DiagnosedSilenceableFailure transform::TestFuseUsingForallOp::apply(
+    transform::TransformRewriter &rewriter,
+    mlir::transform::TransformResults &transformResults,
----------------
chelini wrote:

let's drop mlir:: please.

https://github.com/llvm/llvm-project/pull/77874


More information about the Mlir-commits mailing list