[Mlir-commits] [mlir] [mlir][vector] Refactor parts of `VectorToSCF.cpp` (NFC) (PR #75855)
Rik Huijzer
llvmlistbot at llvm.org
Mon Dec 18 13:13:34 PST 2023
https://github.com/rikhuijzer updated https://github.com/llvm/llvm-project/pull/75855
>From 898b3b05960cecc32c1c4e15be4278c4f01ee22a Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Mon, 18 Dec 2023 21:57:08 +0100
Subject: [PATCH 1/3] [mlir][vector] Refactor parts of `VectorToSCF.cpp`
---
.../Conversion/VectorToSCF/VectorToSCF.cpp | 135 +++++++++---------
1 file changed, 70 insertions(+), 65 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 2ee314e9fedfe3..39bf5fb07f1a70 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -321,7 +321,7 @@ static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
// It may be possible to support these in future by using dynamic memref dims.
if (vectorType.getScalableDims().front())
return failure();
- auto memrefShape = type.getShape();
+ ArrayRef<int64_t> memrefShape = type.getShape();
SmallVector<int64_t, 8> newMemrefShape;
newMemrefShape.append(memrefShape.begin(), memrefShape.end());
newMemrefShape.push_back(vectorType.getDimSize(0));
@@ -350,7 +350,7 @@ struct Strategy<TransferReadOp> {
/// result to the temporary buffer allocation.
static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
- auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
+ auto storeOp = cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
assert(storeOp && "Expected TransferReadOp result used by StoreOp");
return storeOp;
}
@@ -369,8 +369,8 @@ struct Strategy<TransferReadOp> {
/// Retrieve the indices of the current StoreOp that stores into the buffer.
static void getBufferIndices(TransferReadOp xferOp,
SmallVector<Value, 8> &indices) {
- auto storeOp = getStoreOp(xferOp);
- auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
+ memref::StoreOp storeOp = getStoreOp(xferOp);
+ ValueRange prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
indices.append(prevIndices.begin(), prevIndices.end());
}
@@ -413,8 +413,8 @@ struct Strategy<TransferReadOp> {
getXferIndices(b, xferOp, iv, xferIndices);
Location loc = xferOp.getLoc();
- auto bufferType = dyn_cast<ShapedType>(buffer.getType());
- auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
+ auto bufferType = cast<ShapedType>(buffer.getType());
+ auto vecType = cast<VectorType>(bufferType.getElementType());
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
auto newXferOp = b.create<vector::TransferReadOp>(
loc, vecType, xferOp.getSource(), xferIndices,
@@ -437,8 +437,8 @@ struct Strategy<TransferReadOp> {
storeIndices.push_back(iv);
Location loc = xferOp.getLoc();
- auto bufferType = dyn_cast<ShapedType>(buffer.getType());
- auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
+ auto bufferType = cast<ShapedType>(buffer.getType());
+ auto vecType = cast<VectorType>(bufferType.getElementType());
auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
@@ -476,7 +476,7 @@ struct Strategy<TransferWriteOp> {
static void getBufferIndices(TransferWriteOp xferOp,
SmallVector<Value, 8> &indices) {
auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
- auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
+ ValueRange prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
indices.append(prevIndices.begin(), prevIndices.end());
}
@@ -591,11 +591,11 @@ struct PrepareTransferReadConversion
if (checkPrepareXferOp(xferOp, options).failed())
return failure();
- auto buffers = allocBuffers(rewriter, xferOp);
+ BufferAllocs buffers = allocBuffers(rewriter, xferOp);
auto *newXfer = rewriter.clone(*xferOp.getOperation());
newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
if (xferOp.getMask()) {
- dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
+ cast<TransferReadOp>(newXfer).getMaskMutable().assign(
buffers.maskBuffer);
}
@@ -641,7 +641,7 @@ struct PrepareTransferWriteConversion
return failure();
Location loc = xferOp.getLoc();
- auto buffers = allocBuffers(rewriter, xferOp);
+ BufferAllocs buffers = allocBuffers(rewriter, xferOp);
rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
buffers.dataBuffer);
auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
@@ -707,8 +707,8 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
if (vectorType.getRank() > 1 && vectorType.isScalable())
return failure();
- auto loc = printOp.getLoc();
- auto value = printOp.getSource();
+ Location loc = printOp.getLoc();
+ Value value = printOp.getSource();
if (auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
// Oddly sized integers are (somewhat) buggy on a lot of backends, so to
@@ -719,27 +719,27 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
auto legalIntTy = IntegerType::get(rewriter.getContext(), legalWidth,
intTy.getSignedness());
// arith can only take signless integers, so we must cast back and forth.
- auto signlessSourceVectorType =
+ VectorType signlessSource =
vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy));
- auto signlessTargetVectorType =
+ VectorType signlessTarget =
vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
- auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
- value = rewriter.create<vector::BitCastOp>(loc, signlessSourceVectorType,
+ VectorType target = vectorType.cloneWith({}, legalIntTy);
+ value = rewriter.create<vector::BitCastOp>(loc, signlessSource,
value);
- if (value.getType() != signlessTargetVectorType) {
+ if (value.getType() != signlessTarget) {
if (width == 1 || intTy.isUnsigned())
- value = rewriter.create<arith::ExtUIOp>(loc, signlessTargetVectorType,
+ value = rewriter.create<arith::ExtUIOp>(loc, signlessTarget,
value);
else
- value = rewriter.create<arith::ExtSIOp>(loc, signlessTargetVectorType,
+ value = rewriter.create<arith::ExtSIOp>(loc, signlessTarget,
value);
}
- value = rewriter.create<vector::BitCastOp>(loc, targetVectorType, value);
- vectorType = targetVectorType;
+ value = rewriter.create<vector::BitCastOp>(loc, target, value);
+ vectorType = target;
}
- auto scalableDimensions = vectorType.getScalableDims();
- auto shape = vectorType.getShape();
+ ArrayRef<bool> scalableDimensions = vectorType.getScalableDims();
+ ArrayRef<int64_t> shape = vectorType.getShape();
constexpr int64_t singletonShape[] = {1};
if (vectorType.getRank() == 0)
shape = singletonShape;
@@ -748,11 +748,11 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
// Flatten n-D vectors to 1D. This is done to allow indexing with a
// non-constant value (which can currently only be done via
// vector.extractelement for 1D vectors).
- auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
+ int flatLength = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
- auto flatVectorType =
+ VectorType flat =
VectorType::get({flatLength}, vectorType.getElementType());
- value = rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, value);
+ value = rewriter.create<vector::ShapeCastOp>(loc, flat, value);
}
vector::PrintOp firstClose;
@@ -866,6 +866,37 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
this->setHasBoundedRewriteRecursion();
}
+ Value createNewTransferOp(OpBuilder &b, OpTy xferOp, Value castedDataBuffer,
+ Value iv, ValueRange loopState,
+ Value castedMaskBuffer, PatternRewriter &rewriter,
+ Location loc) const {
+ // Create new transfer op.
+ OpTy newXfer = Strategy<OpTy>::rewriteOp(b, this->options, xferOp,
+ castedDataBuffer, iv, loopState);
+
+ // If old transfer op has a mask: Set mask on new transfer op.
+ // Special case: If the mask of the old transfer op is 1D and
+ // th unpacked dim is not a broadcast, no mask is needed on the
+ // new transfer op.
+ if (xferOp.getMask() &&
+ (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() > 1)) {
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPoint(newXfer); // Insert load before newXfer.
+
+ SmallVector<Value, 8> loadIndices;
+ Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
+ // In case of broadcast: Use same indices to load from memref
+ // as before.
+ if (!xferOp.isBroadcastDim(0))
+ loadIndices.push_back(iv);
+
+ auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer, loadIndices);
+ rewriter.updateRootInPlace(
+ newXfer, [&]() { newXfer.getMaskMutable().assign(mask); });
+ }
+ return loopState.empty() ? Value() : newXfer->getResult(0);
+ }
+
LogicalResult matchAndRewrite(OpTy xferOp,
PatternRewriter &rewriter) const override {
if (!xferOp->hasAttr(kPassLabel))
@@ -873,9 +904,9 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// Find and cast data buffer. How the buffer can be found depends on OpTy.
ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
- auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
- auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
- auto castedDataType = unpackOneDim(dataBufferType);
+ Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
+ auto dataBufferType = cast<MemRefType>(dataBuffer.getType());
+ FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
if (failed(castedDataType))
return failure();
@@ -885,8 +916,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// If the xferOp has a mask: Find and cast mask buffer.
Value castedMaskBuffer;
if (xferOp.getMask()) {
- auto maskBuffer = getMaskBuffer(xferOp);
- auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
+ Value maskBuffer = getMaskBuffer(xferOp);
if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
// Do not unpack a dimension of the mask, if:
// * To-be-unpacked transfer op dimension is a broadcast.
@@ -897,7 +927,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
} else {
// It's safe to assume the mask buffer can be unpacked if the data
// buffer was unpacked.
- auto castedMaskType = *unpackOneDim(maskBufferType);
+ auto maskBufferType = cast<MemRefType>(maskBuffer.getType());
+ MemRefType castedMaskType = *unpackOneDim(maskBufferType);
castedMaskBuffer =
locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
}
@@ -910,7 +941,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
auto step = locB.create<arith::ConstantIndexOp>(1);
// TransferWriteOps that operate on tensors return the modified tensor and
// require a loop state.
- auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
+ Value loopState = Strategy<OpTy>::initialLoopState(xferOp);
// Generate for loop.
auto result = locB.create<scf::ForOp>(
@@ -918,40 +949,14 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
[&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
Type stateType = loopState.empty() ? Type() : loopState[0].getType();
- auto result = generateInBoundsCheck(
+ Value result = generateInBoundsCheck(
b, xferOp, iv, unpackedDim(xferOp),
stateType ? TypeRange(stateType) : TypeRange(),
/*inBoundsCase=*/
[&](OpBuilder &b, Location loc) {
- // Create new transfer op.
- OpTy newXfer = Strategy<OpTy>::rewriteOp(
- b, this->options, xferOp, castedDataBuffer, iv, loopState);
-
- // If old transfer op has a mask: Set mask on new transfer op.
- // Special case: If the mask of the old transfer op is 1D and
- // the
- // unpacked dim is not a broadcast, no mask is
- // needed on the new transfer op.
- if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
- xferOp.getMaskType().getRank() > 1)) {
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPoint(newXfer); // Insert load before newXfer.
-
- SmallVector<Value, 8> loadIndices;
- Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
- // In case of broadcast: Use same indices to load from memref
- // as before.
- if (!xferOp.isBroadcastDim(0))
- loadIndices.push_back(iv);
-
- auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
- loadIndices);
- rewriter.updateRootInPlace(newXfer, [&]() {
- newXfer.getMaskMutable().assign(mask);
- });
- }
-
- return loopState.empty() ? Value() : newXfer->getResult(0);
+ return createNewTransferOp(b, xferOp, castedDataBuffer, iv,
+ loopState, castedMaskBuffer,
+ rewriter, loc);
},
/*outOfBoundsCase=*/
[&](OpBuilder &b, Location /*loc*/) {
>From 91032d429262d0c1c5b665d58002fb72d239c653 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Mon, 18 Dec 2023 22:08:04 +0100
Subject: [PATCH 2/3] Remove outdated comment
---
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 39bf5fb07f1a70..87e22703982e93 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -870,7 +870,6 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
Value iv, ValueRange loopState,
Value castedMaskBuffer, PatternRewriter &rewriter,
Location loc) const {
- // Create new transfer op.
OpTy newXfer = Strategy<OpTy>::rewriteOp(b, this->options, xferOp,
castedDataBuffer, iv, loopState);
>From c66dd8f5cca9357a70ebe23a746c6c3a39129858 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Mon, 18 Dec 2023 22:13:22 +0100
Subject: [PATCH 3/3] Apply `clang-format`
---
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp | 14 +++++---------
1 file changed, 5 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 87e22703982e93..d6509844215337 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -595,8 +595,7 @@ struct PrepareTransferReadConversion
auto *newXfer = rewriter.clone(*xferOp.getOperation());
newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
if (xferOp.getMask()) {
- cast<TransferReadOp>(newXfer).getMaskMutable().assign(
- buffers.maskBuffer);
+ cast<TransferReadOp>(newXfer).getMaskMutable().assign(buffers.maskBuffer);
}
Location loc = xferOp.getLoc();
@@ -724,15 +723,12 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
VectorType signlessTarget =
vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
VectorType target = vectorType.cloneWith({}, legalIntTy);
- value = rewriter.create<vector::BitCastOp>(loc, signlessSource,
- value);
+ value = rewriter.create<vector::BitCastOp>(loc, signlessSource, value);
if (value.getType() != signlessTarget) {
if (width == 1 || intTy.isUnsigned())
- value = rewriter.create<arith::ExtUIOp>(loc, signlessTarget,
- value);
+ value = rewriter.create<arith::ExtUIOp>(loc, signlessTarget, value);
else
- value = rewriter.create<arith::ExtSIOp>(loc, signlessTarget,
- value);
+ value = rewriter.create<arith::ExtSIOp>(loc, signlessTarget, value);
}
value = rewriter.create<vector::BitCastOp>(loc, target, value);
vectorType = target;
@@ -749,7 +745,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
// non-constant value (which can currently only be done via
// vector.extractelement for 1D vectors).
int flatLength = std::accumulate(shape.begin(), shape.end(), 1,
- std::multiplies<int64_t>());
+ std::multiplies<int64_t>());
VectorType flat =
VectorType::get({flatLength}, vectorType.getElementType());
value = rewriter.create<vector::ShapeCastOp>(loc, flat, value);
More information about the Mlir-commits
mailing list