[Mlir-commits] [mlir] f6c79c6 - [mlir][Vector]Fix bug where vector::WarpExecuteOnLane0Op are created with 2 blocks in the region
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Jun 24 07:34:04 PDT 2022
Author: Nicolas Vasilache
Date: 2022-06-24T07:33:58-07:00
New Revision: f6c79c6ae49f3a642bebe32a2346186c38bb83d7
URL: https://github.com/llvm/llvm-project/commit/f6c79c6ae49f3a642bebe32a2346186c38bb83d7
DIFF: https://github.com/llvm/llvm-project/commit/f6c79c6ae49f3a642bebe32a2346186c38bb83d7.diff
LOG: [mlir][Vector]Fix bug where vector::WarpExecuteOnLane0Op are created with 2 blocks in the region
Differential Revision: https://reviews.llvm.org/D128534
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 9f308d46900d5..08ea44225d4f0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -46,8 +46,8 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
Value bbArg = warpOpBody->getArgument(it.index());
rewriter.setInsertionPoint(ifOp);
- Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp,
- bbArg.getType());
+ Value buffer =
+ options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType());
// Store arg vector into buffer.
rewriter.setInsertionPoint(ifOp);
@@ -68,7 +68,7 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
// Insert sync after all the stores and before all the loads.
if (!warpOp.getArgs().empty()) {
rewriter.setInsertionPoint(ifOp);
- options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp);
+ options.warpSyncronizationFn(loc, rewriter, warpOp);
}
// Move body of warpOp to ifOp.
@@ -82,8 +82,8 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
Value val = it.value();
Type resultType = warpOp->getResultTypes()[it.index()];
rewriter.setInsertionPoint(ifOp);
- Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp,
- val.getType());
+ Value buffer =
+ options.warpAllocationFn(loc, rewriter, warpOp, val.getType());
// Store yielded value into buffer.
rewriter.setInsertionPoint(yieldOp);
@@ -121,7 +121,7 @@ rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
// Insert sync after all the stores and before all the loads.
if (!yieldOp.operands().empty()) {
rewriter.setInsertionPointAfter(ifOp);
- options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp);
+ options.warpSyncronizationFn(loc, rewriter, warpOp);
}
// Delete terminator and add empty scf.yield.
@@ -148,7 +148,12 @@ static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
Region &opBody = warpOp.getBodyRegion();
Region &newOpBody = newWarpOp.getBodyRegion();
+ Block &newOpFirstBlock = newOpBody.front();
rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
+ rewriter.eraseBlock(&newOpFirstBlock);
+ assert(newWarpOp.getWarpRegion().hasOneBlock() &&
+ "expected WarpOp with single block");
+
auto yield =
cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
More information about the Mlir-commits
mailing list