[llvm-branch-commits] [mlir] [mlir] NFC - refactor id builder and avoid leaking impl details (PR #146922)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jul 3 09:35:50 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir
Author: Nicolas Vasilache (nicolasvasilache)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/146922.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h (+14-17)
- (modified) mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp (+6-27)
- (modified) mlir/lib/Dialect/GPU/TransformOps/Utils.cpp (+107-69)
``````````diff
diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
index 111c67638efc8..de512ded59fec 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
@@ -28,27 +28,24 @@ namespace transform {
namespace gpu {
/// Helper type for functions that generate ids for the mapping of a scf.forall.
-/// Operates on both 1) an "original" basis that represents the individual
-/// thread and block ids and 2) a "scaled" basis that represents grouped ids
-/// (e.g. block clusters, warpgroups and warps).
-/// The mapping of ids is done in the "scaled" basis (i.e. when mapping to warps
-/// a division by 32 occurs).
-/// The predication is in the "original" basis using the "active" quantities
-/// (`activeMappingSizes`, `availableMappingSizes` and `activeIdOps`).
struct IdBuilderResult {
- // Ops used to replace the forall induction variables.
+ /// Error message, if not empty then building the ids failed.
+ std::string errorMsg;
+ /// Values used to replace the forall induction variables.
SmallVector<Value> mappingIdOps;
- // Available mapping sizes used to predicate the forall body when they are
- // larger than the predicate mapping sizes.
- SmallVector<int64_t> availableMappingSizes;
- // Actual mapping sizes used to predicate the forall body when they are
- // smaller than the available mapping sizes.
- SmallVector<int64_t> activeMappingSizes;
- // Ops used to predicate the forall body when activeMappingSizes is smaller
- // than the available mapping sizes.
- SmallVector<Value> activeIdOps;
+ /// Values used to predicate the forall body when activeMappingSizes is
+ /// smaller than the available mapping sizes.
+ SmallVector<Value> predicateOps;
};
+inline raw_ostream &operator<<(raw_ostream &os, const IdBuilderResult &res) {
+ llvm::interleaveComma(res.mappingIdOps, os << "----mappingIdOps: ");
+ os << "\n";
+ llvm::interleaveComma(res.predicateOps, os << "----predicateOps: ");
+ os << "\n";
+ return os;
+}
+
/// Common gpu id builder type, allows the configuration of lowering for various
/// mapping schemes. Takes:
/// - A rewriter with insertion point set before the forall op to rewrite.
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 20d1c94409238..63f87d9b5877e 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -491,6 +491,10 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
IdBuilderResult builderResult =
gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
+ if (!builderResult.errorMsg.empty())
+ return definiteFailureHelper(transformOp, forallOp, builderResult.errorMsg);
+
+ LLVM_DEBUG(DBGS() << builderResult);
// Step 4. Map the induction variables to the mappingIdOps, this may involve
// a permutation.
@@ -501,7 +505,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
- LDBG("----map: " << iv << " to" << peIdOp);
+ LDBG("----map: " << iv << " to " << peIdOp);
bvm.map(iv, peIdOp);
}
@@ -510,32 +514,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
// originalBasis and no predication occurs.
Value predicate;
if (originalBasisWasProvided) {
- SmallVector<int64_t> activeMappingSizes = builderResult.activeMappingSizes;
- SmallVector<int64_t> availableMappingSizes =
- builderResult.availableMappingSizes;
- SmallVector<Value> activeIdOps = builderResult.activeIdOps;
- LDBG("----activeMappingSizes: " << llvm::interleaved(activeMappingSizes));
- LDBG("----availableMappingSizes: "
- << llvm::interleaved(availableMappingSizes));
- LDBG("----activeIdOps: " << llvm::interleaved(activeIdOps));
- for (auto [activeId, activeMappingSize, availableMappingSize] :
- llvm::zip_equal(activeIdOps, activeMappingSizes,
- availableMappingSizes)) {
- if (activeMappingSize > availableMappingSize) {
- return definiteFailureHelper(
- transformOp, forallOp,
- "Trying to map to fewer GPU threads than loop iterations but "
- "overprovisioning is not yet supported. "
- "Try additional tiling of the before mapping or map to more "
- "threads.");
- }
- if (activeMappingSize == availableMappingSize)
- continue;
- Value idx =
- rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize);
- Value tmpPredicate = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, activeId, idx);
- LDBG("----predicate: " << tmpPredicate);
+ for (Value tmpPredicate : builderResult.predicateOps) {
predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate,
tmpPredicate)
: tmpPredicate;
diff --git a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
index c693a2fa01e89..795d643c05912 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
@@ -47,12 +47,57 @@ using namespace mlir::transform::gpu;
#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
+/// Build predicates to filter execution by only the activeIds. Along each
+/// dimension, 3 cases appear:
+/// 1. activeMappingSize > availableMappingSize: this is an unsupported case
+/// as this requires additional looping. An error message is produced to
+/// advise the user to tile more or to use more threads.
+/// 2. activeMappingSize == availableMappingSize: no predication is needed.
+/// 3. activeMappingSize < availableMappingSize: only a subset of threads
+/// should be active and we produce the boolean `id < activeMappingSize`
+/// for further use in building predicated execution.
+static FailureOr<SmallVector<Value>>
+buildPredicates(RewriterBase &rewriter, Location loc, ArrayRef<Value> activeIds,
+ ArrayRef<int64_t> activeMappingSizes,
+ ArrayRef<int64_t> availableMappingSizes,
+ std::string &errorMsg) {
+ // clang-format off
+ LLVM_DEBUG(
+ llvm::interleaveComma(
+ activeMappingSizes, DBGS() << "----activeMappingSizes: ");
+ DBGS() << "\n";
+ llvm::interleaveComma(
+ availableMappingSizes, DBGS() << "----availableMappingSizes: ");
+ DBGS() << "\n";);
+ // clang-format on
+
+ SmallVector<Value> predicateOps;
+ for (auto [activeId, activeMappingSize, availableMappingSize] :
+ llvm::zip_equal(activeIds, activeMappingSizes, availableMappingSizes)) {
+ if (activeMappingSize > availableMappingSize) {
+ errorMsg = "Trying to map to fewer GPU threads than loop iterations but "
+ "overprovisioning is not yet supported. Try additional tiling "
+ "before mapping or map to more threads.";
+ return failure();
+ }
+ if (activeMappingSize == availableMappingSize)
+ continue;
+ Value idx = rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize);
+ Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
+ activeId, idx);
+ predicateOps.push_back(pred);
+ }
+ return predicateOps;
+}
+
/// Return a flattened thread id for the workgroup with given sizes.
template <typename ThreadOrBlockIdOp>
static Value buildLinearId(RewriterBase &rewriter, Location loc,
ArrayRef<OpFoldResult> originalBasisOfr) {
- LLVM_DEBUG(DBGS() << "----buildLinearId with originalBasisOfr: "
- << llvm::interleaved(originalBasisOfr) << "\n");
+ LLVM_DEBUG(llvm::interleaveComma(
+ originalBasisOfr,
+ DBGS() << "----buildLinearId with originalBasisOfr: ");
+ llvm::dbgs() << "\n");
assert(originalBasisOfr.size() == 3 && "expected 3 sizes");
IndexType indexType = rewriter.getIndexType();
AffineExpr tx, ty, tz, bdx, bdy;
@@ -79,44 +124,43 @@ static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
auto res = [multiplicity](RewriterBase &rewriter, Location loc,
ArrayRef<int64_t> forallMappingSizes,
ArrayRef<int64_t> originalBasis) {
+ // 1. Compute linearId.
SmallVector<OpFoldResult> originalBasisOfr =
getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
- OpFoldResult linearId =
+ Value physicalLinearId =
buildLinearId<ThreadOrBlockIdOp>(rewriter, loc, originalBasisOfr);
+
+ // 2. Compute scaledLinearId.
+ AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
+ OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, d0.floorDiv(multiplicity), {physicalLinearId});
+
+ // 3. Compute remapped indices.
+ SmallVector<Value> ids;
// Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
// "row-major" order.
SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
- AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
- OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply(
- rewriter, loc, d0.floorDiv(multiplicity), {linearId});
SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
- SmallVector<Value> ids;
// Reverse back to be in [0 .. n] order.
for (AffineExpr e : llvm::reverse(delinearizingExprs)) {
ids.push_back(
affine::makeComposedAffineApply(rewriter, loc, e, {scaledLinearId}));
}
- LLVM_DEBUG(DBGS() << "--delinearization basis: "
- << llvm::interleaved(reverseBasisSizes) << "\n";
- DBGS() << "--delinearization strides: "
- << llvm::interleaved(strides) << "\n";
- DBGS() << "--delinearization exprs: "
- << llvm::interleaved(delinearizingExprs) << "\n";
- DBGS() << "--ids: " << llvm::interleaved(ids) << "\n");
-
- // Return n-D ids for indexing and 1-D size + id for predicate generation.
- return IdBuilderResult{
- /*mappingIdOps=*/ids,
- /*availableMappingSizes=*/
- SmallVector<int64_t>{computeProduct(originalBasis)},
- // `forallMappingSizes` iterate in the scaled basis, they need to be
- // scaled back into the original basis to provide tight
- // activeMappingSizes quantities for predication.
- /*activeMappingSizes=*/
- SmallVector<int64_t>{computeProduct(forallMappingSizes) * multiplicity},
- /*activeIdOps=*/SmallVector<Value>{cast<Value>(linearId)}};
+ // 4. Handle predicates using physicalLinearId.
+ std::string errorMsg;
+ SmallVector<Value> predicateOps;
+ FailureOr<SmallVector<Value>> maybePredicateOps =
+ buildPredicates(rewriter, loc, physicalLinearId,
+ computeProduct(forallMappingSizes) * multiplicity,
+ computeProduct(originalBasis), errorMsg);
+ if (succeeded(maybePredicateOps))
+ predicateOps = *maybePredicateOps;
+
+ return IdBuilderResult{/*errorMsg=*/errorMsg,
+ /*mappingIdOps=*/ids,
+ /*predicateOps=*/predicateOps};
};
return res;
@@ -143,15 +187,18 @@ static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) {
// In the 3-D mapping case, unscale the first dimension by the multiplicity.
SmallVector<int64_t> forallMappingSizeInOriginalBasis(forallMappingSizes);
forallMappingSizeInOriginalBasis[0] *= multiplicity;
- return IdBuilderResult{
- /*mappingIdOps=*/scaledIds,
- /*availableMappingSizes=*/SmallVector<int64_t>{originalBasis},
- // `forallMappingSizes` iterate in the scaled basis, they need to be
- // scaled back into the original basis to provide tight
- // activeMappingSizes quantities for predication.
- /*activeMappingSizes=*/
- SmallVector<int64_t>{forallMappingSizeInOriginalBasis},
- /*activeIdOps=*/ids};
+
+ std::string errorMsg;
+ SmallVector<Value> predicateOps;
+ FailureOr<SmallVector<Value>> maybePredicateOps =
+ buildPredicates(rewriter, loc, ids, forallMappingSizeInOriginalBasis,
+ originalBasis, errorMsg);
+ if (succeeded(maybePredicateOps))
+ predicateOps = *maybePredicateOps;
+
+ return IdBuilderResult{/*errorMsg=*/errorMsg,
+ /*mappingIdOps=*/scaledIds,
+ /*predicateOps=*/predicateOps};
};
return res;
}
@@ -159,55 +206,46 @@ static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) {
/// Create a lane id builder that takes the `originalBasis` and decompose
/// it in the basis of `forallMappingSizes`. The linear id builder returns an
/// n-D vector of ids for indexing and 1-D size + id for predicate generation.
-static GpuIdBuilderFnType laneIdBuilderFn(int64_t periodicity) {
- auto res = [periodicity](RewriterBase &rewriter, Location loc,
- ArrayRef<int64_t> forallMappingSizes,
- ArrayRef<int64_t> originalBasis) {
+static GpuIdBuilderFnType laneIdBuilderFn(int64_t warpSize) {
+ auto res = [warpSize](RewriterBase &rewriter, Location loc,
+ ArrayRef<int64_t> forallMappingSizes,
+ ArrayRef<int64_t> originalBasis) {
+ // 1. Compute linearId.
SmallVector<OpFoldResult> originalBasisOfr =
getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
- OpFoldResult linearId =
+ Value physicalLinearId =
buildLinearId<ThreadIdOp>(rewriter, loc, originalBasisOfr);
+
+ // 2. Compute laneId.
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
- linearId = affine::makeComposedFoldedAffineApply(
- rewriter, loc, d0 % periodicity, {linearId});
+ OpFoldResult laneId = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, d0 % warpSize, {physicalLinearId});
+ // 3. Compute remapped indices.
+ SmallVector<Value> ids;
// Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
// "row-major" order.
SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
- SmallVector<Value> ids;
// Reverse back to be in [0 .. n] order.
for (AffineExpr e : llvm::reverse(delinearizingExprs)) {
ids.push_back(
- affine::makeComposedAffineApply(rewriter, loc, e, {linearId}));
+ affine::makeComposedAffineApply(rewriter, loc, e, {laneId}));
}
- // clang-format off
- LLVM_DEBUG(llvm::interleaveComma(reverseBasisSizes,
- DBGS() << "--delinearization basis: ");
- llvm::dbgs() << "\n";
- llvm::interleaveComma(strides,
- DBGS() << "--delinearization strides: ");
- llvm::dbgs() << "\n";
- llvm::interleaveComma(delinearizingExprs,
- DBGS() << "--delinearization exprs: ");
- llvm::dbgs() << "\n";
- llvm::interleaveComma(ids, DBGS() << "--ids: ");
- llvm::dbgs() << "\n";);
- // clang-format on
-
- // Return n-D ids for indexing and 1-D size + id for predicate generation.
- return IdBuilderResult{
- /*mappingIdOps=*/ids,
- /*availableMappingSizes=*/
- SmallVector<int64_t>{computeProduct(originalBasis)},
- // `forallMappingSizes` iterate in the scaled basis, they need to be
- // scaled back into the original basis to provide tight
- // activeMappingSizes quantities for predication.
- /*activeMappingSizes=*/
- SmallVector<int64_t>{computeProduct(forallMappingSizes)},
- /*activeIdOps=*/SmallVector<Value>{linearId.get<Value>()}};
+ // 4. Handle predicates using laneId.
+ std::string errorMsg;
+ SmallVector<Value> predicateOps;
+ FailureOr<SmallVector<Value>> maybePredicateOps = buildPredicates(
+ rewriter, loc, cast<Value>(laneId), computeProduct(forallMappingSizes),
+ computeProduct(originalBasis), errorMsg);
+ if (succeeded(maybePredicateOps))
+ predicateOps = *maybePredicateOps;
+
+ return IdBuilderResult{/*errorMsg=*/errorMsg,
+ /*mappingIdOps=*/ids,
+ /*predicateOps=*/predicateOps};
};
return res;
``````````
</details>
https://github.com/llvm/llvm-project/pull/146922
More information about the llvm-branch-commits
mailing list