[Mlir-commits] [mlir] 768615b - [mlir][Transform] NFC - Refactor forall mapping to threads and blocks into one thing
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Mar 15 05:14:02 PDT 2023
Author: Nicolas Vasilache
Date: 2023-03-15T05:09:39-07:00
New Revision: 768615bba0b05ad1b6f798edfb021e05a243b5b8
URL: https://github.com/llvm/llvm-project/commit/768615bba0b05ad1b6f798edfb021e05a243b5b8
DIFF: https://github.com/llvm/llvm-project/commit/768615bba0b05ad1b6f798edfb021e05a243b5b8.diff
LOG: [mlir][Transform] NFC - Refactor forall mapping to threads and blocks into one thing
Differential Revision: https://reviews.llvm.org/D146095
Added:
Modified:
mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
index 7c6aa7e069879..579922a3a9c03 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
@@ -39,11 +39,11 @@ namespace gpu {
/// Dynamic, `scf.forall` trip counts are currently not supported.
/// Dynamic block dim sizes are currently not supported.
DiagnosedSilenceableFailure mapForallToBlocksImpl(
- RewriterBase &rewriter, scf::ForallOp forallOp,
+ RewriterBase &rewriter, TransformOpInterface transformOp,
+ scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims,
+ const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes,
function_ref<void(RewriterBase &, scf::ForallOp, SmallVectorImpl<Value> &)>
- blockIdGenerator,
- SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp,
- const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes);
+ blockIdGenerator);
/// Search `scf.forall` ops nested under `target` and map each such op to GPU
/// threads. Mapping is one-to-one and the induction variables of `scf.forall`
@@ -54,12 +54,12 @@ DiagnosedSilenceableFailure mapForallToBlocksImpl(
/// Dynamic, `scf.forall` trip counts are currently not supported.
/// Dynamic block dim sizes are currently not supported.
DiagnosedSilenceableFailure mapNestedForallToThreadsImpl(
- RewriterBase &rewriter, Operation *target,
- const SmallVectorImpl<int64_t> &kernelBlockDims,
+ RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
+ Operation *target, const SmallVectorImpl<int64_t> &kernelBlockDims,
+ bool syncAfterDistribute,
+ const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes,
function_ref<void(RewriterBase &, scf::ForallOp, SmallVectorImpl<Value> &)>
- threadIdGenerator,
- bool syncAfterDistribute, std::optional<TransformOpInterface> transformOp,
- const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes);
+ threadIdGenerator);
/// Find the unique top level scf::ForallOp within a given target op.
DiagnosedSilenceableFailure
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 930bf46fca215..27c27756b3918 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -124,6 +124,9 @@ getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<Value> values,
SmallVector<OpFoldResult>
getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values,
llvm::function_ref<bool(Attribute, Attribute)> compare);
+SmallVector<int64_t>
+getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
+ llvm::function_ref<bool(Attribute, Attribute)> compare);
} // namespace mlir
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 93f00fe9cca01..6d87604aee0ba 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -16,17 +16,26 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
using namespace mlir;
using namespace mlir::gpu;
using namespace mlir::transform;
+#define DEBUG_TYPE "gpu-transforms"
+
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
namespace {
-/// Helper type forfunctions that generate ids for the mapping of a scf.forall.
+/// Helper type for functions that generate ids for the mapping of a scf.forall.
using IdGeneratorFnType = llvm::function_ref<void(RewriterBase &, scf::ForallOp,
SmallVectorImpl<Value> &)>;
@@ -86,7 +95,7 @@ static DiagnosedSilenceableFailure
failureHelper(std::optional<TransformOpInterface> transformOp,
scf::ForallOp forallOp, const Twine &message) {
if (transformOp.has_value())
- return transformOp->emitSilenceableError() << message;
+ return emitDefiniteFailure(*transformOp, message);
return emitDefiniteFailure(forallOp, message);
}
@@ -273,30 +282,35 @@ alterGpuLaunch(IRRewriter &rewriter, LaunchOp gpuLaunch,
// MapForallToBlocks
//===----------------------------------------------------------------------===//
-DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl(
- RewriterBase &rewriter, scf::ForallOp forallOp,
- IdGeneratorFnType blockIdGenerator, SmallVectorImpl<int64_t> &gridDims,
- TransformOpInterface transformOp,
- const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes) {
+static FailureOr<SmallVector<int64_t>> rewriteOneForallCommonImpl(
+ RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
+ scf::ForallOp forallOp,
+ const SmallVectorImpl<int64_t> &availableMappingSizes,
+ const ArrayRef<DeviceMappingAttrInterface> &allMappingAttributes,
+ IdGeneratorFnType idGenerator) {
+ LDBG("Start rewriteOneForallCommonImpl");
// Step 0. GPU-specific verifications. There is no better place to anchor
// those right now: the ForallOp is target-independent and the transform op
// does not apply to individual ForallOp.
DiagnosedSilenceableFailure diag = verifyGpuMapping(transformOp, forallOp);
if (!diag.succeeded())
- return diag;
-
- SmallVector<Attribute> blockMapping =
+ return failure();
+
+ // Step 1. Complete the mapping to a full mapping (with 1s) if necessary.
+ SmallVector<int64_t> tmpMappingSizes = llvm::to_vector(
+ llvm::map_range(forallOp.getMixedUpperBound(), [](OpFoldResult ofr) {
+ auto maybeStaticValue = getConstantIntValue(ofr);
+ assert(maybeStaticValue && "expected static value");
+ return maybeStaticValue.value();
+ }));
+ SmallVector<Attribute> forallMappings =
llvm::to_vector(forallOp.getMapping()->getValue());
-
- // Step 1. Complete the blockMapping to a full mapping (with 1s) if necessary.
- SmallVector<OpFoldResult> numBlocks = forallOp.getMixedUpperBound();
- // Ensure we have 3 block sizes, one for each id.
- for (auto attr : mappingAttributes) {
- if (!llvm::is_contained(blockMapping, attr)) {
- blockMapping.push_back(attr);
- numBlocks.push_back(rewriter.getIndexAttr(1));
- }
+ for (auto attr : allMappingAttributes) {
+ if (llvm::is_contained(forallMappings, attr))
+ continue;
+ forallMappings.push_back(attr);
+ tmpMappingSizes.push_back(1);
}
// Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
@@ -304,43 +318,116 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl(
DeviceMappingAttrInterface b) -> bool {
return a.getMappingId() < b.getMappingId();
};
- SmallVector<OpFoldResult> gridDimValues =
- getValuesSortedByKey(blockMapping, numBlocks, comparator);
- gridDims =
- llvm::to_vector(llvm::map_range(gridDimValues, [](OpFoldResult ofr) {
- return getConstantIntValue(ofr).value();
- }));
+ SmallVector<int64_t> mappingSizes =
+ getValuesSortedByKey(forallMappings, tmpMappingSizes, comparator);
+ LLVM_DEBUG(llvm::interleaveComma(mappingSizes, DBGS() << "mappingSizes: ");
+ llvm::dbgs() << "\n";
+ llvm::interleaveComma(forallMappings, DBGS() << "mappingAttrs: ");
+ llvm::dbgs() << "\n");
+
+ // Step 3. Generate the mappingIdOps using the provided generator and map the
+ // induction variables to the newly created ops. Replace ids of dimension
+ // known to be of size 1 by zero to simplify the IR.
+ SmallVector<Value> mappingIdOps;
+ Location loc = forallOp.getLoc();
+ idGenerator(rewriter, forallOp, mappingIdOps);
+ LLVM_DEBUG(llvm::interleaveComma(mappingIdOps, DBGS() << "mappingIdOps: ");
+ llvm::dbgs() << "\n");
+ assert(mappingIdOps.size() == mappingSizes.size() && "expect equal sizes");
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ if (!availableMappingSizes.empty()) {
+ for (size_t i : llvm::seq(size_t(0), availableMappingSizes.size())) {
+ if (availableMappingSizes[i] == 1)
+ mappingIdOps[i] = zero;
+ }
+ }
- // Step 3. Generate the blockids using the provided generator and map the
- // induction variables to the newly created ops.
- SmallVector<Value> blockOps;
- blockIdGenerator(rewriter, forallOp, blockOps);
IRMapping bvm;
- for (auto [blockIdx, blockDim] :
- llvm::zip(forallOp.getInductionVars(), blockMapping)) {
- bvm.map(blockIdx,
- blockOps[static_cast<int64_t>(
- blockDim.cast<DeviceMappingAttrInterface>().getMappingId())]);
+ for (auto [iv, dim] :
+ llvm::zip_equal(forallOp.getInductionVars(),
+ ArrayRef<Attribute>{forallMappings}.take_front(
+ forallOp.getInductionVars().size()))) {
+ Value peIdOp = mappingIdOps[static_cast<int64_t>(
+ dim.cast<DeviceMappingAttrInterface>().getMappingId())];
+ bvm.map(iv, peIdOp);
}
- // Step 4. Move the body of forallOp.
- // Erase the terminator first, it will not be used since we are on buffers.
+ // Step 4. Maybe create conditionals to predicate the region.
+ // Skip this step when availableMappingSizes is empty.
+ Value predicate;
+ if (!availableMappingSizes.empty()) {
+ LLVM_DEBUG(llvm::interleaveComma(availableMappingSizes,
+ DBGS() << "availableMappingSizes: ");
+ llvm::dbgs() << "\n");
+ for (auto [id, mappingSize, availableMappingSize] :
+ llvm::zip_equal(mappingIdOps, mappingSizes, availableMappingSizes)) {
+ if (mappingSize > availableMappingSize) {
+ (void)failureHelper(
+ transformOp, forallOp,
+ "Trying to map to fewer GPU threads than loop iterations but "
+ "overprovisioning is not yet supported. "
+ "Try additional tiling of the before mapping or map to more "
+ "threads.");
+ return failure();
+ }
+ if (mappingSize == availableMappingSize)
+ continue;
+ Value idx = rewriter.create<arith::ConstantIndexOp>(loc, mappingSize);
+ Value tmpPredicate = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ult, id, idx);
+ LDBG("predicate: " << tmpPredicate);
+ predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate,
+ tmpPredicate)
+ : tmpPredicate;
+ }
+ }
+
+ // Step 5. Move the body of forallOp.
+ // Erase the terminator first, it will not be used.
rewriter.eraseOp(forallOp.getTerminator());
- Block *targetBlock = forallOp->getBlock();
- Block::iterator insertionPoint = Block::iterator(forallOp);
+ Block *targetBlock;
+ Block::iterator insertionPoint;
+ if (predicate) {
+ // Step 5.a. If predicated, move at the beginning.
+ auto ifOp =
+ rewriter.create<scf::IfOp>(loc, predicate, /*withElseRegion=*/false);
+ targetBlock = ifOp.thenBlock();
+ insertionPoint = ifOp.thenBlock()->begin();
+ } else {
+ // Step 5.b. Otherwise, move inline just at the rewriter insertion point.
+ targetBlock = forallOp->getBlock();
+ insertionPoint = rewriter.getInsertionPoint();
+ }
Block &sourceBlock = forallOp.getRegion().front();
targetBlock->getOperations().splice(insertionPoint,
sourceBlock.getOperations());
- // Step 5. RAUW thread indices to thread ops.
+ // Step 6. RAUW thread indices to thread ops.
for (Value loopIndex : forallOp.getInductionVars()) {
- Value blockIdx = bvm.lookup(loopIndex);
- rewriter.replaceAllUsesWith(loopIndex, blockIdx);
+ Value threadIdx = bvm.lookup(loopIndex);
+ rewriter.replaceAllUsesWith(loopIndex, threadIdx);
}
- // Step 6. Erase old op.
+ // Step 7. Erase old op.
rewriter.eraseOp(forallOp);
+ return mappingSizes;
+}
+
+DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl(
+ RewriterBase &rewriter, TransformOpInterface transformOp,
+ scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims,
+ const ArrayRef<DeviceMappingAttrInterface> &allMappingAttributes,
+ IdGeneratorFnType idGenerator) {
+ // Pass an empty anyAvailableMappingSizes.
+ SmallVector<int64_t> anyAvailableMappingSizes;
+ FailureOr<SmallVector<int64_t>> maybeMappingSizes =
+ rewriteOneForallCommonImpl(rewriter, transformOp, forallOp,
+ anyAvailableMappingSizes, allMappingAttributes,
+ idGenerator);
+ if (failed(maybeMappingSizes))
+ return DiagnosedSilenceableFailure::definiteFailure();
+ gridDims = *maybeMappingSizes;
return DiagnosedSilenceableFailure::success();
}
@@ -389,8 +476,8 @@ transform::MapForallToBlocks::applyToOne(Operation *target,
return diag;
}
- SmallVector<int64_t> gridDim = extractFromI64ArrayAttr(getGridDim());
- if (!getGenerateGpuLaunch() && gridDim.size() != 3)
+ SmallVector<int64_t> gridDims = extractFromI64ArrayAttr(getGridDim());
+ if (!getGenerateGpuLaunch() && gridDims.size() != 3)
return transformOp.emitDefiniteFailure("transform require size-3 mapping");
OpBuilder::InsertionGuard guard(rewriter);
@@ -415,14 +502,14 @@ transform::MapForallToBlocks::applyToOne(Operation *target,
MappingToGpuBlocksHelper helper(getContext());
diag = mlir::transform::gpu::mapForallToBlocksImpl(
- rewriter, topLevelForallOp, helper.idGenerator, gridDim, transformOp,
- helper.mappingAttributes);
+ rewriter, transformOp, topLevelForallOp, gridDims,
+ helper.mappingAttributes, helper.idGenerator);
if (!diag.succeeded())
return diag;
diag = alterGpuLaunch(rewriter, gpuLaunch,
- cast<TransformOpInterface>(getOperation()), gridDim[0],
- gridDim[1], gridDim[2]);
+ cast<TransformOpInterface>(getOperation()), gridDims[0],
+ gridDims[1], gridDims[2]);
results.push_back(gpuLaunch);
return diag;
@@ -432,147 +519,33 @@ transform::MapForallToBlocks::applyToOne(Operation *target,
// MapNestedForallToThreads
//===----------------------------------------------------------------------===//
-static DiagnosedSilenceableFailure rewriteOneForallToGpuThreads(
- RewriterBase &rewriter, scf::ForallOp forallOp,
- const SmallVectorImpl<int64_t> &kernelBlockDims,
- const SmallVectorImpl<Value> &threadOps, bool syncAfterDistribute,
- std::optional<TransformOpInterface> transformOp,
- const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes) {
-
- // Step 0. GPU-specific verifications. There is no better place to anchor
- // those right now: the ForallOp is target-independent and the transform op
- // does not apply to individual ForallOp.
- DiagnosedSilenceableFailure diag = verifyGpuMapping(transformOp, forallOp);
- if (!diag.succeeded())
- return diag;
-
- Location loc = forallOp->getLoc();
-
- SmallVector<Attribute> mapping =
- llvm::to_vector(forallOp.getMapping()->getValue());
-
- // Step 1. Complete the mapping to a full mapping (with 1s) if
- // necessary.
- SmallVector<OpFoldResult> numThreads = forallOp.getMixedUpperBound();
- Attribute one = rewriter.getIndexAttr(1);
- for (auto attr : mappingAttributes) {
- if (std::find(mapping.begin(), mapping.end(), attr) == mapping.end()) {
- mapping.push_back(attr);
- numThreads.push_back(one);
- }
- }
-
- // Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
- auto comparator = [&](DeviceMappingAttrInterface a,
- DeviceMappingAttrInterface b) -> bool {
- return a.getMappingId() < b.getMappingId();
- };
- SmallVector<OpFoldResult> blockDimValues =
- getValuesSortedByKey(mapping, numThreads, comparator);
- SmallVector<int64_t> blockDims =
- llvm::to_vector(llvm::map_range(blockDimValues, [](OpFoldResult ofr) {
- return getConstantIntValue(ofr).value();
- }));
-
- // Step 3. Create the gpu.thread ops and map the induction variables to the
- // newly created ops.
- // Replace ids of dimension size 1 by zero to simplify the IR.
- // TODO
- SmallVector<Value> threadOpsUpdated(threadOps.begin(), threadOps.end());
- assert(threadOps.size() == kernelBlockDims.size());
- Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- for (size_t i : llvm::seq(size_t(0), kernelBlockDims.size())) {
- if (kernelBlockDims[i] == 1)
- threadOpsUpdated[i] = zero;
- }
- IRMapping bvm;
- for (auto [threadIdx, blockDim] :
- llvm::zip(forallOp.getInductionVars(), mapping)) {
- bvm.map(threadIdx,
- threadOpsUpdated[blockDim.cast<DeviceMappingAttrInterface>()
- .getMappingId()]);
- }
-
- // Step 4. Maybe create conditionals to predicate the region.
- Value predicate;
- for (auto [threadId, blockDim, globalBlockDim] :
- llvm::zip(threadOpsUpdated, blockDims, kernelBlockDims)) {
- if (blockDim > globalBlockDim) {
- return failureHelper(
- transformOp, forallOp,
- "Trying to map to fewer GPU threads than loop iterations but "
- "overprovisioning is not yet supported. "
- "Try additional tiling of the before mapping or map to more "
- "threads.");
- }
- if (blockDim == globalBlockDim)
- continue;
- Value threadIdx = rewriter.create<arith::ConstantIndexOp>(loc, blockDim);
- Value tmpPredicate = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, threadId, threadIdx);
- predicate =
- predicate ? rewriter.create<arith::AndIOp>(loc, predicate, tmpPredicate)
- : tmpPredicate;
- }
-
- // Step 5. Move the body of forallOp.
- // Erase the terminator first, it will not be used.
- rewriter.eraseOp(forallOp.getTerminator());
- Block *targetBlock;
- Block::iterator insertionPoint;
- if (predicate) {
- // Step 5.a. If predicated, move at the beginning.
- auto ifOp =
- rewriter.create<scf::IfOp>(loc, predicate, /*withElseRegion=*/false);
- targetBlock = ifOp.thenBlock();
- insertionPoint = ifOp.thenBlock()->begin();
- } else {
- // Step 5.b. Otherwise, move inline just before forallOp.
- targetBlock = forallOp->getBlock();
- insertionPoint = Block::iterator(forallOp);
- }
- Block &sourceBlock = forallOp.getRegion().front();
- targetBlock->getOperations().splice(insertionPoint,
- sourceBlock.getOperations());
-
- // Step 6. RAUW thread indices to thread ops.
- for (Value loopIndex : forallOp.getInductionVars()) {
- Value threadIdx = bvm.lookup(loopIndex);
- rewriter.replaceAllUsesWith(loopIndex, threadIdx);
- }
-
- // Step 7. syncthreads.
- // TODO: Need warpsync
- if (syncAfterDistribute)
- rewriter.create<BarrierOp>(loc);
-
- // Step 8. Erase old op.
- rewriter.eraseOp(forallOp);
-
- return DiagnosedSilenceableFailure::success();
-}
-
DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl(
- RewriterBase &rewriter, Operation *target,
- const SmallVectorImpl<int64_t> &blockDim, IdGeneratorFnType idGenerator,
- bool syncAfterDistribute, std::optional<TransformOpInterface> transformOp,
- const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes) {
+ RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
+ Operation *target, const SmallVectorImpl<int64_t> &kernelBlockDims,
+ bool syncAfterDistribute,
+ const ArrayRef<DeviceMappingAttrInterface> &allMappingAttributes,
+ IdGeneratorFnType idGenerator) {
DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
target->walk([&](scf::ForallOp forallOp) {
// Ignore cases with
diff erent attributes.
for (Attribute map : forallOp.getMapping()->getValue()) {
- if (!llvm::is_contained(mappingAttributes, map)) {
+ if (!llvm::is_contained(allMappingAttributes, map)) {
return WalkResult::skip();
}
}
diag = verifyGpuMapping(transformOp, forallOp);
if (diag.succeeded()) {
- rewriter.setInsertionPoint(forallOp);
- SmallVector<Value> threadOps;
- idGenerator(rewriter, forallOp, threadOps);
- diag = rewriteOneForallToGpuThreads(rewriter, forallOp, blockDim,
- threadOps, syncAfterDistribute,
- transformOp, mappingAttributes);
+ // Take the loc ahead of time
+ Location loc = forallOp.getLoc();
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointAfter(forallOp);
+ if (failed(rewriteOneForallCommonImpl(rewriter, transformOp, forallOp,
+ kernelBlockDims,
+ allMappingAttributes, idGenerator)))
+ diag = DiagnosedSilenceableFailure::definiteFailure();
+ // Add a syncthreads if needed. TODO: warpsync
+ if (syncAfterDistribute)
+ rewriter.create<BarrierOp>(loc);
}
return diag.succeeded() ? WalkResult::advance() : WalkResult::interrupt();
});
@@ -588,13 +561,13 @@ DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
if (!gpuLaunch)
return emitSilenceableError() << "Given target is not a gpu.launch";
- SmallVector<int64_t> blockDim = extractFromI64ArrayAttr(getBlockDim());
- if (blockDim.size() != 3)
+ SmallVector<int64_t> blockDims = extractFromI64ArrayAttr(getBlockDim());
+ if (blockDims.size() != 3)
return transformOp.emitDefiniteFailure("transform require size-3 mapping");
DiagnosedSilenceableFailure diag =
checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
- blockDim[0], blockDim[1], blockDim[2]);
+ blockDims[0], blockDims[1], blockDims[2]);
if (diag.isSilenceableFailure()) {
diag.attachNote(getLoc()) << getBlockDimAttrName() << " is too large";
return diag;
@@ -602,18 +575,17 @@ DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
MLIRContext *ctx = getContext();
IRRewriter rewriter(ctx);
- rewriter.setInsertionPoint(target);
MappingToGpuThreadsHelper helper(ctx);
diag = mlir::transform::gpu::mapNestedForallToThreadsImpl(
- rewriter, target, blockDim, helper.idGenerator, getSyncAfterDistribute(),
- transformOp, helper.mappingAttributes);
+ rewriter, transformOp, target, blockDims, getSyncAfterDistribute(),
+ helper.mappingAttributes, helper.idGenerator);
if (!diag.succeeded())
return diag;
diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,
- std::nullopt, std::nullopt, blockDim[0], blockDim[1],
- blockDim[2]);
+ std::nullopt, std::nullopt, blockDims[0], blockDims[1],
+ blockDims[2]);
results.push_back(gpuLaunch.getOperation());
return diag;
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 907a8c1e4914d..e646de95a76c9 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -222,4 +222,10 @@ getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<OpFoldResult> values,
return getValuesSortedByKeyImpl(keys, values, compare);
}
+SmallVector<int64_t>
+getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
+ llvm::function_ref<bool(Attribute, Attribute)> compare) {
+ return getValuesSortedByKeyImpl(keys, values, compare);
+}
+
} // namespace mlir
More information about the Mlir-commits
mailing list