[Mlir-commits] [mlir] [MLIR] Add continuous tiling to TileUsingForOp (PR #82792)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 23 05:06:37 PDT 2024


================
@@ -2581,6 +2643,157 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ContinuousTileSizesOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
+                                        TransformResults &transformResults,
+                                        TransformState &state) {
+
+  SmallVector<Operation *> targetOps =
+      llvm::to_vector(state.getPayloadOps(getTarget()));
+
+  if (!llvm::hasSingleElement(targetOps)) {
+    return emitDefiniteFailure() << "requires exactly one target (got "
+                                 << llvm::range_size(targetOps) << ")";
+  }
+
+  auto target = dyn_cast<LinalgOp>(*targetOps.begin());
+
+  OpBuilder builder(target.getContext());
+
+  if (!target)
+    return emitDefiniteFailure() << "expected Linalg Op";
+
+  if (isa<TransformParamTypeInterface>(getSplitPoints().getType())) {
+    if (target.hasDynamicShape()) {
+      auto diag = emitSilenceableError()
+                  << "cannot compute parametric tile sizes for dynamically "
+                     "shaped payload op";
+      diag.attachNote(target->getLoc()) << "payload op";
+      return diag;
+    }
+
+    FailureOr<StaticContinuousTileSizeSpecification> spec =
+        computeStaticContinuousTileSizes(target, getDimension(),
+                                         getTargetSize());
+    if (failed(spec)) {
+      return emitSilenceableError()
+             << "failed to compute multi-size tiling sizes";
+    }
+
+    SmallVector<int64_t> splitPoints;
+
+    auto tileSizeTripCountPairs =
+        llvm::zip_equal(spec->tileSizes, spec->tripCounts);
+
+    for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs))
+      splitPoints.push_back(std::get<0>(pair) * std::get<1>(pair));
+
+    auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
+      return llvm::to_vector(
+          llvm::map_range(values, [&](int64_t value) -> Attribute {
+            return builder.getI64IntegerAttr(value);
+          }));
+    };
+    transformResults.setParams(cast<OpResult>(getTileSizes()),
+                               makeI64AttrsFromI64(spec->tileSizes));
+    transformResults.setParams(cast<OpResult>(getSplitPoints()),
+                               makeI64AttrsFromI64(splitPoints));
+
+    return DiagnosedSilenceableFailure::success();
+  }
+
+  builder.setInsertionPoint(target);
+
+  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
+  unsigned dimension = getDimension();
+
+  FailureOr<ContinuousTileSizeSpecification> spec =
+      computeContinuousTileSizes(builder, target, dimension, targetSize, true);
+  if (failed(spec)) {
+    return emitSilenceableError() << "could not generate tile size computation";
+  }
+
+  auto tileSizeTripCountPairs =
+      llvm::zip_equal(spec->tileSizes, spec->tripCounts);
+
+  AffineExpr s0 = builder.getAffineSymbolExpr(0);
+  AffineExpr s1 = builder.getAffineSymbolExpr(1);
+  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
+    return affine::makeComposedAffineApply(builder, target->getLoc(), expr,
+                                           ofrs);
+  };
+
+  SmallVector<Value> splitPoints;
+  Value splitPoint;
+  for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs)) {
+    splitPoint = apply(s0 * s1, {std::get<0>(pair), std::get<1>(pair)});
+    splitPoints.push_back(splitPoint);
+  }
+
+  auto makeOpFromValue = [&](ArrayRef<Value> values) {
+    return llvm::to_vector(
+        llvm::map_range(values, [&](Value value) -> Operation * {
+          return value.getDefiningOp();
+        }));
+  };
+
+  transformResults.set(cast<OpResult>(getTileSizes()),
+                       makeOpFromValue(spec->tileSizes));
+  transformResults.set(cast<OpResult>(getSplitPoints()),
+                       makeOpFromValue(splitPoints));
----------------
muneebkhan85 wrote:

This is a good suggestion, but skipping for consistency.

https://github.com/llvm/llvm-project/pull/82792


More information about the Mlir-commits mailing list