[Mlir-commits] [mlir] [MLIR] Add continuous tiling to TileUsingForOp (PR #82792)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 23 05:05:40 PDT 2024


================
@@ -2581,6 +2643,157 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ContinuousTileSizesOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
+                                        TransformResults &transformResults,
+                                        TransformState &state) {
+
+  SmallVector<Operation *> targetOps =
+      llvm::to_vector(state.getPayloadOps(getTarget()));
+
+  if (!llvm::hasSingleElement(targetOps)) {
+    return emitDefiniteFailure() << "requires exactly one target (got "
+                                 << llvm::range_size(targetOps) << ")";
+  }
+
+  auto target = dyn_cast<LinalgOp>(*targetOps.begin());
+
+  OpBuilder builder(target.getContext());
+
+  if (!target)
+    return emitDefiniteFailure() << "expected Linalg Op";
+
+  if (isa<TransformParamTypeInterface>(getSplitPoints().getType())) {
+    if (target.hasDynamicShape()) {
+      auto diag = emitSilenceableError()
+                  << "cannot compute parametric tile sizes for dynamically "
+                     "shaped payload op";
+      diag.attachNote(target->getLoc()) << "payload op";
+      return diag;
+    }
+
+    FailureOr<StaticContinuousTileSizeSpecification> spec =
+        computeStaticContinuousTileSizes(target, getDimension(),
+                                         getTargetSize());
+    if (failed(spec)) {
+      return emitSilenceableError()
+             << "failed to compute multi-size tiling sizes";
+    }
+
+    SmallVector<int64_t> splitPoints;
+
+    auto tileSizeTripCountPairs =
+        llvm::zip_equal(spec->tileSizes, spec->tripCounts);
+
+    for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs))
+      splitPoints.push_back(std::get<0>(pair) * std::get<1>(pair));
+
+    auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
+      return llvm::to_vector(
+          llvm::map_range(values, [&](int64_t value) -> Attribute {
+            return builder.getI64IntegerAttr(value);
+          }));
+    };
+    transformResults.setParams(cast<OpResult>(getTileSizes()),
+                               makeI64AttrsFromI64(spec->tileSizes));
+    transformResults.setParams(cast<OpResult>(getSplitPoints()),
+                               makeI64AttrsFromI64(splitPoints));
+
+    return DiagnosedSilenceableFailure::success();
+  }
+
+  builder.setInsertionPoint(target);
+
+  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
+  unsigned dimension = getDimension();
+
+  FailureOr<ContinuousTileSizeSpecification> spec =
+      computeContinuousTileSizes(builder, target, dimension, targetSize, true);
+  if (failed(spec)) {
+    return emitSilenceableError() << "could not generate tile size computation";
+  }
+
+  auto tileSizeTripCountPairs =
+      llvm::zip_equal(spec->tileSizes, spec->tripCounts);
+
+  AffineExpr s0 = builder.getAffineSymbolExpr(0);
+  AffineExpr s1 = builder.getAffineSymbolExpr(1);
+  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
+    return affine::makeComposedAffineApply(builder, target->getLoc(), expr,
+                                           ofrs);
+  };
+
+  SmallVector<Value> splitPoints;
+  Value splitPoint;
+  for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs)) {
+    splitPoint = apply(s0 * s1, {std::get<0>(pair), std::get<1>(pair)});
----------------
muneebkhan85 wrote:

`Affine Apply` is needed to generate the IR to compute the chunk sizes dynamically. This is need to support the case when bounds are not known. I'm not sure if there's another way to support that. 

https://github.com/llvm/llvm-project/pull/82792


More information about the Mlir-commits mailing list