[Mlir-commits] [mlir] 12831be - [mlir][Linalg] NFC - Cleanup internal transform APIs and produce better messages on failure to apply.

Nicolas Vasilache llvmlistbot at llvm.org
Mon Sep 19 04:21:15 PDT 2022


Author: Nicolas Vasilache
Date: 2022-09-19T04:16:15-07:00
New Revision: 12831be96cdc152e6b07e74e3841da3d9a7a93ab

URL: https://github.com/llvm/llvm-project/commit/12831be96cdc152e6b07e74e3841da3d9a7a93ab
DIFF: https://github.com/llvm/llvm-project/commit/12831be96cdc152e6b07e74e3841da3d9a7a93ab.diff

LOG: [mlir][Linalg] NFC - Cleanup internal transform APIs and produce better messages on failure to apply.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bc3f10c717c3f..bf4396b446f9c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -228,12 +228,16 @@ LogicalResult transform::FuseOp::verify() {
 /// Find the first "extract" user of `producerOp` and tile it right before its
 /// use. The tiled op is fused under the `containingOp`.
 /// Return this fused op on success or nullptr if anything fails.
-static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
-                                             Operation *containingOp,
-                                             RewriterBase &rewriter) {
+static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
+                                             Diagnostic &diag,
+                                             Operation *producerOp,
+                                             Operation *containingOp) {
   auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
-  if (!tileableProducer)
+  if (!tileableProducer) {
+    diag.attachNote(producerOp->getLoc())
+        << "producer is not a TileableInterface: " << *producerOp;
     return nullptr;
+  }
 
   // Search the producer slices accessed within the containing operation.
   // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
@@ -244,8 +248,11 @@ static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
   });
 
   // Find a fusion opportunity.
-  if (it == tileableProducer->getUsers().end())
+  if (it == tileableProducer->getUsers().end()) {
+    diag.attachNote(tileableProducer->getLoc())
+        << "could not find fusion opportunity for: " << *tileableProducer;
     return nullptr;
+  }
   auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
 
   // Try to fuse the producer in-place.
@@ -256,8 +263,11 @@ static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
   FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
       rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
       sliceOpToTile.getMixedSizes());
-  if (failed(tiledProducer))
+  if (failed(tiledProducer)) {
+    diag.attachNote(tileableProducer->getLoc())
+        << "failed to tile producer op: " << *tileableProducer;
     return nullptr;
+  }
 
   // Replace the extract op.
   Operation *fusedOp = tiledProducer->getDefiningOp();
@@ -272,11 +282,25 @@ static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
 /// `containingOp`.
 /// Return this fused op on success or nullptr if anything fails.
 static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
-    Operation *producerOp, Operation *containingOp, RewriterBase &rewriter) {
+    RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
+    Operation *containingOp) {
 
   auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
-  if (!tileableProducer)
+  if (!tileableProducer) {
+    diag.attachNote(producerOp->getLoc())
+        << "producer is not a TileableInterface: " << *producerOp;
+    return nullptr;
+  }
+
+  // Ensure `tileableProducer` has exactly one destination operand that we can
+  // replace the ForeachThreadOp bbArg with.
+  auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
+  if (destinationOperands.size() != 1) {
+    diag.attachNote(tileableProducer->getLoc())
+        << "tileableProducer must have exactly one destination operand: "
+        << *tileableProducer;
     return nullptr;
+  }
 
   // Search the first use by a "scf::ForeachThreadOp" user.
   scf::ForeachThreadOp foreachThreadOp;
@@ -286,8 +310,11 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
         return foreachThreadOp;
       });
   // If it's not from the containing op, return.
-  if (!foreachThreadOp || foreachThreadOp != containingOp)
+  if (!foreachThreadOp || foreachThreadOp != containingOp) {
+    diag.attachNote(tileableProducer->getLoc())
+        << "could not find a use by the containing op: " << *tileableProducer;
     return nullptr;
+  }
 
   // Search the producer slices accessed within the containing
   // operation.
@@ -305,16 +332,13 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
   });
 
   // Find a fusion opportunity.
-  if (itBBArgUsers == bbArg.getUsers().end())
+  if (itBBArgUsers == bbArg.getUsers().end()) {
+    diag.attachNote(containingOp->getLoc())
+        << "could not find fusion opportunity for bbArg: " << bbArg;
     return nullptr;
+  }
   auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
 
-  // Ensure `tileableProducer` has exactly one destination operand that we can
-  // replace the ForeachThreadOp bbArg with.
-  auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
-  if (destinationOperands.size() != 1)
-    return nullptr;
-
   // Try to fuse the producer in-place.
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPoint(sliceOpToTile);
@@ -333,8 +357,11 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
       tileableProducerClone.generateResultTileValue(
           rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
           sliceOpToTile.getMixedSizes());
-  if (failed(tiledProducer))
+  if (failed(tiledProducer)) {
+    diag.attachNote(tileableProducer->getLoc())
+        << "failed to tile producer op: " << *tileableProducer;
     return nullptr;
+  }
 
   // Replace the extract op.
   Operation *fusedOp = tiledProducer->getDefiningOp();
@@ -349,9 +376,9 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
   return fusedOp;
 }
 
-static Operation *cloneAndFuseFirstUse(Operation *producerOp,
-                                       Operation *containingOp,
-                                       RewriterBase &rewriter) {
+static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
+                                       Operation *producerOp,
+                                       Operation *containingOp) {
   // Gather all uses inside the containing op.
   SmallVector<OpOperand *> uses;
   for (OpResult result : producerOp->getOpResults()) {
@@ -362,14 +389,19 @@ static Operation *cloneAndFuseFirstUse(Operation *producerOp,
       }
       // Cannot clone and fuse if the use is by the containing op itself: fail
       // immediately.
-      if (containingOp == use.getOwner())
+      if (containingOp == use.getOwner()) {
+        diag.attachNote(producerOp->getLoc())
+            << "producer op use by containing op cannot be fused by cloning";
         return nullptr;
+      }
     }
   }
 
   // Check for a non-empty list of fusion opportunities.
-  if (uses.empty())
+  if (uses.empty()) {
+    diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
     return nullptr;
+  }
 
   // Clone and fuse inside the containing op.
   Operation *fusedOp = nullptr;
@@ -441,18 +473,23 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
     auto nextProducer = getNextProducer();
     if (failed(nextProducer)) {
       Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark);
-      diag << "could not fuse ops into container";
+      diag << "could not find next producer to fuse into container";
       return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
     }
 
     Operation *producerOp = *nextProducer;
+
+    // Detaul diagnostic, to be complemented with more failure information.
+    Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
+    diag << "could not fuse " << *producerOp << " into " << *containingOp;
+
     // TODO: If there are multiple uses of the producer in the containing op,
     // we currently tile/clone the op multiple times (once per use). In some
     // cases, we can tile/clone once and reuse the value for each use.
     // Futhermore, producers should then be traversed according to a
     // topological sorting.
     Operation *tiled =
-        tileAndFuseFirstExtractUse(producerOp, containingOp, rewriter);
+        tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
     if (tiled) {
       fusedOps.push_back(tiled);
       continue;
@@ -460,21 +497,19 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
 
     Operation *tiledContainingOpOperand =
         tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
-            producerOp, containingOp, rewriter);
+            rewriter, diag, producerOp, containingOp);
     if (tiledContainingOpOperand) {
       fusedOps.push_back(tiledContainingOpOperand);
       continue;
     }
 
     Operation *cloned =
-        cloneAndFuseFirstUse(producerOp, containingOp, rewriter);
+        cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
     if (cloned) {
       fusedOps.push_back(cloned);
       continue;
     }
 
-    Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
-    diag << "could not fuse " << *producerOp << "into " << *containingOp;
     return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
   }
 


        


More information about the Mlir-commits mailing list