[Mlir-commits] [mlir] acc2a12 - [mlir][Linalg] Expose the implementation of the tiling to scf.foreach_thread.

Mahesh Ravishankar llvmlistbot at llvm.org
Thu Sep 22 15:19:49 PDT 2022


Author: Mahesh Ravishankar
Date: 2022-09-22T22:19:19Z
New Revision: acc2a12c3318aad2ebdc161a0a7f4c2fec52e18d

URL: https://github.com/llvm/llvm-project/commit/acc2a12c3318aad2ebdc161a0a7f4c2fec52e18d
DIFF: https://github.com/llvm/llvm-project/commit/acc2a12c3318aad2ebdc161a0a7f4c2fec52e18d.diff

LOG: [mlir][Linalg] Expose the implementation of the tiling to scf.foreach_thread.

This allows downstream uses to use the implementation of the tiling
itself, while performing other transformations that are necessary to
go with it.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D134335

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 45b0600adcbd..f7952db7e2a2 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -15,6 +15,7 @@
 
 namespace mlir {
 class TilingInterface;
+class RewriterBase;
 namespace linalg {
 class GenericOp;
 class LinalgOp;
@@ -33,6 +34,17 @@ class LinalgOp;
 namespace mlir {
 class DialectRegistry;
 
+namespace transform {
+
+/// Implementation of tiling operations using `scf.foreach_thread`.
+DiagnosedSilenceableFailure tileToForeachThreadOpImpl(
+    RewriterBase &rewriter, transform::TransformState &state,
+    TransformOpInterface transformOp, ArrayRef<Operation *> targets,
+    ArrayRef<OpFoldResult> mixedNumThreads,
+    ArrayRef<OpFoldResult> mixedTileSizes, Optional<ArrayAttr> threadDimMapping,
+    SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps);
+} // namespace transform
+
 namespace linalg {
 void registerTransformDialectExtension(DialectRegistry &registry);
 } // namespace linalg

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 108570f53843..ca3c932c77cc 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1339,22 +1339,15 @@ transform::MapNestedForeachThreadToGpuThreads::applyToOne(
 // TileToForeachThreadOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
-    transform::TransformResults &transformResults,
-    transform::TransformState &state) {
-  IRRewriter rewriter(getContext());
-  ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
-
-  // If there the target payload ops are empty, there is nothing to do.
-  if (targets.empty()) {
-    transformResults.set(getForeachThreadOp().cast<OpResult>(), {});
-    transformResults.set(getTiledOp().cast<OpResult>(), {});
+DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
+    RewriterBase &rewriter, transform::TransformState &state,
+    TransformOpInterface transformOp, ArrayRef<Operation *> targets,
+    ArrayRef<OpFoldResult> mixedNumThreads,
+    ArrayRef<OpFoldResult> mixedTileSizes, Optional<ArrayAttr> threadDimMapping,
+    SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps) {
+
+  if (targets.empty())
     return DiagnosedSilenceableFailure(success());
-  }
-
-  // Result payload ops.
-  SmallVector<Operation *> tileOps;
-  SmallVector<Operation *> tiledOps;
 
   // Given a list of OpFoldResults that are either index attrs or op handles,
   // return a list of OpFoldResults where all op handles are replaced with the
@@ -1372,7 +1365,7 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
               state.getPayloadOps(ofr.get<Value>());
           if (dynamicNumThreads.size() != 1) {
             DiagnosedSilenceableFailure diag =
-                emitSilenceableError()
+                transformOp.emitSilenceableError()
                 << "handle must be mapped to exactly 1 payload op";
             diag.attachNote(ofr.get<Value>().getLoc())
                 << "mapped to " << dynamicNumThreads.size() << " ops";
@@ -1382,7 +1375,7 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
           if (op->getNumResults() != 1 ||
               !op->getResult(0).getType().isIndex()) {
             DiagnosedSilenceableFailure diag =
-                emitSilenceableError()
+                transformOp.emitSilenceableError()
                 << "payload op must have exactly 1 index result";
             diag.attachNote(op->getLoc())
                 << "has " << op->getNumResults() << " results";
@@ -1398,14 +1391,14 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
   // Convert to OpFoldResults[index attributes or payload op].
   SmallVector<OpFoldResult> numThreads;
   DiagnosedSilenceableFailure status =
-      getOpResultsOrIndexAttrs(numThreads, getMixedNumThreads());
+      getOpResultsOrIndexAttrs(numThreads, mixedNumThreads);
   if (!status.succeeded())
     return status;
 
   // getMixedTileSizes are OpFoldResults[index attributes or PDL operation].
   // Convert to OpFoldResults[index attributes or payload op].
   SmallVector<OpFoldResult> tileSizes;
-  status = getOpResultsOrIndexAttrs(tileSizes, getMixedTileSizes());
+  status = getOpResultsOrIndexAttrs(tileSizes, mixedTileSizes);
   if (!status.succeeded())
     return status;
 
@@ -1414,19 +1407,20 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
     auto tilableOp = dyn_cast<TilingInterface>(target);
     if (!tilableOp) {
       DiagnosedSilenceableFailure diag =
-          emitSilenceableError() << "only TilingInterface ops are supported";
+          transformOp.emitSilenceableError()
+          << "only TilingInterface ops are supported";
       diag.attachNote(target->getLoc()) << "target op";
       return diag;
     }
     rewriter.setInsertionPoint(tilableOp);
-    auto maybeThreadDimMappingAttr = getThreadDimMapping();
+    auto maybeThreadDimMappingAttr = threadDimMapping;
     auto dimMapping = llvm::to_vector(
         maybeThreadDimMappingAttr
             ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
             : ArrayRef<int64_t>{});
 
-    FailureOr<ForeachThreadTilingResult> tilingResult = failure();
-    if (!getMixedNumThreads().empty()) {
+    FailureOr<linalg::ForeachThreadTilingResult> tilingResult = failure();
+    if (!mixedNumThreads.empty()) {
       tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp,
                                                    numThreads, dimMapping);
     } else {
@@ -1435,12 +1429,32 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
     }
 
     if (failed(tilingResult))
-      return emitDefaultSilenceableFailure(tilableOp);
+      return transformOp.emitDefaultSilenceableFailure(tilableOp);
     rewriter.replaceOp(tilableOp, tilingResult->tileOp->getResults());
 
     tileOps.push_back(tilingResult->tileOp);
     tiledOps.push_back(tilingResult->tiledOp);
   }
+  return DiagnosedSilenceableFailure(success());
+}
+
+DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
+    transform::TransformResults &transformResults,
+    transform::TransformState &state) {
+  IRRewriter rewriter(getContext());
+  ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
+
+  // Result payload ops.
+  SmallVector<Operation *> tileOps;
+  SmallVector<Operation *> tiledOps;
+
+  DiagnosedSilenceableFailure diag = tileToForeachThreadOpImpl(
+      rewriter, state, cast<TransformOpInterface>(getOperation()), targets,
+      getMixedNumThreads(), getMixedTileSizes(), getThreadDimMapping(), tileOps,
+      tiledOps);
+
+  if (!diag.succeeded())
+    return diag;
 
   transformResults.set(getForeachThreadOp().cast<OpResult>(), tileOps);
   transformResults.set(getTiledOp().cast<OpResult>(), tiledOps);


        


More information about the Mlir-commits mailing list