[Mlir-commits] [mlir] e158b56 - [mlir][linalg] Make fusion on tensor rewriter friendly (NFC).

Tobias Gysi llvmlistbot at llvm.org
Mon Sep 27 04:32:31 PDT 2021


Author: Tobias Gysi
Date: 2021-09-27T11:28:25Z
New Revision: e158b5634aa67ea3039a62c3d8bda79b77b3b21c

URL: https://github.com/llvm/llvm-project/commit/e158b5634aa67ea3039a62c3d8bda79b77b3b21c
DIFF: https://github.com/llvm/llvm-project/commit/e158b5634aa67ea3039a62c3d8bda79b77b3b21c.diff

LOG: [mlir][linalg] Make fusion on tensor rewriter friendly (NFC).

Let the calling pass or pattern replace the uses of the original root operation. Internally, the tileAndFuse still replaces uses and updates operands but only of newly created operations.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D110169

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 8d01b333e311..24c5784f2a9d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -199,9 +199,11 @@ class TileLoopNest {
 
   /// Fuse the producer of `rootOpOperand` into the tile loop nest. Returns the
   /// fused producer of fails if fusion is not possible.
-  // TODO: add replace uses callback to support passes and patterns.
   FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *rootOpOperand);
 
+  /// Returns the replacement results for the original untiled root operation.
+  ValueRange getRootOpReplacementResults();
+
   /// Returns the tiled root operation.
   LinalgOp getRootOp() { return rootOp; }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 31e53f7cf93d..448e677e7ac1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -245,10 +245,15 @@ LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
                       .setLoopType(LinalgTilingLoopType::Loops);
   Optional<TiledLinalgOp> tiledRootOp = tileLinalgOp(b, rootOp, tilingOptions);
 
-  // Replace all uses of the root operation.
+  // Exit if tiling the root operation fails.
   if (!tiledRootOp.hasValue())
     return failure();
-  rootOp->replaceAllUsesWith(tiledRootOp->tensorResults);
+
+  // Replace all uses of the root operation if it has been tiled before. All
+  // uses of the original untiled root operation are updated by the calling pass
+  // or pattern.
+  if (!isEmpty())
+    rootOp->replaceAllUsesWith(tiledRootOp->tensorResults);
 
   // Update the root operation and append the loops and tile loop dimensions.
   rootOp = tiledRootOp->op;
@@ -323,6 +328,11 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
   return clonedOp;
 }
 
+ValueRange TileLoopNest::getRootOpReplacementResults() {
+  assert(!isEmpty() && "expect tile loop nest to be non-empty");
+  return loopOps.front()->getOpResults();
+}
+
 //===----------------------------------------------------------------------===//
 // Tile and fuse entry-points.
 //===----------------------------------------------------------------------===//
@@ -433,9 +443,13 @@ struct LinalgTileAndFuseTensorOps
           "expect the tile interchange permutes the root loops");
 
     // Tile `rootOp` and fuse its producers.
-    if (failed(tileConsumerAndFuseProducers(b, rootOp, rootTileSizes,
-                                            rootInterchange)))
+    FailureOr<TileLoopNest> tileLoopNest =
+        tileConsumerAndFuseProducers(b, rootOp, rootTileSizes, rootInterchange);
+    if (failed(tileLoopNest))
       return notifyFailure("tileConsumerAndFuseProducers failed unexpectedly");
+
+    // Replace all uses of the tiled loop operation.
+    rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
   }
 };
 } // namespace


        


More information about the Mlir-commits mailing list