[Mlir-commits] [mlir] [mlir][vector] Refactor parts of `VectorToSCF.cpp` (NFC) (PR #75855)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 18 13:10:39 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Rik Huijzer (rikhuijzer)
<details>
<summary>Changes</summary>
I'm trying to fix https://github.com/llvm/llvm-project/issues/71326. However, the code in `VectorToSCF.cpp` is quite complex. This PR suggests the following refactorings to make it slightly simpler to read:
- Specify the type explicitly instead of using `auto`, except when the type is obvious (https://llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable).
- Replace a few Systems Hungarian namings by less verbose ones (related to the previous point). For example, `auto signlessSourceVectorType` to `VectorType signlessSource`.
- Use `cast` instead of `dyn_cast` when the result is not tested.
- Extracted one method. The code is still complex, but (in my opinion) it is now slightly more clear which variables are used where.
---
Full diff: https://github.com/llvm/llvm-project/pull/75855.diff
1 Files Affected:
- (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+70-65)
``````````diff
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 2ee314e9fedfe3..39bf5fb07f1a70 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -321,7 +321,7 @@ static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
// It may be possible to support these in future by using dynamic memref dims.
if (vectorType.getScalableDims().front())
return failure();
- auto memrefShape = type.getShape();
+ ArrayRef<int64_t> memrefShape = type.getShape();
SmallVector<int64_t, 8> newMemrefShape;
newMemrefShape.append(memrefShape.begin(), memrefShape.end());
newMemrefShape.push_back(vectorType.getDimSize(0));
@@ -350,7 +350,7 @@ struct Strategy<TransferReadOp> {
/// result to the temporary buffer allocation.
static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
- auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
+ auto storeOp = cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
assert(storeOp && "Expected TransferReadOp result used by StoreOp");
return storeOp;
}
@@ -369,8 +369,8 @@ struct Strategy<TransferReadOp> {
/// Retrieve the indices of the current StoreOp that stores into the buffer.
static void getBufferIndices(TransferReadOp xferOp,
SmallVector<Value, 8> &indices) {
- auto storeOp = getStoreOp(xferOp);
- auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
+ memref::StoreOp storeOp = getStoreOp(xferOp);
+ ValueRange prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
indices.append(prevIndices.begin(), prevIndices.end());
}
@@ -413,8 +413,8 @@ struct Strategy<TransferReadOp> {
getXferIndices(b, xferOp, iv, xferIndices);
Location loc = xferOp.getLoc();
- auto bufferType = dyn_cast<ShapedType>(buffer.getType());
- auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
+ auto bufferType = cast<ShapedType>(buffer.getType());
+ auto vecType = cast<VectorType>(bufferType.getElementType());
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
auto newXferOp = b.create<vector::TransferReadOp>(
loc, vecType, xferOp.getSource(), xferIndices,
@@ -437,8 +437,8 @@ struct Strategy<TransferReadOp> {
storeIndices.push_back(iv);
Location loc = xferOp.getLoc();
- auto bufferType = dyn_cast<ShapedType>(buffer.getType());
- auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
+ auto bufferType = cast<ShapedType>(buffer.getType());
+ auto vecType = cast<VectorType>(bufferType.getElementType());
auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
@@ -476,7 +476,7 @@ struct Strategy<TransferWriteOp> {
static void getBufferIndices(TransferWriteOp xferOp,
SmallVector<Value, 8> &indices) {
auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
- auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
+ ValueRange prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
indices.append(prevIndices.begin(), prevIndices.end());
}
@@ -591,11 +591,11 @@ struct PrepareTransferReadConversion
if (checkPrepareXferOp(xferOp, options).failed())
return failure();
- auto buffers = allocBuffers(rewriter, xferOp);
+ BufferAllocs buffers = allocBuffers(rewriter, xferOp);
auto *newXfer = rewriter.clone(*xferOp.getOperation());
newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
if (xferOp.getMask()) {
- dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
+ cast<TransferReadOp>(newXfer).getMaskMutable().assign(
buffers.maskBuffer);
}
@@ -641,7 +641,7 @@ struct PrepareTransferWriteConversion
return failure();
Location loc = xferOp.getLoc();
- auto buffers = allocBuffers(rewriter, xferOp);
+ BufferAllocs buffers = allocBuffers(rewriter, xferOp);
rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
buffers.dataBuffer);
auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
@@ -707,8 +707,8 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
if (vectorType.getRank() > 1 && vectorType.isScalable())
return failure();
- auto loc = printOp.getLoc();
- auto value = printOp.getSource();
+ Location loc = printOp.getLoc();
+ Value value = printOp.getSource();
if (auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
// Oddly sized integers are (somewhat) buggy on a lot of backends, so to
@@ -719,27 +719,27 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
auto legalIntTy = IntegerType::get(rewriter.getContext(), legalWidth,
intTy.getSignedness());
// arith can only take signless integers, so we must cast back and forth.
- auto signlessSourceVectorType =
+ VectorType signlessSource =
vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy));
- auto signlessTargetVectorType =
+ VectorType signlessTarget =
vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
- auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
- value = rewriter.create<vector::BitCastOp>(loc, signlessSourceVectorType,
+ VectorType target = vectorType.cloneWith({}, legalIntTy);
+ value = rewriter.create<vector::BitCastOp>(loc, signlessSource,
value);
- if (value.getType() != signlessTargetVectorType) {
+ if (value.getType() != signlessTarget) {
if (width == 1 || intTy.isUnsigned())
- value = rewriter.create<arith::ExtUIOp>(loc, signlessTargetVectorType,
+ value = rewriter.create<arith::ExtUIOp>(loc, signlessTarget,
value);
else
- value = rewriter.create<arith::ExtSIOp>(loc, signlessTargetVectorType,
+ value = rewriter.create<arith::ExtSIOp>(loc, signlessTarget,
value);
}
- value = rewriter.create<vector::BitCastOp>(loc, targetVectorType, value);
- vectorType = targetVectorType;
+ value = rewriter.create<vector::BitCastOp>(loc, target, value);
+ vectorType = target;
}
- auto scalableDimensions = vectorType.getScalableDims();
- auto shape = vectorType.getShape();
+ ArrayRef<bool> scalableDimensions = vectorType.getScalableDims();
+ ArrayRef<int64_t> shape = vectorType.getShape();
constexpr int64_t singletonShape[] = {1};
if (vectorType.getRank() == 0)
shape = singletonShape;
@@ -748,11 +748,11 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
// Flatten n-D vectors to 1D. This is done to allow indexing with a
// non-constant value (which can currently only be done via
// vector.extractelement for 1D vectors).
- auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
+ int flatLength = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
- auto flatVectorType =
+ VectorType flat =
VectorType::get({flatLength}, vectorType.getElementType());
- value = rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, value);
+ value = rewriter.create<vector::ShapeCastOp>(loc, flat, value);
}
vector::PrintOp firstClose;
@@ -866,6 +866,37 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
this->setHasBoundedRewriteRecursion();
}
+ Value createNewTransferOp(OpBuilder &b, OpTy xferOp, Value castedDataBuffer,
+ Value iv, ValueRange loopState,
+ Value castedMaskBuffer, PatternRewriter &rewriter,
+ Location loc) const {
+ // Create new transfer op.
+ OpTy newXfer = Strategy<OpTy>::rewriteOp(b, this->options, xferOp,
+ castedDataBuffer, iv, loopState);
+
+ // If old transfer op has a mask: Set mask on new transfer op.
+ // Special case: If the mask of the old transfer op is 1D and
+ // th unpacked dim is not a broadcast, no mask is needed on the
+ // new transfer op.
+ if (xferOp.getMask() &&
+ (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() > 1)) {
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPoint(newXfer); // Insert load before newXfer.
+
+ SmallVector<Value, 8> loadIndices;
+ Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
+ // In case of broadcast: Use same indices to load from memref
+ // as before.
+ if (!xferOp.isBroadcastDim(0))
+ loadIndices.push_back(iv);
+
+ auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer, loadIndices);
+ rewriter.updateRootInPlace(
+ newXfer, [&]() { newXfer.getMaskMutable().assign(mask); });
+ }
+ return loopState.empty() ? Value() : newXfer->getResult(0);
+ }
+
LogicalResult matchAndRewrite(OpTy xferOp,
PatternRewriter &rewriter) const override {
if (!xferOp->hasAttr(kPassLabel))
@@ -873,9 +904,9 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// Find and cast data buffer. How the buffer can be found depends on OpTy.
ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
- auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
- auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
- auto castedDataType = unpackOneDim(dataBufferType);
+ Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
+ auto dataBufferType = cast<MemRefType>(dataBuffer.getType());
+ FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
if (failed(castedDataType))
return failure();
@@ -885,8 +916,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// If the xferOp has a mask: Find and cast mask buffer.
Value castedMaskBuffer;
if (xferOp.getMask()) {
- auto maskBuffer = getMaskBuffer(xferOp);
- auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
+ Value maskBuffer = getMaskBuffer(xferOp);
if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
// Do not unpack a dimension of the mask, if:
// * To-be-unpacked transfer op dimension is a broadcast.
@@ -897,7 +927,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
} else {
// It's safe to assume the mask buffer can be unpacked if the data
// buffer was unpacked.
- auto castedMaskType = *unpackOneDim(maskBufferType);
+ auto maskBufferType = cast<MemRefType>(maskBuffer.getType());
+ MemRefType castedMaskType = *unpackOneDim(maskBufferType);
castedMaskBuffer =
locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
}
@@ -910,7 +941,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
auto step = locB.create<arith::ConstantIndexOp>(1);
// TransferWriteOps that operate on tensors return the modified tensor and
// require a loop state.
- auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
+ Value loopState = Strategy<OpTy>::initialLoopState(xferOp);
// Generate for loop.
auto result = locB.create<scf::ForOp>(
@@ -918,40 +949,14 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
[&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
Type stateType = loopState.empty() ? Type() : loopState[0].getType();
- auto result = generateInBoundsCheck(
+ Value result = generateInBoundsCheck(
b, xferOp, iv, unpackedDim(xferOp),
stateType ? TypeRange(stateType) : TypeRange(),
/*inBoundsCase=*/
[&](OpBuilder &b, Location loc) {
- // Create new transfer op.
- OpTy newXfer = Strategy<OpTy>::rewriteOp(
- b, this->options, xferOp, castedDataBuffer, iv, loopState);
-
- // If old transfer op has a mask: Set mask on new transfer op.
- // Special case: If the mask of the old transfer op is 1D and
- // the
- // unpacked dim is not a broadcast, no mask is
- // needed on the new transfer op.
- if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
- xferOp.getMaskType().getRank() > 1)) {
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPoint(newXfer); // Insert load before newXfer.
-
- SmallVector<Value, 8> loadIndices;
- Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
- // In case of broadcast: Use same indices to load from memref
- // as before.
- if (!xferOp.isBroadcastDim(0))
- loadIndices.push_back(iv);
-
- auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
- loadIndices);
- rewriter.updateRootInPlace(newXfer, [&]() {
- newXfer.getMaskMutable().assign(mask);
- });
- }
-
- return loopState.empty() ? Value() : newXfer->getResult(0);
+ return createNewTransferOp(b, xferOp, castedDataBuffer, iv,
+ loopState, castedMaskBuffer,
+ rewriter, loc);
},
/*outOfBoundsCase=*/
[&](OpBuilder &b, Location /*loc*/) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/75855
More information about the Mlir-commits
mailing list