[llvm-branch-commits] [mlir] [mlir][SCF][GPU] Add DeviceMaskingAttrInterface support to scf::Foral… (PR #146943)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jul 3 12:29:04 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Nicolas Vasilache (nicolasvasilache)
<details>
<summary>Changes</summary>
…lOp and use it to implement warp specialization.
This revision adds DeviceMaskingAttrInterface and extends DeviceMappingArrayAttr to accept a union of DeviceMappingAttrInterface and DeviceMaskingAttrInterface.
The first implementation is if the form of a GPUMappingMaskAttr, which can be additionally passed to the scf.forall.mapping attribute to specify a mask on compute resources that should be active.
Support is added to GPUTransformOps to take advantage of this information and lower to block/warpgroup/warp/thread specialization when mapped to linear ids.
---
Patch is 35.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146943.diff
12 Files Affected:
- (modified) mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td (+18)
- (modified) mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h (+10-5)
- (modified) mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td (+44-1)
- (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+12)
- (modified) mlir/lib/Dialect/GPU/CMakeLists.txt (+1)
- (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+45)
- (modified) mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp (+39-19)
- (modified) mlir/lib/Dialect/GPU/TransformOps/Utils.cpp (+73-27)
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+37-6)
- (modified) mlir/test/Dialect/GPU/transform-gpu-failing.mlir (+61)
- (modified) mlir/test/Dialect/GPU/transform-gpu.mlir (+81)
- (modified) mlir/test/Dialect/SCF/invalid.mlir (+18)
``````````diff
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td b/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td
index 63f228ca3157f..e8540027e7b77 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td
@@ -252,6 +252,24 @@ def GPULaneMappingAttr
}];
}
+def GPUMappingMaskAttr : GPU_Attr<"GPUMappingMask", "mask", [
+ DeclareAttrInterfaceMethods<DeviceMaskingAttrInterface> ] > {
+ let parameters = (ins "uint64_t":$mask);
+ let assemblyFormat = "`<` params `>`";
+ let description = [{
+ Attribute describing how to filter the processing units that a
+ region is mapped to.
+
+ In the first implementation the masking is a bitfield that specifies for
+ each processing unit whether it is active or not.
+
+ In the future, we may want to implement this as a symbol to refer to
+ dynamically defined values.
+
+ Extending op semantics with an operand is deemed too intrusive at this time.
+ }];
+}
+
def GPUMemorySpaceMappingAttr : GPU_Attr<"GPUMemorySpaceMapping", "memory_space", [
DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] > {
let parameters = (ins
diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
index de512ded59fec..0a11b8f8d3fa0 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
@@ -78,7 +78,8 @@ struct GpuIdBuilder {
/// If `useLinearMapping` is true, the `idBuilder` method returns nD values
/// used for indexing rewrites as well as 1D sizes for predicate generation.
struct GpuBlockIdBuilder : public GpuIdBuilder {
- GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping = false);
+ GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping = false,
+ DeviceMaskingAttrInterface mask = nullptr);
};
/// Builder for warpgroup ids used to map scf.forall to reindexed warpgroups.
@@ -88,7 +89,8 @@ struct GpuBlockIdBuilder : public GpuIdBuilder {
/// used for indexing rewrites as well as 1D sizes for predicate generation.
struct GpuWarpgroupIdBuilder : public GpuIdBuilder {
GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize,
- bool useLinearMapping = false);
+ bool useLinearMapping = false,
+ DeviceMaskingAttrInterface mask = nullptr);
int64_t warpSize = 32;
/// In the future this may be configured by the transformation.
static constexpr int64_t kNumWarpsPerGroup = 4;
@@ -101,7 +103,8 @@ struct GpuWarpgroupIdBuilder : public GpuIdBuilder {
/// used for indexing rewrites as well as 1D sizes for predicate generation.
struct GpuWarpIdBuilder : public GpuIdBuilder {
GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize,
- bool useLinearMapping = false);
+ bool useLinearMapping = false,
+ DeviceMaskingAttrInterface mask = nullptr);
int64_t warpSize = 32;
};
@@ -111,7 +114,8 @@ struct GpuWarpIdBuilder : public GpuIdBuilder {
/// If `useLinearMapping` is true, the `idBuilder` method returns nD values
/// used for indexing rewrites as well as 1D sizes for predicate generation.
struct GpuThreadIdBuilder : public GpuIdBuilder {
- GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping = false);
+ GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping = false,
+ DeviceMaskingAttrInterface mask = nullptr);
};
/// Builder for lane id.
@@ -119,7 +123,8 @@ struct GpuThreadIdBuilder : public GpuIdBuilder {
/// as 1D sizes for predicate generation.
/// This `useLinearMapping` case is the only supported case.
struct GpuLaneIdBuilder : public GpuIdBuilder {
- GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize, bool unused);
+ GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize, bool unused,
+ DeviceMaskingAttrInterface mask = nullptr);
int64_t warpSize = 32;
};
diff --git a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
index 96db2a40cf58e..353aaf05bee0c 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
@@ -60,8 +60,51 @@ def DeviceMappingAttrInterface : AttrInterface<"DeviceMappingAttrInterface"> {
];
}
+def DeviceMaskingAttrInterface : AttrInterface<"DeviceMaskingAttrInterface"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ Attribute interface describing how to filter the processing units that a
+ region is mapped to.
+
+ A popcount can be applied to determine the logical linear index that a
+ physical processing unit is responsible for.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the logical active id for a given physical id.
+ Expects a physicalLinearMappingId of I64Type.
+ }],
+ /*retTy=*/"Value",
+ /*methodName=*/"getLogicalLinearMappingId",
+ /*args=*/(ins "OpBuilder&":$builder, "Value":$physicalLinearMappingId)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the dynamic condition determining whether a given physical id is
+ active under the mask.
+ Expects a physicalLinearMappingId of I64Type.
+ }],
+ /*retTy=*/"Value",
+ /*methodName=*/"getIsActiveIdPredicate",
+ /*args=*/(ins "OpBuilder&":$builder, "Value":$physicalLinearMappingId)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the maximal number of pysical ids supported.
+ This is to account for temporary implementation limitations (e.g. i64)
+ and fail gracefully with actionnable error messages.
+ }],
+ /*retTy=*/"int64_t",
+ /*methodName=*/"getMaxNumPhysicalIds",
+ /*args=*/(ins)
+ >,
+ ];
+}
+
def DeviceMappingArrayAttr :
- TypedArrayAttrBase<DeviceMappingAttrInterface,
+ TypedArrayAttrBase<AnyAttrOf<[DeviceMappingAttrInterface, DeviceMaskingAttrInterface]>,
"Device Mapping array attribute"> { }
#endif // MLIR_DEVICEMAPPINGINTERFACE
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 8b14cef7437d4..2d15544e871b3 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -611,6 +611,18 @@ def ForallOp : SCF_Op<"forall", [
/// Returns operations within scf.forall.in_parallel whose destination
/// operand is the block argument `bbArg`.
SmallVector<Operation*> getCombiningOps(BlockArgument bbArg);
+
+ /// Returns the subset of DeviceMappingArrayAttrs of type
+ /// DeviceMappingAttrInterface.
+ SmallVector<DeviceMappingAttrInterface> getDeviceMappingAttrs();
+
+ /// Returns the at most one DeviceMaskingAttrInterface in the mapping.
+ /// If more than one DeviceMaskingAttrInterface is specified, returns
+ /// failure. If no mapping is present, returns nullptr.
+ FailureOr<DeviceMaskingAttrInterface> getDeviceMaskingAttr();
+
+ /// Returns true if the mapping specified for this forall op is linear.
+ bool usesLinearMapping();
}];
}
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index c8c53374d676b..4862d1f722785 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRGPUDialect
MLIRFunctionInterfaces
MLIRInferIntRangeInterface
MLIRIR
+ MLIRMathDialect
MLIRMemRefDialect
MLIRSideEffectInterfaces
MLIRSupport
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 56631f1aac084..9d74c23c24cc8 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -120,6 +121,50 @@ int64_t GPULaneMappingAttr::getRelativeIndex() const {
: getMappingId();
}
+int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; }
+
+/// 8 4 0
+/// Example mask : 0 0 0 1 1 0 1 0 0
+///
+/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
+/// Logical id for e.g. 5 (2) constructs filter (1 << 5 - 1).
+///
+/// Example mask : 0 0 0 1 1 0 1 0 0
+/// Example filter: 0 0 0 0 1 1 1 1 1
+/// Intersection : 0 0 0 0 1 0 1 0 0
+/// PopCnt : 2
+Value GPUMappingMaskAttr::getLogicalLinearMappingId(
+ OpBuilder &b, Value physicalLinearMappingId) const {
+ Location loc = physicalLinearMappingId.getLoc();
+ Value mask = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(getMask()));
+ Value one = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
+ Value filter = b.create<arith::ShLIOp>(loc, one, physicalLinearMappingId);
+ filter = b.create<arith::SubIOp>(loc, filter, one);
+ Value filteredId = b.create<arith::AndIOp>(loc, mask, filter);
+ return b.create<math::CtPopOp>(loc, filteredId);
+}
+
+/// 8 4 0
+/// Example mask : 0 0 0 1 1 0 1 0 0
+///
+/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
+/// Logical id for e.g. 5 (2) constructs filter (1 << 5).
+///
+/// Example mask : 0 0 0 1 1 0 1 0 0
+/// Example filter: 0 0 0 1 0 0 0 0 0
+/// Intersection : 0 0 0 1 0 0 0 0 0
+/// Cmp : 1
+Value GPUMappingMaskAttr::getIsActiveIdPredicate(
+ OpBuilder &b, Value physicalLinearMappingId) const {
+ Location loc = physicalLinearMappingId.getLoc();
+ Value mask = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(getMask()));
+ Value one = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
+ Value filter = b.create<arith::ShLIOp>(loc, one, physicalLinearMappingId);
+ Value filtered = b.create<arith::AndIOp>(loc, mask, filter);
+ Value zero = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(0));
+ return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, filtered, zero);
+}
+
int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
return static_cast<int64_t>(getAddressSpace());
}
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 63f87d9b5877e..a8eaa20928b7f 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -351,16 +351,25 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
seen.insert(map);
}
- auto isLinear = [](Attribute a) {
- return cast<DeviceMappingAttrInterface>(a).isLinearMapping();
+ auto isLinear = [](DeviceMappingAttrInterface attr) {
+ return attr.isLinearMapping();
};
- if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) &&
- !llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) {
+ if (llvm::any_of(forallOp.getDeviceMappingAttrs(), isLinear) &&
+ !llvm::all_of(forallOp.getDeviceMappingAttrs(), isLinear)) {
return definiteFailureHelper(
transformOp, forallOp,
"cannot mix linear and non-linear mapping modes");
}
+ FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
+ forallOp.getDeviceMaskingAttr();
+ if (succeeded(maybeMaskingAttr) && *maybeMaskingAttr &&
+ !forallOp.usesLinearMapping()) {
+ return definiteFailureHelper(
+ transformOp, forallOp,
+ "device masking is only available in linear mapping mode");
+ }
+
return DiagnosedSilenceableFailure::success();
}
@@ -381,9 +390,7 @@ verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
if (forallOp.getNumResults() > 0)
return definiteFailureHelper(transformOp, forallOp,
"only bufferized scf.forall can be mapped");
- bool useLinearMapping = cast<DeviceMappingAttrInterface>(
- forallOp.getMapping()->getValue().front())
- .isLinearMapping();
+ bool useLinearMapping = forallOp.usesLinearMapping();
// TODO: This would be more natural with support for Optional<EnumParameter>
// in GPUDeviceMappingAttr.
int64_t maxNumMappingsSupported =
@@ -682,12 +689,17 @@ DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
// The BlockIdBuilder adapts to whatever is thrown at it.
bool useLinearMapping = false;
- if (topLevelForallOp.getMapping()) {
- auto mappingAttr = cast<DeviceMappingAttrInterface>(
- topLevelForallOp.getMapping()->getValue().front());
- useLinearMapping = mappingAttr.isLinearMapping();
- }
- GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping);
+ if (topLevelForallOp.getMapping())
+ useLinearMapping = topLevelForallOp.usesLinearMapping();
+
+ FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
+ topLevelForallOp.getDeviceMaskingAttr();
+ assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr");
+ assert((!*maybeMaskingAttr || useLinearMapping) &&
+ "masking requires linear mapping");
+
+ GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping,
+ *maybeMaskingAttr);
diag = mlir::transform::gpu::mapForallToBlocksImpl(
rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
@@ -744,8 +756,7 @@ static DiagnosedSilenceableFailure
getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
int64_t warpSize, GpuIdBuilder &gpuIdBuilder) {
- auto mappingAttr = cast<DeviceMappingAttrInterface>(
- forallOp.getMapping()->getValue().front());
+ auto mappingAttr = forallOp.getDeviceMappingAttrs().front();
bool useLinearMapping = mappingAttr.isLinearMapping();
// Sanity checks that may result in runtime verification errors.
@@ -768,21 +779,30 @@ getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
if (!diag.succeeded())
return diag;
+ FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
+ forallOp.getDeviceMaskingAttr();
+ assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr");
+ assert((!*maybeMaskingAttr || useLinearMapping) &&
+ "masking requires linear mapping");
+
// Start mapping.
MLIRContext *ctx = forallOp.getContext();
gpuIdBuilder =
TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr)
.Case([&](GPUWarpgroupMappingAttr) {
- return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping);
+ return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping,
+ *maybeMaskingAttr);
})
.Case([&](GPUWarpMappingAttr) {
- return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping);
+ return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping,
+ *maybeMaskingAttr);
})
.Case([&](GPUThreadMappingAttr) {
- return GpuThreadIdBuilder(ctx, useLinearMapping);
+ return GpuThreadIdBuilder(ctx, useLinearMapping, *maybeMaskingAttr);
})
.Case([&](GPULaneMappingAttr) {
- return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping);
+ return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping,
+ *maybeMaskingAttr);
})
.Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
llvm_unreachable("unknown mapping attribute");
diff --git a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
index 795d643c05912..d1969dbc82997 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
@@ -44,7 +44,7 @@ using namespace mlir::transform::gpu;
#define DEBUG_TYPE "gpu-transforms"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
/// Build predicates to filter execution by only the activeIds. Along each
@@ -120,10 +120,23 @@ static Value buildLinearId(RewriterBase &rewriter, Location loc,
/// it in the basis of `forallMappingSizes`. The linear id builder returns an
/// n-D vector of ids for indexing and 1-D size + id for predicate generation.
template <typename ThreadOrBlockIdOp>
-static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
- auto res = [multiplicity](RewriterBase &rewriter, Location loc,
- ArrayRef<int64_t> forallMappingSizes,
- ArrayRef<int64_t> originalBasis) {
+static GpuIdBuilderFnType
+commonLinearIdBuilderFn(int64_t multiplicity = 1,
+ DeviceMaskingAttrInterface mask = nullptr) {
+ auto res = [multiplicity, mask](RewriterBase &rewriter, Location loc,
+ ArrayRef<int64_t> forallMappingSizes,
+ ArrayRef<int64_t> originalBasis) {
+ // 0. Early-exit mask case.
+ if (mask) {
+ if (computeProduct(originalBasis) >
+ mask.getMaxNumPhysicalIds() * multiplicity) {
+ return IdBuilderResult{
+ /*errorMsg=*/std::string(
+ "mask representation too short to capture all physical ids: ") +
+ std::to_string(mask.getMaxNumPhysicalIds())};
+ }
+ }
+
// 1. Compute linearId.
SmallVector<OpFoldResult> originalBasisOfr =
getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
@@ -132,9 +145,25 @@ static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
// 2. Compute scaledLinearId.
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
- OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply(
+ OpFoldResult scaledLinearIdOfr = affine::makeComposedFoldedAffineApply(
rewriter, loc, d0.floorDiv(multiplicity), {physicalLinearId});
+ // 2.b. Adjust with mask if needed.
+ Value scaledLinearIdI64;
+ Value scaledLinearId =
+ getValueOrCreateConstantIndexOp(rewriter, loc, scaledLinearIdOfr);
+ if (mask) {
+ scaledLinearId =
+ getValueOrCreateConstantIndexOp(rewriter, loc, scaledLinearIdOfr);
+ scaledLinearIdI64 = rewriter.create<arith::IndexCastUIOp>(
+ loc, rewriter.getI64Type(), scaledLinearId);
+ Value logicalLinearIdI64 =
+ mask.getLogicalLinearMappingId(rewriter, scaledLinearIdI64);
+ scaledLinearId = rewriter.create<arith::IndexCastUIOp>(
+ loc, rewriter.getIndexType(), logicalLinearIdI64);
+ LDBG("------adjusting linearId with mask: " << scaledLinearId);
+ }
+
// 3. Compute remapped indices.
SmallVector<Value> ids;
// Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
@@ -148,15 +177,23 @@ static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
affine::makeComposedAffineApply(rewriter, loc, e, {scaledLinearId}));
}
- // 4. Handle predicates using physicalLinearId.
std::string errorMsg;
SmallVector<Value> predicateOps;
- FailureOr<SmallVector<Value>> maybePredicateOps =
- buildPredicates(rewriter, loc, physicalLinearId,
- computeProduct(forallMappingSizes) * multiplicity,
- computeProduct(originalBasis), errorMsg);
- if (succeeded(maybePredicateOps))
- predicateOps = *maybePredicateOps;
+ // 4. If mask present, it takes precedence to determine predication.
+ if (mask) {
+ Value isActiveIdPredicate =
+ mask.getIsActiveIdPredicate(rewriter, scaledLinearIdI64);
+ LDBG("------adjusting predicate with mask: " << isActiveIdPredicate);
+ predicateOps.push_back(isActiveIdPredicate);
+ } else {
+ // 4.b. Otherwise, handle predicates using physicalLinearId.
+ FailureOr<SmallVector<Value>> maybePredicateOps =
+ buildPredicates(rewriter, loc, physicalLinearId,
+ computeProduct(forallMappingSizes) * multiplicity,
+ ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/146943
More information about the llvm-branch-commits
mailing list