[Mlir-commits] [mlir] [mlir][TilingInterface] Move TilingInterface tests to use transform dialect ops. (PR #77204)
Quinn Dawkins
llvmlistbot at llvm.org
Thu Jan 11 13:44:57 PST 2024
================
@@ -0,0 +1,267 @@
+//===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines transform dialect operations used for testing
+// TilingInterface
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
+#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/TilingInterface.h"
+
+#define GET_OP_CLASSES
+#include "TestTilingInterfaceTransformOps.h.inc"
+
+using namespace mlir;
+using namespace mlir::transform;
+
+//===----------------------------------------------------------------------===//
+// TestFuseAndYieldOp
+//===----------------------------------------------------------------------===//
+
+static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) {
+ SmallVector<Operation *> worklist;
+ llvm::SmallDenseSet<Operation *> producers;
+ worklist.push_back(op);
+ producers.insert(op);
+ while (!worklist.empty()) {
+ Operation *current = worklist.pop_back_val();
+ for (OpOperand &operand : current->getOpOperands()) {
+ Operation *producer = operand.get().getDefiningOp();
+ if (!producer || !isa<TilingInterface>(producer) ||
+ producers.count(producer))
+ continue;
+ worklist.push_back(producer);
+ producers.insert(producer);
+ }
+ }
+ return producers;
+}
+
+/// Apply a tiling transformation to all payload ops and store both the
+/// tiled operation as well as the created tile loops.
+template <typename Range>
+static LogicalResult
+applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
+ Range &&payloadOps, unsigned numLoops,
+ ArrayRef<OpFoldResult> tileSizes,
+ ArrayRef<int64_t> interchange,
+ transform::TransformResults &transformResults) {
+ SmallVector<Operation *> tiledOps;
+ SmallVector<SmallVector<Operation *>> loopOps(numLoops);
+
+ for (Operation *target : payloadOps) {
+ auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
+ if (!tilingInterfaceOp)
+ return transformOp->emitError("only TilingInterface ops are supported");
+ DominanceInfo dominanceInfo(tilingInterfaceOp);
+
+ llvm::SmallDenseSet<Operation *> tiledAndFusedOps =
+ collectTiledAndFusedOps(tilingInterfaceOp);
+ llvm::DenseSet<Operation *> yieldReplacementsFor;
+ for (auto op : tiledAndFusedOps) {
+ if (llvm::any_of(op->getUsers(), [&](Operation *user) {
+ return dominanceInfo.properlyDominates(tilingInterfaceOp, user);
+ })) {
+ yieldReplacementsFor.insert(op);
+ }
+ }
+
+ scf::SCFTilingOptions tilingOptions;
+ tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
+
+ scf::SCFTileAndFuseOptions tileAndFuseOptions;
+ tileAndFuseOptions.setTilingOptions(tilingOptions);
+
+ scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
+ [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
+ bool isDestinationOperand) {
+ Operation *owner = originalProducer.getOwner();
+ bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
+ return std::make_tuple(true, yieldProducerReplacement);
+ };
+ tileAndFuseOptions.setFusionControlFn(controlFn);
+
+ rewriter.setInsertionPoint(target);
+ FailureOr<scf::SCFTileAndFuseResult> tiledResults =
+ scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
+ rewriter, tilingInterfaceOp, tileAndFuseOptions);
+ if (failed(tiledResults))
+ return failure();
+
+ // Perform the replacement of tiled and fused values.
+ SmallVector<Operation *> opsToReplace{target};
+ llvm::append_range(opsToReplace, tiledResults->fusedProducers);
+ for (Operation *toReplace : opsToReplace) {
+ for (OpResult res : toReplace->getResults())
+ if (auto replacement = tiledResults->replacements.lookup(res)) {
+ Operation *replacementOp = replacement.getDefiningOp();
+ rewriter.replaceUsesWithIf(
+ res, replacement, [&](mlir::OpOperand &use) {
+ Operation *user = use.getOwner();
+ return dominanceInfo.properlyDominates(replacementOp, user) &&
+ user->getParentOp() == replacementOp->getParentOp();
+ });
+ }
+
+ if (toReplace->use_empty()) {
+ rewriter.eraseOp(toReplace);
+ }
+ }
+
+ // Report back the relevant handles to the transform op.
+ tiledOps.push_back(tiledResults->tiledAndFusedOps.front());
+ assert(tiledResults->loops.size() == numLoops &&
+ "Mismatched number of loops, tile and fuse transform should have "
+ "failed");
+ for (unsigned int i = 0; i < numLoops; ++i)
+ loopOps[i].push_back(tiledResults->loops[i]);
+ }
+
+ transformResults.set(transformOp->getOpResult(0), tiledOps);
+ for (unsigned int i = 0; i < numLoops; ++i)
+ transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
+
+ return success();
+}
+
+DiagnosedSilenceableFailure transform::TestFuseAndYieldOp::apply(
+ transform::TransformRewriter &rewriter,
+ mlir::transform::TransformResults &transformResults,
+ mlir::transform::TransformState &state) {
+ SmallVector<int64_t> tileSizes =
+ extractFromIntegerArrayAttr<int64_t>(getTileSizes());
+ SmallVector<int64_t> tileInterchange =
+ extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
+
+ SmallVector<OpFoldResult> tileSizesOfr =
+ getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
+
+ LogicalResult result = applyTileAndFuseToAll(
+ rewriter, getOperation(), state.getPayloadOps(getTarget()),
+ tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr,
+ tileInterchange, transformResults);
+ return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+ : DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
+// TestTileUsingForallOp
+//===----------------------------------------------------------------------===//
+
+/// Apply a tiling transformation to all payload ops and store both the
+/// tiled operation as well as the created tile loops.
+template <typename Range>
+static LogicalResult
+applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
----------------
qedawkins wrote:
nit: `applyTileUsingForallToAll`?
https://github.com/llvm/llvm-project/pull/77204
More information about the Mlir-commits
mailing list