[Mlir-commits] [mlir] 0d9b439 - [mlir][vector] Use `DenseI64ArrayAttr` for constant_mask dim sizes (#100997)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 29 10:08:41 PDT 2024
Author: Benjamin Maxwell
Date: 2024-07-29T18:08:37+01:00
New Revision: 0d9b4394081df959b3752283ed9ca303759dda61
URL: https://github.com/llvm/llvm-project/commit/0d9b4394081df959b3752283ed9ca303759dda61
DIFF: https://github.com/llvm/llvm-project/commit/0d9b4394081df959b3752283ed9ca303759dda61.diff
LOG: [mlir][vector] Use `DenseI64ArrayAttr` for constant_mask dim sizes (#100997)
This prevents a bunch of boilerplate conversions to/from IntegerAttrs
and int64_ts. Other than that this is a NFC.
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 39ad03c801140..3cdbd21874567 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2443,7 +2443,7 @@ def Vector_TypeCastOp :
def Vector_ConstantMaskOp :
Vector_Op<"constant_mask", [Pure]>,
- Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>,
+ Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>,
Results<(outs VectorOfAnyRankOf<[I1]>)> {
let summary = "creates a constant vector mask";
let description = [{
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d297c40760cd8..669ae586e5786 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -88,15 +88,14 @@ static MaskFormat getMaskFormat(Value mask) {
// Inspect constant mask index. If the index exceeds the
// dimension size, all bits are set. If the index is zero
// or less, no bits are set.
- ArrayAttr masks = m.getMaskDimSizes();
+ ArrayRef<int64_t> masks = m.getMaskDimSizes();
auto shape = m.getType().getShape();
bool allTrue = true;
bool allFalse = true;
for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
- int64_t i = llvm::cast<IntegerAttr>(maskIdx).getInt();
- if (i < dimSize)
+ if (maskIdx < dimSize)
allTrue = false;
- if (i > 0)
+ if (maskIdx > 0)
allFalse = false;
}
if (allTrue)
@@ -3593,8 +3592,7 @@ class StridedSliceConstantMaskFolder final
if (extractStridedSliceOp.hasNonUnitStrides())
return failure();
// Gather constant mask dimension sizes.
- SmallVector<int64_t, 4> maskDimSizes;
- populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes);
+ ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
// Gather strided slice offsets and sizes.
SmallVector<int64_t, 4> sliceOffsets;
populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
@@ -3625,7 +3623,7 @@ class StridedSliceConstantMaskFolder final
// region.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
- vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
+ sliceMaskDimSizes);
return success();
}
};
@@ -5410,21 +5408,19 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
}
if (constantMaskOp) {
- auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
+ auto maskDimSizes = constantMaskOp.getMaskDimSizes();
auto numMaskOperands = maskDimSizes.size();
// Check every mask dim size to see whether it can be dropped
for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
--i) {
- if (cast<IntegerAttr>(maskDimSizes[i]).getValue() != 1)
+ if (maskDimSizes[i] != 1)
return failure();
}
auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
- ArrayAttr newMaskOperandsAttr = rewriter.getArrayAttr(newMaskOperands);
-
rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
- newMaskOperandsAttr);
+ newMaskOperands);
return success();
}
@@ -5804,12 +5800,10 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
// ConstantMaskOp case.
auto maskDimSizes = constantMaskOp.getMaskDimSizes();
- SmallVector<Attribute> newMaskDimSizes(maskDimSizes.getValue());
- applyPermutationToVector(newMaskDimSizes, permutation);
+ auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
- transpOp, transpOp.getResultVectorType(),
- ArrayAttr::get(transpOp.getContext(), newMaskDimSizes));
+ transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
return success();
}
};
@@ -5832,7 +5826,7 @@ LogicalResult ConstantMaskOp::verify() {
if (resultType.getRank() == 0) {
if (getMaskDimSizes().size() != 1)
return emitError("array attr must have length 1 for 0-D vectors");
- auto dim = llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt();
+ auto dim = getMaskDimSizes()[0];
if (dim != 0 && dim != 1)
return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
return success();
@@ -5846,9 +5840,8 @@ LogicalResult ConstantMaskOp::verify() {
// result dimension size.
auto resultShape = resultType.getShape();
auto resultScalableDims = resultType.getScalableDims();
- SmallVector<int64_t, 4> maskDimSizes;
- for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
- int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
+ ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
+ for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
if (maskDimSize < 0 || maskDimSize > resultShape[index])
return emitOpError(
"array attr of size out of bounds of vector result dimension size");
@@ -5856,7 +5849,6 @@ LogicalResult ConstantMaskOp::verify() {
maskDimSize != resultShape[index])
return emitOpError(
"only supports 'none set' or 'all set' scalable dimensions");
- maskDimSizes.push_back(maskDimSize);
}
// Verify that if one mask dim size is zero, they all should be zero (because
// the mask region is a conjunction of each mask dimension interval).
@@ -5873,11 +5865,10 @@ bool ConstantMaskOp::isAllOnesMask() {
// Check the corner case of 0-D vectors first.
if (resultType.getRank() == 0) {
assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
- return llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() == 1;
+ return getMaskDimSizes()[0] == 1;
}
- for (const auto [resultSize, intAttr] :
+ for (const auto [resultSize, maskDimSize] :
llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
- int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
if (maskDimSize < resultSize)
return false;
}
@@ -6007,9 +5998,8 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
}
// Replace 'createMaskOp' with ConstantMaskOp.
- rewriter.replaceOpWithNewOp<ConstantMaskOp>(
- createMaskOp, retTy,
- vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
+ rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, retTy,
+ maskDimSizes);
return success();
}
};
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index dfeb7bc53adad..bfc05c71f5340 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -111,7 +111,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
if (rank == 0) {
assert(dimSizes.size() == 1 &&
"Expected exactly one dim size for a 0-D vector");
- bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
+ bool value = dimSizes.front() == 1;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType,
DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
@@ -119,7 +119,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
return success();
}
- int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
+ int64_t trueDimSize = dimSizes.front();
if (rank == 1) {
if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
@@ -147,7 +147,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
VectorType lowType = VectorType::Builder(dstType).dropDim(0);
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
- loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
+ loc, lowType, dimSizes.drop_front());
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
for (int64_t d = 0; d < trueDimSize; d++)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 7ed3dea42b771..42ac717b44c4b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -550,9 +550,7 @@ struct CastAwayConstantMaskLeadingOneDim
return failure();
int64_t dropDim = oldType.getRank() - newType.getRank();
- SmallVector<int64_t> dimSizes;
- for (auto attr : mask.getMaskDimSizes())
- dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt());
+ ArrayRef<int64_t> dimSizes = mask.getMaskDimSizes();
// If any of the dropped unit dims has a size of `0`, the entire mask is a
// zero mask, else the unit dim has no effect on the mask.
@@ -563,7 +561,7 @@ struct CastAwayConstantMaskLeadingOneDim
newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
auto newMask = rewriter.create<vector::ConstantMaskOp>(
- mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes));
+ mask.getLoc(), newType, newDimSizes);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index ac2a4d3abcc68..d3296ee38c249 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -83,17 +83,14 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
newMaskOperands);
} else if (constantMaskOp) {
- ArrayRef<Attribute> maskDimSizes =
- constantMaskOp.getMaskDimSizes().getValue();
+ ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
size_t numMaskOperands = maskDimSizes.size();
- auto origIndex =
- cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
- IntegerAttr maskIndexAttr =
- rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
- SmallVector<Attribute> newMaskDimSizes(maskDimSizes.drop_back());
- newMaskDimSizes.push_back(maskIndexAttr);
- newMask = rewriter.create<vector::ConstantMaskOp>(
- loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
+ int64_t origIndex = maskDimSizes[numMaskOperands - 1];
+ int64_t maskIndex = (origIndex + scale - 1) / scale;
+ SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
+ newMaskDimSizes.push_back(maskIndex);
+ newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
+ newMaskDimSizes);
}
while (!extractOps.empty()) {
More information about the Mlir-commits
mailing list