[Mlir-commits] [mlir] f22a573 - [mlir][vector] Clean up use of `llvm::zip` in `VectorOps.cpp`
Jakub Kuderski
llvmlistbot at llvm.org
Wed Nov 30 14:13:14 PST 2022
Author: Jakub Kuderski
Date: 2022-11-30T17:13:04-05:00
New Revision: f22a573b2b8afaee88001168eeeb70c77f28a03e
URL: https://github.com/llvm/llvm-project/commit/f22a573b2b8afaee88001168eeeb70c77f28a03e
DIFF: https://github.com/llvm/llvm-project/commit/f22a573b2b8afaee88001168eeeb70c77f28a03e.diff
LOG: [mlir][vector] Clean up use of `llvm::zip` in `VectorOps.cpp`
- Use `zip_equal` where iteratees are supposted to have equal lenght.
- Use `zip_first` where the first iteratee is supposed to be the
shortest.
- Use `llvm::enumerate` instead of calculating index manually.
- Use structured bindings to unpack tuples where appropriate.
- Fix a bug in a comparison in `intersectsWhereNonNegative`.
Both `zip_first` (after D138858) and `zip_equal` (introduced in D138865)
assert interatee lengths, which allows us to more precisely convey
whether we want to iterate over the common prefix (`zip`), or expect all
lengths to be the same (`zip_equal`).
Reviewed By: dcaballe, antiagainst
Differential Revision: https://reviews.llvm.org/D139022
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 328601c7f94f..f8c10bdb1411 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -87,10 +87,9 @@ static MaskFormat getMaskFormat(Value mask) {
auto shape = m.getType().getShape();
bool allTrue = true;
bool allFalse = true;
- for (auto pair : llvm::zip(masks, shape)) {
- int64_t i = std::get<0>(pair).cast<IntegerAttr>().getInt();
- int64_t u = std::get<1>(pair);
- if (i < u)
+ for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
+ int64_t i = maskIdx.cast<IntegerAttr>().getInt();
+ if (i < dimSize)
allTrue = false;
if (i > 0)
allFalse = false;
@@ -1178,10 +1177,10 @@ class ExtractFromInsertTransposeChainState {
/// Comparison is on the common prefix (i.e. zip).
template <typename ContainerA, typename ContainerB>
bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
- for (auto it : llvm::zip(a, b)) {
- if (std::get<0>(it) < 0 || std::get<0>(it) < 0)
+ for (auto [elemA, elemB] : llvm::zip(a, b)) {
+ if (elemA < 0 || elemB < 0)
continue;
- if (std::get<0>(it) != std::get<1>(it))
+ if (elemA != elemB)
return false;
}
return true;
@@ -1729,7 +1728,8 @@ computeBroadcastedUnitDims(ArrayRef<int64_t> srcShape,
int64_t rankDiff = dstShape.size() - srcShape.size();
int64_t dstDim = rankDiff;
llvm::SetVector<int64_t> res;
- for (auto [s1, s2] : llvm::zip(srcShape, dstShape.drop_front(rankDiff))) {
+ for (auto [s1, s2] :
+ llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
if (s1 != s2) {
assert(s1 == 1 && "expected dim-1 broadcasting");
res.insert(dstDim);
@@ -2384,18 +2384,16 @@ static LogicalResult
isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
ArrayRef<int64_t> shape, StringRef attrName,
bool halfOpen = true, int64_t min = 0) {
- assert(arrayAttr.size() <= shape.size());
- unsigned index = 0;
- for (auto it : llvm::zip(arrayAttr, shape)) {
- auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
- auto max = std::get<1>(it);
+ for (auto [index, attrDimPair] :
+ llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
+ int64_t val = std::get<0>(attrDimPair).cast<IntegerAttr>().getInt();
+ int64_t max = std::get<1>(attrDimPair);
if (!halfOpen)
max += 1;
if (val < min || val >= max)
return op.emitOpError("expected ")
<< attrName << " dimension " << index << " to be confined to ["
<< min << ", " << max << ")";
- ++index;
}
return success();
}
@@ -2410,8 +2408,8 @@ static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
bool halfOpen = true, int64_t min = 1) {
assert(arrayAttr1.size() <= shape.size());
assert(arrayAttr2.size() <= shape.size());
- unsigned index = 0;
- for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) {
+ for (auto [index, it] :
+ llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
auto max = std::get<2>(it);
@@ -2421,7 +2419,6 @@ static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
return op.emitOpError("expected sum(")
<< attrName1 << ", " << attrName2 << ") dimension " << index
<< " to be confined to [" << min << ", " << max << ")";
- ++index;
}
return success();
}
@@ -2962,11 +2959,9 @@ class StridedSliceConstantMaskFolder final
// Compute slice of vector mask region.
SmallVector<int64_t, 4> sliceMaskDimSizes;
- assert(sliceOffsets.size() == maskDimSizes.size());
- for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
- int64_t maskDimSize = std::get<0>(it);
- int64_t sliceOffset = std::get<1>(it);
- int64_t sliceSize = std::get<2>(it);
+ sliceMaskDimSizes.reserve(maskDimSizes.size());
+ for (auto [maskDimSize, sliceOffset, sliceSize] :
+ llvm::zip_equal(maskDimSizes, sliceOffsets, sliceSizes)) {
int64_t sliceMaskDimSize = std::max(
static_cast<int64_t>(0),
std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
@@ -4236,9 +4231,9 @@ struct SwapExtractSliceOfTransferWrite
}
// Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes
diff er.
- for (const auto &it :
- llvm::zip(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
- if (!isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it))) {
+ for (auto [insertSize, extractSize] :
+ llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
+ if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
return rewriter.notifyMatchFailure(
insertOp, "InsertSliceOp and ExtractSliceOp sizes
diff er");
}
@@ -5208,10 +5203,10 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
// Gather constant mask dimension sizes.
SmallVector<int64_t, 4> maskDimSizes;
- for (auto it : llvm::zip(createMaskOp.operands(),
- createMaskOp.getType().getShape())) {
- auto *defOp = std::get<0>(it).getDefiningOp();
- int64_t maxDimSize = std::get<1>(it);
+ maskDimSizes.reserve(createMaskOp->getNumOperands());
+ for (auto [operand, maxDimSize] : llvm::zip_equal(
+ createMaskOp.operands(), createMaskOp.getType().getShape())) {
+ Operation *defOp = operand.getDefiningOp();
int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value();
dimSize = std::min(dimSize, maxDimSize);
// If one of dim sizes is zero, set all dims to zero.
@@ -5438,10 +5433,7 @@ LogicalResult ScanOp::verify() {
if (i != reductionDim)
expectedShape.push_back(srcShape[i]);
}
- if (llvm::any_of(llvm::zip(initialValueShapes, expectedShape),
- [](std::tuple<int64_t, int64_t> s) {
- return std::get<0>(s) != std::get<1>(s);
- })) {
+ if (!llvm::equal(initialValueShapes, expectedShape)) {
return emitOpError("incompatible input/initial value shapes");
}
@@ -5588,8 +5580,8 @@ void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
OpBuilder::InsertionGuard guard(builder);
Region *warpRegion = result.addRegion();
Block *block = builder.createBlock(warpRegion);
- for (auto it : llvm::zip(blockArgTypes, args))
- block->addArgument(std::get<0>(it), std::get<1>(it).getLoc());
+ for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
+ block->addArgument(type, arg.getLoc());
}
/// Helper check if the distributed vector type is consistent with the expanded
@@ -5636,16 +5628,16 @@ LogicalResult WarpExecuteOnLane0Op::verify() {
return emitOpError(
"expected same number of yield operands and return values.");
int64_t warpSize = getWarpSize();
- for (auto it : llvm::zip(getWarpRegion().getArguments(), getArgs())) {
- if (failed(verifyDistributedType(std::get<0>(it).getType(),
- std::get<1>(it).getType(), warpSize,
- getOperation())))
+ for (auto [regionArg, arg] :
+ llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
+ if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
+ warpSize, getOperation())))
return failure();
}
- for (auto it : llvm::zip(yield.getOperands(), getResults())) {
- if (failed(verifyDistributedType(std::get<0>(it).getType(),
- std::get<1>(it).getType(), warpSize,
- getOperation())))
+ for (auto [yieldOperand, result] :
+ llvm::zip_equal(yield.getOperands(), getResults())) {
+ if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
+ warpSize, getOperation())))
return failure();
}
return success();
More information about the Mlir-commits
mailing list