[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Fri May 17 04:17:21 PDT 2024
================
@@ -160,6 +160,59 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
: DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// TestFuseConsumerOp
+//===----------------------------------------------------------------------===//
+
+/// Apply fusing of consumer transformation to all payload ops and store both
+/// the original consumer operation as well as the fused consumer operation.
+template <typename Range>
+static LogicalResult
+applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
+ Range &&payloadOps, TransformResults &transformResults) {
+ SmallVector<Operation *> originalConsumerOps;
+ SmallVector<Operation *> fusedConsumerOps;
+
+ for (Operation *target : payloadOps) {
+ rewriter.setInsertionPoint(target);
+
+ FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
+ scf::tileAndFuseConsumerOfSlice(rewriter, target);
+
+ if (failed(fuseConsumerResults))
+ return failure();
+
+ // Report back the relevant handles to the transform op.
+ originalConsumerOps.push_back(
+ fuseConsumerResults->origConsumerOperand->getOwner());
+ fusedConsumerOps.push_back(
+ fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
+ }
+
+ transformResults.set(transformOp->getOpResult(0), originalConsumerOps);
+ transformResults.set(transformOp->getOpResult(1), fusedConsumerOps);
+ return success();
+}
+
+DiagnosedSilenceableFailure
+transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
+ TransformResults &transformResults,
+ TransformState &state) {
+ LogicalResult result =
+ applyFuseConsumer(rewriter, getOperation(),
+ state.getPayloadOps(getTarget()), transformResults);
+ return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+ : DiagnosedSilenceableFailure::success();
+}
+
+void transform::TestFuseConsumerOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTarget(), effects);
----------------
ftynse wrote:
This may be acceptable for tests, but in a real transform we need to take a handle to the consumer being fused so we can invalidate it properly.
https://github.com/llvm/llvm-project/pull/88712
More information about the Mlir-commits
mailing list