[Mlir-commits] [mlir] [mlir][vector] Notify the rewriter when sinking out of warp ops (PR #71964)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 10 09:39:43 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Quinn Dawkins (qedawkins)
<details>
<summary>Changes</summary>
A number of the warp distribution patterns work by rewriting a warp op in place by moving a contained op outside. This notifies the rewriter that the warp op is changing in this case.
---
Full diff: https://github.com/llvm/llvm-project/pull/71964.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+30)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 645caa9c1378821..ffe2a3e44fdec03 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -645,6 +645,10 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
});
if (!yieldOperand)
return failure();
+
+ // Notify the rewriter that the warp op is changing.
+ rewriter.startRootUpdate(warpOp);
+
Operation *elementWise = yieldOperand->get().getDefiningOp();
unsigned operandIndex = yieldOperand->getOperandNumber();
Value distributedVal = warpOp.getResult(operandIndex);
@@ -683,6 +687,7 @@ struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
{newWarpOp.getResult(operandIndex).getType()});
rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
newOp->getResult(0));
+ rewriter.finalizeRootUpdate(warpOp);
return success();
}
};
@@ -713,6 +718,8 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
if (!dense)
return failure();
+ // Notify the rewriter that the warp op is changing.
+ rewriter.startRootUpdate(warpOp);
unsigned operandIndex = yieldOperand->getOperandNumber();
Attribute scalarAttr = dense.getSplatValue<Attribute>();
auto newAttr = DenseElementsAttr::get(
@@ -721,6 +728,7 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
rewriter.setInsertionPointAfter(warpOp);
Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
+ rewriter.finalizeRootUpdate(warpOp);
return success();
}
};
@@ -823,7 +831,9 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
OpBuilder::InsertionGuard g(rewriter);
WarpExecuteOnLane0Op newWarpOp = warpOp;
Value newMask = read.getMask();
+ bool hasMask = false;
if (read.getMask()) {
+ hasMask = true;
// TODO: Distribution of masked reads with non-trivial permutation maps
// requires the distribution of the mask to elementwise match the
// distribution of the permuted written vector. Currently the details
@@ -894,6 +904,11 @@ struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
return failure();
rewriter.replaceAllUsesWith(distributedVal, newRead);
+ if (hasMask) {
+ // Notify the rewriter that the warp op is changing.
+ rewriter.startRootUpdate(warpOp);
+ rewriter.finalizeRootUpdate(warpOp);
+ }
return success();
}
};
@@ -996,7 +1011,10 @@ struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
if (!valForwarded)
return failure();
+ // Notify the rewriter that the warp op is changing.
+ rewriter.startRootUpdate(warpOp);
rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
+ rewriter.finalizeRootUpdate(warpOp);
return success();
}
};
@@ -1024,6 +1042,8 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
vector::BroadcastableToResult::Success)
return failure();
+ // Notify the rewriter that the warp op is changing.
+ rewriter.startRootUpdate(warpOp);
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
@@ -1032,6 +1052,7 @@ struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
broadcasted);
+ rewriter.finalizeRootUpdate(warpOp);
return success();
}
};
@@ -1046,6 +1067,10 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
warpOp, [](Operation *op) { return isa<vector::ShapeCastOp>(op); });
if (!operand)
return failure();
+
+ // Notify the rewriter that the warp op is changing.
+ rewriter.startRootUpdate(warpOp);
+
auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
unsigned int operandNumber = operand->getOperandNumber();
@@ -1074,6 +1099,7 @@ struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
oldCastOp.getLoc(), castResultType,
newWarpOp->getResult(newRetIndices[0]));
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
+ rewriter.finalizeRootUpdate(warpOp);
return success();
}
};
@@ -1133,6 +1159,9 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
mask, "cannot delinearize lane ID for distribution");
assert(!delinearizedIds.empty());
+ // Notify the rewriter that the warp op is changing.
+ rewriter.startRootUpdate(warpOp);
+
AffineExpr s0, s1;
bindSymbols(rewriter.getContext(), s0, s1);
SmallVector<Value> newOperands;
@@ -1151,6 +1180,7 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
auto newMask =
rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
+ rewriter.finalizeRootUpdate(warpOp);
return success();
}
};
``````````
</details>
https://github.com/llvm/llvm-project/pull/71964
More information about the Mlir-commits
mailing list