[Mlir-commits] [mlir] [mlir][TilingInterface] Use `LoopLikeOpInterface` in tiling using SCF to unify tiling with `scf.for` and `scf.forall`. (PR #77874)
lorenzo chelini
llvmlistbot at llvm.org
Fri Jan 12 07:42:18 PST 2024
================
@@ -235,6 +236,97 @@ void transform::TestTileUsingForallOp::getEffects(
modifiesPayload(effects);
}
+//===----------------------------------------------------------------------===//
+// TestFuseUsingForallOp
+//===----------------------------------------------------------------------===//
+
+/// Apply a tiling transformation to all payload ops and store both the
+/// tiled operation as well as the created tile loops.
+template <typename Range>
+static LogicalResult applyTilingToAll(
+ RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
+ unsigned numLoops, transform::TransformResults &transformResults,
+ function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
+ applyFn) {
+ SmallVector<Operation *> tiledLinalgOps;
+ SmallVector<SmallVector<Operation *>> loopOps(1);
+
+ for (Operation *target : payloadOps) {
+ auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
+ if (!tilingInterfaceOp)
+ return transformOp->emitError("only TilingInterface ops are supported");
+
+ rewriter.setInsertionPoint(target);
+ FailureOr<scf::SCFTileAndFuseResult> tiledResults =
+ applyFn(tilingInterfaceOp);
+ if (failed(tiledResults))
+ return failure();
+
+ // Perform the replacement of tiled and fused values.
+ SmallVector<Operation *> opsToReplace{target};
+ llvm::append_range(opsToReplace, tiledResults->fusedProducers);
+ for (Operation *toReplace : opsToReplace) {
+ for (OpResult res : toReplace->getResults())
+ if (auto replacement = tiledResults->replacements.lookup(res))
+ rewriter.replaceAllUsesWith(res, replacement);
+ if (toReplace->use_empty()) {
+ rewriter.eraseOp(toReplace);
+ }
+ }
+
+ // Report back the relevant handles to the transform op.
+ tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
+ assert(tiledResults->loops.size() == 1 &&
+ cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops &&
+ "Mismatched number of loops, tile and fuse transform should have "
+ "failed");
+ loopOps[0].push_back({tiledResults->loops[0]});
+ }
+
+ transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
+ if (!loopOps.empty())
+ transformResults.set(transformOp->getOpResult(1), loopOps[0]);
+
+ return success();
+}
+
+DiagnosedSilenceableFailure transform::TestFuseUsingForallOp::apply(
+ transform::TransformRewriter &rewriter,
+ mlir::transform::TransformResults &transformResults,
----------------
chelini wrote:
let's drop mlir:: please.
https://github.com/llvm/llvm-project/pull/77874
More information about the Mlir-commits
mailing list