[Mlir-commits] [mlir] 8542794 - Relax requirements for TileOp.

Johannes Reifferscheid llvmlistbot at llvm.org
Tue Jan 24 06:17:01 PST 2023


Author: Johannes Reifferscheid
Date: 2023-01-24T15:16:55+01:00
New Revision: 85427941e7e1790aca1aa6c7c3876f9b7571e30d

URL: https://github.com/llvm/llvm-project/commit/85427941e7e1790aca1aa6c7c3876f9b7571e30d
DIFF: https://github.com/llvm/llvm-project/commit/85427941e7e1790aca1aa6c7c3876f9b7571e30d.diff

LOG: Relax requirements for TileOp.

The op doesn't need to be a LinalgOp, implementing TilingInterface and
DestinationStyleOpInterace is sufficient.

Reviewed By: nicolasvasilache, ftynse

Differential Revision: https://reviews.llvm.org/D142460

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8d06ef6cca0c5..4e0be5aa8dbd4 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1902,20 +1902,21 @@ transform::TileOp::apply(TransformResults &transformResults,
   SmallVector<Operation *> tiled;
   SmallVector<SmallVector<Operation *, 4>, 4> loops;
   loops.resize(getLoops().size());
-  for (auto &en : llvm::enumerate(targets)) {
-    auto linalgOp = dyn_cast<LinalgOp>(en.value());
-    if (!linalgOp) {
-      DiagnosedSilenceableFailure diag = emitSilenceableError()
-                                         << "only linalg ops are supported";
-      diag.attachNote(en.value()->getLoc()) << "target op";
+  for (auto &[i, op] : llvm::enumerate(targets)) {
+    auto tilingInterface = dyn_cast<TilingInterface>(op);
+    auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(op);
+    if (!tilingInterface || !dpsInterface) {
+      DiagnosedSilenceableFailure diag =
+          emitSilenceableError() << "only ops implementing TilingInterface and "
+                                    "DestinationStyleOpInterface are supported";
+      diag.attachNote(op->getLoc()) << "target op";
       return diag;
     }
 
     scf::SCFTilingOptions tilingOptions;
-    unsigned index = en.index();
     if (!tileSizes.empty()) {
-      tilingOptions.setTileSizeComputationFunction([&, index](OpBuilder &b,
-                                                              Operation *) {
+      tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
+                                                                  Operation *) {
         SmallVector<Value, 4> sizes;
         sizes.reserve(tileSizes.size());
         unsigned dynamicIdx = 0;
@@ -1942,18 +1943,16 @@ transform::TileOp::apply(TransformResults &transformResults,
     }
 
     tilingOptions.setInterchange(getInterchange());
-    TrivialPatternRewriter rewriter(linalgOp.getContext());
-    FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
-        rewriter, cast<TilingInterface>(linalgOp.getOperation()),
-        tilingOptions);
+    TrivialPatternRewriter rewriter(op->getContext());
+    FailureOr<scf::SCFTilingResult> maybeTilingResult =
+        tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions);
     if (failed(maybeTilingResult))
       return DiagnosedSilenceableFailure::definiteFailure();
 
-    if (linalgOp.hasBufferSemantics())
-      rewriter.eraseOp(linalgOp);
+    if (dpsInterface.hasBufferSemantics())
+      rewriter.eraseOp(op);
     else
-      rewriter.replaceOp(linalgOp,
-                         maybeTilingResult->loops.front()->getResults());
+      rewriter.replaceOp(op, maybeTilingResult->loops.front()->getResults());
 
     tiled.append(maybeTilingResult->tiledOps);
     for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))


        


More information about the Mlir-commits mailing list