[Mlir-commits] [mlir] c9c8527 - [mlir][Linalg] NFC - Pass TrackingListeners to all transforms

Nicolas Vasilache llvmlistbot at llvm.org
Tue Apr 18 05:01:34 PDT 2023


Author: Nicolas Vasilache
Date: 2023-04-18T05:00:53-07:00
New Revision: c9c8527a51e9907947f925489d2e909c51f361a3

URL: https://github.com/llvm/llvm-project/commit/c9c8527a51e9907947f925489d2e909c51f361a3
DIFF: https://github.com/llvm/llvm-project/commit/c9c8527a51e9907947f925489d2e909c51f361a3.diff

LOG: [mlir][Linalg] NFC - Pass TrackingListeners to all transforms

This commit properly passes tracking listeners to all Linalg transform operations.
It has no visible API or behavioral change but it helps downstream clients make better use of tracking facilities.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8c31e4184b78..33de27265afb 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -148,7 +148,8 @@ transform::BufferizeToAllocationOp::apply(transform::TransformResults &results,
                                           transform::TransformState &state) {
   Attribute memorySpace =
       getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
-  IRRewriter rewriter(getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   auto transformed = llvm::to_vector(
       llvm::map_range(state.getPayloadValues(getTarget()), [&](Value v) {
         return linalg::bufferizeToAllocation(rewriter, v, memorySpace);
@@ -207,7 +208,7 @@ transform::DecomposeOp::applyToOne(LinalgOp target,
 /// Apply a tiling transformation to all payload ops and store both the
 /// tiled operation as well as the created tile loops.
 static LogicalResult applyTilingToAll(
-    Operation *transformOp, ArrayRef<Operation *> payloadOps, unsigned numLoops,
+    RewriterBase &rewriter, Operation *transformOp, ArrayRef<Operation *> payloadOps, unsigned numLoops,
     transform::TransformResults &transformResults,
     function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
         applyFn) {
@@ -221,7 +222,6 @@ static LogicalResult applyTilingToAll(
     if (!tilingInterfaceOp)
       return transformOp->emitError("only TilingInterface ops are supported");
 
-    IRRewriter rewriter(target->getContext());
     rewriter.setInsertionPoint(target);
     FailureOr<scf::SCFTileAndFuseResult> tiledResults =
         applyFn(tilingInterfaceOp);
@@ -300,12 +300,13 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
   tilingOptions = tilingOptions.setTileSizes(tileSizes);
   scf::SCFTileAndFuseOptions tileAndFuseOptions;
   tileAndFuseOptions.tilingOptions = tilingOptions;
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   LogicalResult result = applyTilingToAll(
-      getOperation(), state.getPayloadOps(getTarget()),
+      rewriter, getOperation(), state.getPayloadOps(getTarget()),
       tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
       [&](TilingInterface tilingInterfaceOp)
           -> FailureOr<scf::SCFTileAndFuseResult> {
-        IRRewriter rewriter(getContext());
         return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
             rewriter, tilingInterfaceOp, tileAndFuseOptions);
       });
@@ -620,7 +621,8 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
     return failure();
   };
 
-  IRRewriter rewriter(getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   while (!remainingProducers.empty()) {
     auto nextProducer = getNextProducer();
     if (failed(nextProducer)) {
@@ -692,7 +694,8 @@ transform::GeneralizeOp::applyToOne(LinalgOp target,
     results.push_back(target);
     return DiagnosedSilenceableFailure::success();
   }
-  IRRewriter rewriter(getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   rewriter.setInsertionPoint(target);
   FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
   if (succeeded(generic)) {
@@ -716,7 +719,8 @@ transform::InterchangeOp::applyToOne(GenericOp target,
     results.push_back(target);
     return DiagnosedSilenceableFailure::success();
   }
-  IRRewriter rewriter(target->getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   FailureOr<GenericOp> res =
       interchangeGenericOp(rewriter, target,
                            SmallVector<unsigned>(interchangeVector.begin(),
@@ -866,7 +870,8 @@ static FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
 DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
     tensor::PackOp target, transform::ApplyToEachResultList &transformResults,
     transform::TransformState &state) {
-  IRRewriter rewriter(target->getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   rewriter.setInsertionPoint(target);
   FailureOr<LowerPackResult> res = lowerPack(rewriter, target);
   if (failed(res)) {
@@ -995,7 +1000,8 @@ static FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
 DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
     tensor::UnPackOp target, transform::ApplyToEachResultList &transformResults,
     transform::TransformState &state) {
-  IRRewriter rewriter(target->getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   rewriter.setInsertionPoint(target);
   FailureOr<LowerUnPackOpResult> res = lowerUnPack(rewriter, target);
   if (failed(res)) {
@@ -1241,7 +1247,8 @@ transform::PackOp::apply(transform::TransformResults &transformResults,
   DiagnosedSilenceableFailure status = unpackSingleIndexResultPDLOperations(
       state, *this, packedSizes, getMixedPackedSizes());
 
-  IRRewriter rewriter(linalgOp->getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   rewriter.setInsertionPoint(linalgOp);
   FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
   if (failed(maybeResult))
@@ -1425,7 +1432,8 @@ PackGreedilyOp::apply(transform::TransformResults &transformResults,
   ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
 
   SmallVector<Operation *> results;
-  IRRewriter rewriter(getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   for (Operation *op : targetOps) {
     auto linalgOp = dyn_cast<LinalgOp>(op);
     if (!linalgOp)
@@ -1598,7 +1606,8 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
   assert(packOp && linalgOp && "unexpected null op");
 
   // Step 3. Actually transpose the ops.
-  IRRewriter rewriter(getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   FailureOr<PackTransposeResult> res = packTranspose(
       rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
   // Preconditions have been checked, it is an error to fail here.
@@ -1671,7 +1680,8 @@ transform::PadOp::applyToOne(LinalgOp target,
     transposePaddings.push_back(
         extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>()));
 
-  IRRewriter rewriter(target->getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   LinalgOp paddedOp;
   FailureOr<SmallVector<Value>> result = rewriteAsPaddedOp(
       rewriter, target, extractFromI64ArrayAttr(getPaddingDimensions()),
@@ -1744,7 +1754,8 @@ DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
   if (!padOp || !loopOp)
     return emitDefiniteFailure() << "requires exactly 2 non-null handles";
 
-  IRRewriter rewriter(getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   FailureOr<linalg::detail::PackingResult> result =
       linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
                                            getTranspose());
@@ -1789,7 +1800,7 @@ transform::HoistPadOp::applyToOne(tensor::PadOp target,
   tensor::PadOp hoistedPadOp;
   SmallVector<GenericOp> transposeOps;
   TrackingListener listener(state, *this);
-  IRRewriter rewriter(target->getContext(), &listener);
+  IRRewriter rewriter(getContext(), &listener);
   FailureOr<Value> result =
       hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
                             hoistedPadOp, transposeOps);
@@ -1872,7 +1883,8 @@ transform::PromoteOp::applyToOne(LinalgOp target,
   if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
     return emitDefaultDefiniteFailure(target);
 
-  IRRewriter rewriter(target->getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   rewriter.setInsertionPoint(target);
   FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
   if (failed(res))
@@ -1901,7 +1913,8 @@ transform::ReplaceOp::apply(TransformResults &transformResults,
   }
 
   // Clone and replace.
-  IRRewriter rewriter(getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   Operation *pattern = &getBodyRegion().front().front();
   SmallVector<Operation *> replacements;
   for (Operation *target : payload) {
@@ -1957,7 +1970,8 @@ transform::ScalarizeOp::applyToOne(LinalgOp target,
     AffineMap map = target.getShapesToLoopsMap();
     if (!map)
       return tileSizes;
-    IRRewriter rewriter(b);
+    TrackingListener listener(state, *this);
+    IRRewriter rewriter(getContext(), &listener);
     SmallVector<OpFoldResult> shapeSizes =
         makeComposedFoldedMultiResultAffineApply(rewriter, loc, map,
                                                  allShapeSizes);
@@ -1971,7 +1985,8 @@ transform::ScalarizeOp::applyToOne(LinalgOp target,
     return tileSizes;
   });
   SmallVector<int64_t> emptyTileSizes;
-  IRRewriter rewriter(getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   rewriter.setInsertionPoint(target);
   FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
       rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
@@ -1998,7 +2013,8 @@ transform::RewriteInDestinationPassingStyleOp::applyToOne(
     Operation *target, transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   SmallVector<Operation *> res;
-  IRRewriter rewriter(target->getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   rewriter.setInsertionPoint(target);
   FailureOr<Operation *> maybeResult =
       TypeSwitch<Operation *, FailureOr<Operation *>>(target)
@@ -2020,7 +2036,8 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
                                            TransformState &state) {
   // Collect the dynamic split points if provided.
   ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
-  IRRewriter rewriter(getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   SmallVector<OpFoldResult> splitPoints;
   splitPoints.reserve(payload.size());
   if (getDynamicSplitPoint()) {
@@ -2227,7 +2244,8 @@ DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
                                          unsigned(getInsertSplitDimension()),
                                          bool(getInnerParallel())};
   };
-  IRRewriter rewriter(getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   rewriter.setInsertionPoint(target);
   FailureOr<SplitReductionResult> splitResult =
       (getUseScalingAlgorithm())
@@ -2267,7 +2285,8 @@ void transform::TileReductionUsingScfOp::build(
 DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
     LinalgOp target, transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
-  IRRewriter rewriter(getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   rewriter.setInsertionPoint(target);
   FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
       rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
@@ -2310,7 +2329,8 @@ void transform::TileReductionUsingForallOp::build(
 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
     LinalgOp target, transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
-  IRRewriter rewriter(getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   rewriter.setInsertionPoint(target);
   SmallVector<OpFoldResult> numThreads =
       getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
@@ -2497,7 +2517,8 @@ transform::TileOp::apply(TransformResults &transformResults,
     }
 
     tilingOptions.setInterchange(getInterchange());
-    IRRewriter rewriter(op->getContext());
+    TrackingListener listener(state, *this);
+    IRRewriter rewriter(getContext(), &listener);
     FailureOr<scf::SCFTilingResult> maybeTilingResult =
         tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions);
     if (failed(maybeTilingResult))
@@ -2739,7 +2760,8 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
 DiagnosedSilenceableFailure
 transform::TileToForallOp::apply(transform::TransformResults &transformResults,
                                  transform::TransformState &state) {
-  IRRewriter rewriter(getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   auto transformOp = cast<TransformOpInterface>(getOperation());
   ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
 
@@ -2821,10 +2843,10 @@ LogicalResult TileToForallOp::verify() {
 // TileToScfForOp
 //===----------------------------------------------------------------------===//
 
-void transform::TileToScfForOp::build(OpBuilder &builder, OperationState &result,
-                              Value target,
-                              ArrayRef<OpFoldResult> mixedTileSizes,
-                              ArrayRef<int64_t> interchange) {
+void transform::TileToScfForOp::build(OpBuilder &builder,
+                                      OperationState &result, Value target,
+                                      ArrayRef<OpFoldResult> mixedTileSizes,
+                                      ArrayRef<int64_t> interchange) {
   SmallVector<int64_t> staticTileSizes;
   SmallVector<Value> dynamicTileSizes;
   dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
@@ -2834,9 +2856,9 @@ void transform::TileToScfForOp::build(OpBuilder &builder, OperationState &result
   auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
   int64_t numExpectedLoops =
       staticTileSizes.size() - llvm::count(staticTileSizes, 0);
-  SmallVector<Type> resultTypes;
-  resultTypes.reserve(numExpectedLoops);
-  build(builder, result, 
+  SmallVector<Type> resultTypes(numExpectedLoops,
+                                pdl::OperationType::get(builder.getContext()));
+  build(builder, result,
         /*tiled_linalg_op=*/target.getType(),
         /*loops=*/resultTypes,
         /*target=*/target,
@@ -2914,7 +2936,8 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
     }
 
     tilingOptions.setInterchange(getInterchange());
-    IRRewriter rewriter(tilingInterfaceOp.getContext());
+    TrackingListener listener(state, *this);
+    IRRewriter rewriter(getContext(), &listener);
     FailureOr<scf::SCFTilingResult> tilingResult =
         tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions);
     if (failed(tilingResult))
@@ -3079,7 +3102,8 @@ transform::VectorizeOp::applyToOne(Operation *target,
 DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
     mlir::transform::TransformResults &transformResults,
     mlir::transform::TransformState &state) {
-  IRRewriter rewriter(getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
   if (targets.empty())
     return DiagnosedSilenceableFailure::success();
@@ -3184,7 +3208,8 @@ transform::HoistRedundantVectorTransfersOp::applyToOne(
 DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
     linalg::LinalgOp target, transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
-  IRRewriter rewriter(target->getContext());
+  TrackingListener listener(state, *this);
+  IRRewriter rewriter(getContext(), &listener);
   rewriter.setInsertionPoint(target);
   auto maybeTransformed =
       TypeSwitch<Operation *, FailureOr<std::pair<Operation *, Operation *>>>(
@@ -3219,7 +3244,7 @@ transform::HoistRedundantTensorSubsetsOp::applyToOne(
     Operation *target, transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
   TrackingListener listener(state, *this);
-  IRRewriter rewriter(target->getContext(), &listener);
+  IRRewriter rewriter(getContext(), &listener);
   auto forOp = dyn_cast<scf::ForOp>(target);
   if (forOp) {
     linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp);


        


More information about the Mlir-commits mailing list