[Mlir-commits] [mlir] [MLIR] Fix VectorEmulateNarrowType constant op mask bug (PR #116064)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 14 19:02:51 PST 2024
https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/116064
>From 934405556ca50cd3f167c10b00ec1769e57efd75 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 12 Nov 2024 22:43:58 -0500
Subject: [PATCH] [MLIR] Fix VectorEmulateNarrowType constant op mask bug
This commit adds support for handling mask constants generated by the
`arith.constant` op in the `VectorEmulateNarrowType` pattern. Previously, this
pattern would not match due to the lack of mask constant handling in
`getCompressedMaskOp`.
The changes include:
1. Updating `getCompressedMaskOp` to recognize and handle `arith.constant` ops as
mask value sources.
2. Handling cases where the mask is not aligned with the emulated load width.
The compressed mask is adjusted to account for the offset.
Limitations:
- The arith.constant op can only have 1-dimensional constant values.
Resolves: #115742
Signed-off-by: Alan Li <me at alanli.org>
---
.../Transforms/VectorEmulateNarrowType.cpp | 165 +++++++++++-------
.../vector-emulate-narrow-type-unaligned.mlir | 38 ++++
.../Vector/vector-emulate-narrow-type.mlir | 50 ++++++
3 files changed, 194 insertions(+), 59 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index e5f2a847994aee..7006dc0be5904f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -75,83 +75,132 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
int numSrcElemsPerDest,
int numFrontPadElems = 0) {
- assert(numFrontPadElems < numSrcElemsPerDest && "intraDataOffset must be less than scale");
+ assert(numFrontPadElems < numSrcElemsPerDest &&
+ "numFrontPadElems must be less than numSrcElemsPerDest");
- auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
+ auto numDestElems = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) /
numSrcElemsPerDest;
Operation *maskOp = mask.getDefiningOp();
SmallVector<vector::ExtractOp, 2> extractOps;
+ // TODO: add support to `vector.splat`.
// Finding the mask creation operation.
- while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
+ while (maskOp &&
+ !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
+ maskOp)) {
if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
maskOp = extractOp.getVector().getDefiningOp();
extractOps.push_back(extractOp);
}
}
- auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
- auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
- if (!createMaskOp && !constantMaskOp)
+
+ if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
+ maskOp))
return failure();
// Computing the "compressed" mask. All the emulation logic (i.e. computing
// new mask index) only happens on the last dimension of the vectors.
- Operation *newMask = nullptr;
- SmallVector<int64_t> shape(
+ SmallVector<int64_t> maskShape(
cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
- shape.back() = numElements;
- auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
- if (createMaskOp) {
- OperandRange maskOperands = createMaskOp.getOperands();
- size_t numMaskOperands = maskOperands.size();
- AffineExpr s0;
- bindSymbols(rewriter.getContext(), s0);
- s0 = s0 + numSrcElemsPerDest - 1;
- s0 = s0.floorDiv(numSrcElemsPerDest);
- OpFoldResult origIndex =
- getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
- OpFoldResult maskIndex =
- affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
- SmallVector<Value> newMaskOperands(maskOperands.drop_back());
- newMaskOperands.push_back(
- getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
- newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
- newMaskOperands);
- } else if (constantMaskOp) {
- ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
- size_t numMaskOperands = maskDimSizes.size();
- int64_t origIndex = maskDimSizes[numMaskOperands - 1];
- int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
- int64_t maskIndex =
- llvm::divideCeil(numFrontPadElems + origIndex, numSrcElemsPerDest);
-
- // TODO: we only want the mask between [startIndex, maskIndex] to be true,
- // the rest are false.
- if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
- return failure();
-
- SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
- newMaskDimSizes.push_back(maskIndex);
-
- if (numFrontPadElems == 0) {
- newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
- newMaskDimSizes);
- } else {
- SmallVector<bool> newMaskValues;
- for (int64_t i = 0; i < numElements; ++i)
- newMaskValues.push_back(i >= startIndex && i < maskIndex);
- auto denseAttr = DenseElementsAttr::get(newMaskType, newMaskValues);
- newMask = rewriter.create<arith::ConstantOp>(loc, newMaskType, denseAttr);
- }
- }
+ maskShape.back() = numDestElems;
+ auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type());
+ std::optional<Operation *> newMask =
+ TypeSwitch<Operation *, std::optional<Operation *>>(maskOp)
+ .Case<vector::CreateMaskOp>(
+ [&](auto createMaskOp) -> std::optional<Operation *> {
+ OperandRange maskOperands = createMaskOp.getOperands();
+ size_t numMaskOperands = maskOperands.size();
+ AffineExpr s0;
+ bindSymbols(rewriter.getContext(), s0);
+ s0 = s0 + numSrcElemsPerDest - 1;
+ s0 = s0.floorDiv(numSrcElemsPerDest);
+ OpFoldResult origIndex =
+ getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
+ OpFoldResult maskIndex = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, s0, origIndex);
+ SmallVector<Value> newMaskOperands(maskOperands.drop_back());
+ newMaskOperands.push_back(
+ getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
+ return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
+ newMaskOperands);
+ })
+ .Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
+ -> std::optional<Operation *> {
+ ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
+ size_t numMaskOperands = maskDimSizes.size();
+ int64_t origIndex = maskDimSizes[numMaskOperands - 1];
+ int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
+ int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex,
+ numSrcElemsPerDest);
+
+ // TODO: we only want the mask between [startIndex, maskIndex]
+ // to be true, the rest are false.
+ if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
+ return std::nullopt;
+
+ SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
+ newMaskDimSizes.push_back(maskIndex);
+
+ if (numFrontPadElems == 0)
+ return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
+ newMaskDimSizes);
+
+ SmallVector<bool> newMaskValues;
+ for (int64_t i = 0; i < numDestElems; ++i)
+ newMaskValues.push_back(i >= startIndex && i < maskIndex);
+ auto denseAttr = DenseElementsAttr::get(newMaskType, newMaskValues);
+ return rewriter.create<arith::ConstantOp>(loc, newMaskType,
+ denseAttr);
+ })
+ .Case<arith::ConstantOp>([&](auto constantOp)
+ -> std::optional<Operation *> {
+ // TODO: Support multiple dimensions.
+ if (maskShape.size() != 1)
+ return std::nullopt;
+ // Rearrange the original mask values to cover the whole potential
+ // loading region. For example, in the case of using byte-size for
+ // emulation, given the following mask:
+ //
+ // %mask = [0, 1, 0, 1, 0, 0]
+ //
+ // With front offset of 1, the mask will be padded 0s in the front
+ // and back so that:
+ // 1. It is aligned with the effective loading bits
+ // 2. Its length is multiple of `numSrcElemPerDest` (and the total
+ // coverage size is mulitiple of bytes). The new mask will be like
+ // this before compressing:
+ //
+ // %new_mask = [0, 0, 1, 0, 1, 0, 0, 0]
+ auto denseAttr =
+ cast<DenseIntElementsAttr>(constantOp.getValue());
+ SmallVector<bool> paddedMaskValues(numFrontPadElems, false);
+ paddedMaskValues.append(denseAttr.template value_begin<bool>(),
+ denseAttr.template value_end<bool>());
+ paddedMaskValues.resize(numDestElems * numSrcElemsPerDest, false);
+
+ // Compressing by combining every `numSrcElemsPerDest` elements:
+ SmallVector<bool> compressedMaskValues;
+ for (size_t i = 0; i < paddedMaskValues.size(); i += numSrcElemsPerDest) {
+ bool combinedValue = false;
+ for (int j = 0; j < numSrcElemsPerDest; ++j) {
+ combinedValue |= paddedMaskValues[i + j];
+ }
+ compressedMaskValues.push_back(combinedValue);
+ }
+ return rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(newMaskType, compressedMaskValues));
+ });
+
+ if (!newMask)
+ return failure();
while (!extractOps.empty()) {
newMask = rewriter.create<vector::ExtractOp>(
- loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
+ loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition());
extractOps.pop_back();
}
- return newMask;
+ return *newMask;
}
/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
@@ -185,12 +234,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
/// `vector.insert_strided_slice`.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
Value src, Value dest, int64_t offset) {
- auto srcType = cast<VectorType>(src.getType());
- auto destType = cast<VectorType>(dest.getType());
+ [[maybe_unused]] auto srcType = cast<VectorType>(src.getType());
+ [[maybe_unused]] auto destType = cast<VectorType>(dest.getType());
assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
"expected source and dest to be vector type");
- (void)srcType;
- (void)destType;
auto offsets = rewriter.getI64ArrayAttr({offset});
auto strides = rewriter.getI64ArrayAttr({1});
return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 7ed75ff7f1579c..b1a0d4f924f3cf 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -249,3 +249,41 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
// CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>
+
+// -----
+
+func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> {
+ %0 = memref.alloc() : memref<3x5xi2>
+ %mask = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %1 = vector.maskedload %0[%c1, %c0], %mask, %passthru :
+ memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
+ return %1 : vector<5xi2>
+}
+
+// CHECK: func @vector_maskedload_i4_constant_mask_unaligned(
+// CHECK-SAME: %[[PTH:.+]]: vector<5xi2>) -> vector<5xi2>
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
+// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
+
+// CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
+// CHECK: %[[PTH_PADDED:.+]] = vector.insert_strided_slice %[[PTH]], %[[EMPTY]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
+
+// Emulated masked load from alloc:
+// CHECK: %[[PTH_PADDED_UPCAST:.+]] = vector.bitcast %[[PTH_PADDED]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMPRESSED_MASK]], %[[PTH_PADDED_UPCAST]]
+// CHECK: %[[MASKLOAD_DOWNCAST:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2>
+
+// Select from emulated loaded vector and passthru vector:
+// TODO: fold this part if possible.
+// CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
+// CHECK: %[[MASK_PADDED:.+]] = vector.insert_strided_slice %[[MASK]], %[[EMPTY_MASK]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>
+// CHECK: %[[SELECT:.+]] = arith.select %[[MASK_PADDED]], %[[MASKLOAD_DOWNCAST]], %[[PTH_PADDED]] : vector<8xi1>, vector<8xi2>
+// CHECK: %[[RESULT:.+]] = vector.extract_strided_slice %[[SELECT]]
+// CHECK-SAME: {offsets = [1], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>
+// CHECK: return %[[RESULT]] : vector<5xi2>
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 034bd47f6163e6..53a60c86c5f5e2 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -275,6 +275,30 @@ func.func @vector_maskedload_i4_constant_mask(%arg1: index, %arg2: index, %passt
// -----
+func.func @vector_maskedload_i4_arith_constant(%passthru: vector<8xi4>) -> vector<8xi4> {
+ %0 = memref.alloc() : memref<3x8xi4>
+ %cst = arith.constant dense<0> : vector<8xi4>
+ %mask = arith.constant dense<[false, true, true, true, true, false, false, false]> : vector<8xi1>
+ %c0 = arith.constant 0 : index
+ %1 = vector.maskedload %0[%c0, %c0], %mask, %passthru :
+ memref<3x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4>
+ return %1 : vector<8xi4>
+}
+
+// CHECK: func @vector_maskedload_i4_arith_constant(
+// CHECK-SAME: %[[PASSTHRU:[a-zA-Z0-9]+]]
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, true, false, false, false]> : vector<8xi1>
+
+// Emit a new, compressed mask for emulated maskedload:
+// CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
+// CHECK: %[[PTHU_UPCAST:.+]] = vector.bitcast %[[PASSTHRU]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C0]]], %[[COMPRESSED_MASK]], %[[PTHU_UPCAST]]
+// CHECK: %[[LOAD_DOWNCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK: %[[SELECT:.+]] = arith.select %[[MASK]], %[[LOAD_DOWNCAST]], %[[PASSTHRU]] : vector<8xi1>, vector<8xi4>
+// CHECK: return %[[SELECT]] : vector<8xi4>
+
///----------------------------------------------------------------------------------------
/// vector.extract -> vector.masked_load
///----------------------------------------------------------------------------------------
@@ -624,3 +648,29 @@ func.func @vector_maskedstore_i4_constant_mask(
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
+
+// -----
+
+func.func @vector_maskedstore_i4_arith_constant(%val_to_store: vector<8xi4>) {
+ %0 = memref.alloc() : memref<5x8xi4>
+ %cst = arith.constant dense<0> : vector<8xi4>
+ %mask = arith.constant dense<[false, true, true, true, true, true, false, false]> : vector<8xi1>
+ %c0 = arith.constant 0 : index
+ %c3 = arith.constant 3 : index
+ vector.maskedstore %0[%c3, %c0], %mask, %val_to_store :
+ memref<5x8xi4>, vector<8xi1>, vector<8xi4>
+ return
+}
+
+// CHECK-LABEL: func @vector_maskedstore_i4_arith_constant
+// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]:
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<20xi8>
+// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, true, true, false, false]> : vector<8xi1>
+// CHECK: %[[C12:.+]] = arith.constant 12 : index
+// CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
+// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<4xi8>
+// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C12]]], %[[COMPRESSED_MASK]], %[[EMPTY]]
+// CHECK: %[[LOAD_UPCAST:.+]] = vector.bitcast %[[MASKEDLOAD]]
+// CHECK: %[[SELECT:.+]] = arith.select %[[MASK]], %[[VAL_TO_STORE]], %[[LOAD_UPCAST]]
+// CHECK: %[[SELECT_DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
+// CHECK: vector.maskedstore %[[ALLOC]][%[[C12]]], %[[COMPRESSED_MASK]], %[[SELECT_DOWNCAST]]
More information about the Mlir-commits
mailing list