[Mlir-commits] [mlir] [MLIR] Fix VectorEmulateNarrowType constant op mask bug (PR #116064)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Nov 14 08:20:40 PST 2024
================
@@ -75,83 +77,134 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
int numSrcElemsPerDest,
int numFrontPadElems = 0) {
- assert(numFrontPadElems < numSrcElemsPerDest && "intraDataOffset must be less than scale");
+ assert(numFrontPadElems < numSrcElemsPerDest &&
+ "numFrontPadElems must be less than numSrcElemsPerDest");
auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
numSrcElemsPerDest;
Operation *maskOp = mask.getDefiningOp();
SmallVector<vector::ExtractOp, 2> extractOps;
// Finding the mask creation operation.
- while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
+ while (maskOp &&
+ !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
+ maskOp)) {
if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
maskOp = extractOp.getVector().getDefiningOp();
extractOps.push_back(extractOp);
}
}
- auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
- auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
- if (!createMaskOp && !constantMaskOp)
+
+ if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
+ maskOp))
return failure();
// Computing the "compressed" mask. All the emulation logic (i.e. computing
// new mask index) only happens on the last dimension of the vectors.
- Operation *newMask = nullptr;
- SmallVector<int64_t> shape(
+ SmallVector<int64_t> maskShape(
cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
- shape.back() = numElements;
- auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
- if (createMaskOp) {
- OperandRange maskOperands = createMaskOp.getOperands();
- size_t numMaskOperands = maskOperands.size();
- AffineExpr s0;
- bindSymbols(rewriter.getContext(), s0);
- s0 = s0 + numSrcElemsPerDest - 1;
- s0 = s0.floorDiv(numSrcElemsPerDest);
- OpFoldResult origIndex =
- getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
- OpFoldResult maskIndex =
- affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
- SmallVector<Value> newMaskOperands(maskOperands.drop_back());
- newMaskOperands.push_back(
- getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
- newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
- newMaskOperands);
- } else if (constantMaskOp) {
- ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
- size_t numMaskOperands = maskDimSizes.size();
- int64_t origIndex = maskDimSizes[numMaskOperands - 1];
- int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
- int64_t maskIndex =
- llvm::divideCeil(numFrontPadElems + origIndex, numSrcElemsPerDest);
-
- // TODO: we only want the mask between [startIndex, maskIndex] to be true,
- // the rest are false.
- if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
- return failure();
-
- SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
- newMaskDimSizes.push_back(maskIndex);
-
- if (numFrontPadElems == 0) {
- newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
- newMaskDimSizes);
- } else {
- SmallVector<bool> newMaskValues;
- for (int64_t i = 0; i < numElements; ++i)
- newMaskValues.push_back(i >= startIndex && i < maskIndex);
- auto denseAttr = DenseElementsAttr::get(newMaskType, newMaskValues);
- newMask = rewriter.create<arith::ConstantOp>(loc, newMaskType, denseAttr);
- }
- }
+ maskShape.back() = numElements;
+ auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type());
+ std::optional<Operation *> newMask =
+ TypeSwitch<Operation *, std::optional<Operation *>>(maskOp)
+ .Case<vector::CreateMaskOp>(
+ [&](auto createMaskOp) -> std::optional<Operation *> {
+ OperandRange maskOperands = createMaskOp.getOperands();
+ size_t numMaskOperands = maskOperands.size();
+ AffineExpr s0;
+ bindSymbols(rewriter.getContext(), s0);
+ s0 = s0 + numSrcElemsPerDest - 1;
+ s0 = s0.floorDiv(numSrcElemsPerDest);
+ OpFoldResult origIndex =
+ getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
+ OpFoldResult maskIndex = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, s0, origIndex);
+ SmallVector<Value> newMaskOperands(maskOperands.drop_back());
+ newMaskOperands.push_back(
+ getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
+ return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
+ newMaskOperands);
+ })
+ .Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
+ -> std::optional<Operation *> {
+ ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
+ size_t numMaskOperands = maskDimSizes.size();
+ int64_t origIndex = maskDimSizes[numMaskOperands - 1];
+ int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
+ int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex,
+ numSrcElemsPerDest);
+
+ // TODO: we only want the mask between [startIndex, maskIndex]
+ // to be true, the rest are false.
+ if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
+ return std::nullopt;
+
+ SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
+ newMaskDimSizes.push_back(maskIndex);
+
+ if (numFrontPadElems == 0)
+ return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
+ newMaskDimSizes);
+
+ SmallVector<bool> newMaskValues;
+ for (int64_t i = 0; i < numElements; ++i)
+ newMaskValues.push_back(i >= startIndex && i < maskIndex);
+ auto denseAttr = DenseElementsAttr::get(newMaskType, newMaskValues);
+ return rewriter.create<arith::ConstantOp>(loc, newMaskType,
+ denseAttr);
+ })
+ .Case<arith::ConstantOp>([&](auto constantOp)
+ -> std::optional<Operation *> {
+ // TODO: Support multiple dimensions.
+ if (maskShape.size() != 1)
+ return std::nullopt;
+ // Rearrange the original mask values to cover the whole potential
+ // loading region. For example, in the case of using byte-size for
+ // emulation, given the following mask:
+ //
+ // %mask = [false, true, false, true, false, false]
+ //
+ // With front offset of 1, the mask will be padded 0s in the front
+ // and back so that:
+ // 1. It is aligned with the effective loading bits
+ // 2. Its length is multiple of `numSrcElemPerDest` (and the total
+ // coverage size is mulitiple of bytes). The new mask will be like
+ // this before compressing:
+ //
+ // %new_mask = [false, false, true, false, true, false, false,
+ // false]
+ auto denseAttr =
+ dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
+ if (!denseAttr)
+ return std::nullopt;
+ SmallVector<bool> maskValues(numFrontPadElems, false);
+ maskValues.append(denseAttr.template value_begin<bool>(),
+ denseAttr.template value_end<bool>());
----------------
banach-space wrote:
What's this meant to do?
https://github.com/llvm/llvm-project/pull/116064
More information about the Mlir-commits
mailing list