[Mlir-commits] [mlir] d61e68f - [mlir][linalg][transform] linalg::tileToForallOpImpl processes only one target

Matthias Springer llvmlistbot at llvm.org
Fri May 5 04:35:19 PDT 2023


Author: Matthias Springer
Date: 2023-05-05T20:35:10+09:00
New Revision: d61e68ff0e0c8ad91cdbab8bb25cdd40930afb6f

URL: https://github.com/llvm/llvm-project/commit/d61e68ff0e0c8ad91cdbab8bb25cdd40930afb6f
DIFF: https://github.com/llvm/llvm-project/commit/d61e68ff0e0c8ad91cdbab8bb25cdd40930afb6f.diff

LOG: [mlir][linalg][transform] linalg::tileToForallOpImpl processes only one target

`tileToForallOpImpl` takes only one target instead of a list of all targets.

This is in preparation of D149847.

Differential Revision: https://reviews.llvm.org/D149929

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 775377e5e5bce..95133e296c757 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -23,6 +23,7 @@ class TilingInterface;
 class RewriterBase;
 
 namespace linalg {
+struct ForallTilingResult;
 class GenericOp;
 class LinalgOp;
 } // namespace linalg
@@ -48,12 +49,13 @@ class DialectRegistry;
 namespace transform {
 
 /// Implementation of tiling operations using `scf.forall`.
-DiagnosedSilenceableFailure tileToForallOpImpl(
-    RewriterBase &rewriter, transform::TransformState &state,
-    TransformOpInterface transformOp, ArrayRef<Operation *> targets,
-    ArrayRef<OpFoldResult> mixedNumThreads,
-    ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
-    SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps);
+DiagnosedSilenceableFailure
+tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state,
+                   TransformOpInterface transformOp, Operation *target,
+                   ArrayRef<OpFoldResult> mixedNumThreads,
+                   ArrayRef<OpFoldResult> mixedTileSizes,
+                   std::optional<ArrayAttr> mapping,
+                   linalg::ForallTilingResult &tilingResult);
 
 } // namespace transform
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9c7236e21842a..74eb3a2df0f96 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2506,15 +2506,11 @@ void transform::TileToForallOp::build(OpBuilder &builder,
 
 DiagnosedSilenceableFailure transform::tileToForallOpImpl(
     RewriterBase &rewriter, transform::TransformState &state,
-    TransformOpInterface transformOp, ArrayRef<Operation *> targets,
+    TransformOpInterface transformOp, Operation *target,
     ArrayRef<OpFoldResult> mixedNumThreads,
     ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
-    SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps) {
-  if (targets.empty())
-    return DiagnosedSilenceableFailure::success();
-
+    linalg::ForallTilingResult &tilingResult) {
   // Transform all targets one by one.
-  for (Operation *target : targets) {
     auto tileableOp = dyn_cast<TilingInterface>(target);
     if (!tileableOp) {
       DiagnosedSilenceableFailure diag =
@@ -2524,23 +2520,21 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
       return diag;
     }
     rewriter.setInsertionPoint(tileableOp);
-    FailureOr<linalg::ForallTilingResult> tilingResult = failure();
+    FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
     if (!mixedNumThreads.empty()) {
-      tilingResult = linalg::tileToForallOp(rewriter, tileableOp,
-                                            mixedNumThreads, mapping);
+      maybeTilingResult = linalg::tileToForallOp(rewriter, tileableOp,
+                                                 mixedNumThreads, mapping);
     } else {
-      tilingResult = linalg::tileToForallOpUsingTileSizes(
+      maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
           rewriter, tileableOp, mixedTileSizes, mapping);
     }
 
-    if (failed(tilingResult))
+    if (failed(maybeTilingResult))
       return transformOp.emitDefaultSilenceableFailure(tileableOp);
-    rewriter.replaceOp(tileableOp, tilingResult->tileOp->getResults());
+    rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
 
-    tileOps.push_back(tilingResult->tileOp);
-    tiledOps.push_back(tilingResult->tiledOp);
-  }
-  return DiagnosedSilenceableFailure::success();
+    tilingResult = *maybeTilingResult;
+    return DiagnosedSilenceableFailure::success();
 }
 
 DiagnosedSilenceableFailure
@@ -2577,12 +2571,16 @@ transform::TileToForallOp::apply(transform::TransformResults &transformResults,
   if (!status.succeeded())
     return status;
 
-  DiagnosedSilenceableFailure diag =
-      tileToForallOpImpl(rewriter, state, transformOp, targets, mixedNumThreads,
-                         mixedTileSizes, getMapping(), tileOps, tiledOps);
-
-  if (!diag.succeeded())
-    return diag;
+  for (Operation *target : targets) {
+    linalg::ForallTilingResult tilingResult;
+    DiagnosedSilenceableFailure diag = tileToForallOpImpl(
+        rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
+        getMapping(), tilingResult);
+    if (!diag.succeeded())
+      return diag;
+    tileOps.push_back(tilingResult.tileOp);
+    tiledOps.push_back(tilingResult.tiledOp);
+  }
 
   transformResults.set(getForallOp().cast<OpResult>(), tileOps);
   transformResults.set(getTiledOp().cast<OpResult>(), tiledOps);


        


More information about the Mlir-commits mailing list