[Mlir-commits] [mlir] [MLIR] Add continuous tiling to TileUsingForOp (PR #82792)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 23 05:01:44 PDT 2024


================
@@ -2581,6 +2643,157 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ContinuousTileSizesOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
+                                        TransformResults &transformResults,
+                                        TransformState &state) {
+
+  SmallVector<Operation *> targetOps =
+      llvm::to_vector(state.getPayloadOps(getTarget()));
+
+  if (!llvm::hasSingleElement(targetOps)) {
+    return emitDefiniteFailure() << "requires exactly one target (got "
+                                 << llvm::range_size(targetOps) << ")";
+  }
+
+  auto target = dyn_cast<LinalgOp>(*targetOps.begin());
+
+  OpBuilder builder(target.getContext());
+
+  if (!target)
+    return emitDefiniteFailure() << "expected Linalg Op";
+
+  if (isa<TransformParamTypeInterface>(getSplitPoints().getType())) {
+    if (target.hasDynamicShape()) {
+      auto diag = emitSilenceableError()
+                  << "cannot compute parametric tile sizes for dynamically "
+                     "shaped payload op";
+      diag.attachNote(target->getLoc()) << "payload op";
+      return diag;
+    }
+
+    FailureOr<StaticContinuousTileSizeSpecification> spec =
+        computeStaticContinuousTileSizes(target, getDimension(),
+                                         getTargetSize());
+    if (failed(spec)) {
+      return emitSilenceableError()
+             << "failed to compute multi-size tiling sizes";
+    }
+
+    SmallVector<int64_t> splitPoints;
+
+    auto tileSizeTripCountPairs =
+        llvm::zip_equal(spec->tileSizes, spec->tripCounts);
+
+    for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs))
----------------
muneebkhan85 wrote:

Fixed.

https://github.com/llvm/llvm-project/pull/82792


More information about the Mlir-commits mailing list