[Mlir-commits] [mlir] 7f82c90 - [mlir][vector] Add support for vector.maskedstore sub-type emulation. (#73871)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 30 11:27:12 PST 2023
Author: Han-Chung Wang
Date: 2023-11-30T11:27:06-08:00
New Revision: 7f82c90621770919dc457d8bb5d12d7f3b29e2e1
URL: https://github.com/llvm/llvm-project/commit/7f82c90621770919dc457d8bb5d12d7f3b29e2e1
DIFF: https://github.com/llvm/llvm-project/commit/7f82c90621770919dc457d8bb5d12d7f3b29e2e1.diff
LOG: [mlir][vector] Add support for vector.maskedstore sub-type emulation. (#73871)
The idea is similar to vector.maskedload + vector.store emulation. What
the emulation does is:
1. Get a compressed mask and load the data from destination.
2. Bitcast the data to original vector type.
3. Select values between `op.valueToStore` and the data from load using
original mask.
4. Bitcast the new value and store it to destination using compressed
masked.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 6aea0343bfc9327..ead7d645cb5bb3d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -32,6 +32,79 @@ using namespace mlir;
#define DBGSNL() (llvm::dbgs() << "\n")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+/// Returns a compressed mask. The mask value is set only if any mask is present
+/// in the scale range. E.g., if `scale` equals to 2, the following mask:
+///
+/// %mask = [1, 1, 1, 0, 0, 0]
+///
+/// will return the following new compressed mask:
+///
+/// %mask = [1, 1, 0]
+static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
+ Location loc, Value mask,
+ int origElements, int scale) {
+ auto numElements = (origElements + scale - 1) / scale;
+
+ Operation *maskOp = mask.getDefiningOp();
+ SmallVector<vector::ExtractOp, 2> extractOps;
+ // Finding the mask creation operation.
+ while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
+ if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
+ maskOp = extractOp.getVector().getDefiningOp();
+ extractOps.push_back(extractOp);
+ }
+ }
+ auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
+ auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
+ if (!createMaskOp && !constantMaskOp)
+ return failure();
+
+ // Computing the "compressed" mask. All the emulation logic (i.e. computing
+ // new mask index) only happens on the last dimension of the vectors.
+ Operation *newMask = nullptr;
+ SmallVector<int64_t> shape(
+ maskOp->getResultTypes()[0].cast<VectorType>().getShape());
+ shape.back() = numElements;
+ auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
+ if (createMaskOp) {
+ OperandRange maskOperands = createMaskOp.getOperands();
+ size_t numMaskOperands = maskOperands.size();
+ AffineExpr s0;
+ bindSymbols(rewriter.getContext(), s0);
+ s0 = s0 + scale - 1;
+ s0 = s0.floorDiv(scale);
+ OpFoldResult origIndex =
+ getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
+ OpFoldResult maskIndex =
+ affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
+ SmallVector<Value> newMaskOperands(maskOperands.drop_back());
+ newMaskOperands.push_back(
+ getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
+ newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
+ newMaskOperands);
+ } else if (constantMaskOp) {
+ ArrayRef<Attribute> maskDimSizes =
+ constantMaskOp.getMaskDimSizes().getValue();
+ size_t numMaskOperands = maskDimSizes.size();
+ auto origIndex =
+ cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
+ IntegerAttr maskIndexAttr =
+ rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
+ SmallVector<Attribute> newMaskDimSizes(maskDimSizes.drop_back());
+ newMaskDimSizes.push_back(maskIndexAttr);
+ newMask = rewriter.create<vector::ConstantMaskOp>(
+ loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
+ }
+
+ while (!extractOps.empty()) {
+ newMask = rewriter.create<vector::ExtractOp>(
+ loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
+ extractOps.pop_back();
+ }
+
+ return newMask;
+}
+
namespace {
//===----------------------------------------------------------------------===//
@@ -99,6 +172,94 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertVectorMaskedStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertVectorMaskedStore final
+ : OpConversionPattern<vector::MaskedStoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto loc = op.getLoc();
+ auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
+ Type oldElementType = op.getValueToStore().getType().getElementType();
+ Type newElementType = convertedType.getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = newElementType.getIntOrFloatBitWidth();
+
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ int scale = dstBits / srcBits;
+ int origElements = op.getValueToStore().getType().getNumElements();
+ if (origElements % scale != 0)
+ return failure();
+
+ auto stridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+ OpFoldResult linearizedIndicesOfr;
+ std::tie(std::ignore, linearizedIndicesOfr) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, srcBits, dstBits,
+ stridedMetadata.getConstifiedMixedOffset(),
+ stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides(),
+ getAsOpFoldResult(adaptor.getIndices()));
+ Value linearizedIndices =
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
+
+ // Load the whole data and use arith.select to handle the corner cases.
+ // E.g., given these input values:
+ //
+ // %mask = [1, 1, 1, 0, 0, 0]
+ // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
+ // %value_to_store = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
+ //
+ // we'll have
+ //
+ // expected output: [0x7, 0x8, 0x9, 0x4, 0x5, 0x6]
+ //
+ // %new_mask = [1, 1, 0]
+ // %maskedload = [0x12, 0x34, 0x0]
+ // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x0, 0x0]
+ // %select_using_original_mask = [0x7, 0x8, 0x9, 0x4, 0x0, 0x0]
+ // %packed_data = [0x78, 0x94, 0x00]
+ //
+ // Using the new mask to store %packed_data results in expected output.
+ FailureOr<Operation *> newMask =
+ getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
+ if (failed(newMask))
+ return failure();
+
+ auto numElements = (origElements + scale - 1) / scale;
+ auto newType = VectorType::get(numElements, newElementType);
+ auto passThru = rewriter.create<arith::ConstantOp>(
+ loc, newType, rewriter.getZeroAttr(newType));
+
+ auto newLoad = rewriter.create<vector::MaskedLoadOp>(
+ loc, newType, adaptor.getBase(), linearizedIndices,
+ newMask.value()->getResult(0), passThru);
+
+ Value valueToStore = rewriter.create<vector::BitCastOp>(
+ loc, op.getValueToStore().getType(), newLoad);
+ valueToStore = rewriter.create<arith::SelectOp>(
+ loc, op.getMask(), op.getValueToStore(), valueToStore);
+ valueToStore =
+ rewriter.create<vector::BitCastOp>(loc, newType, valueToStore);
+
+ rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
+ op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
+ valueToStore);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertVectorLoad
//===----------------------------------------------------------------------===//
@@ -236,7 +397,6 @@ struct ConvertVectorMaskedLoad final
// TODO: Currently, only the even number of elements loading is supported.
// To deal with the odd number of elements, one has to extract the
// subvector at the proper offset after bit-casting.
-
auto origType = op.getVectorType();
auto origElements = origType.getNumElements();
if (origElements % scale != 0)
@@ -244,7 +404,6 @@ struct ConvertVectorMaskedLoad final
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
-
OpFoldResult linearizedIndices;
std::tie(std::ignore, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
@@ -254,66 +413,13 @@ struct ConvertVectorMaskedLoad final
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(adaptor.getIndices()));
- auto numElements = (origElements + scale - 1) / scale;
- auto newType = VectorType::get(numElements, newElementType);
-
- auto maskOp = op.getMask().getDefiningOp();
- SmallVector<vector::ExtractOp, 2> extractOps;
- // Finding the mask creation operation.
- while (maskOp &&
- !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
- if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
- maskOp = extractOp.getVector().getDefiningOp();
- extractOps.push_back(extractOp);
- }
- }
- auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
- auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
- if (!createMaskOp && !constantMaskOp)
+ FailureOr<Operation *> newMask =
+ getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
+ if (failed(newMask))
return failure();
- // Computing the "compressed" mask. All the emulation logic (i.e. computing
- // new mask index) only happens on the last dimension of the vectors.
- Operation *newMask = nullptr;
- auto shape = llvm::to_vector(
- maskOp->getResultTypes()[0].cast<VectorType>().getShape().drop_back());
- shape.push_back(numElements);
- auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
- if (createMaskOp) {
- auto maskOperands = createMaskOp.getOperands();
- auto numMaskOperands = maskOperands.size();
- AffineExpr s0;
- bindSymbols(rewriter.getContext(), s0);
- s0 = s0 + scale - 1;
- s0 = s0.floorDiv(scale);
- OpFoldResult origIndex =
- getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
- OpFoldResult maskIndex =
- affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
- auto newMaskOperands = llvm::to_vector(maskOperands.drop_back());
- newMaskOperands.push_back(
- getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
- newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
- newMaskOperands);
- } else if (constantMaskOp) {
- auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
- auto numMaskOperands = maskDimSizes.size();
- auto origIndex =
- cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
- auto maskIndex =
- rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
- auto newMaskDimSizes = llvm::to_vector(maskDimSizes.drop_back());
- newMaskDimSizes.push_back(maskIndex);
- newMask = rewriter.create<vector::ConstantMaskOp>(
- loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
- }
-
- while (!extractOps.empty()) {
- newMask = rewriter.create<vector::ExtractOp>(
- loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
- extractOps.pop_back();
- }
-
+ auto numElements = (origElements + scale - 1) / scale;
+ auto newType = VectorType::get(numElements, newElementType);
auto newPassThru =
rewriter.create<vector::BitCastOp>(loc, newType, op.getPassThru());
@@ -321,7 +427,7 @@ struct ConvertVectorMaskedLoad final
auto newLoad = rewriter.create<vector::MaskedLoadOp>(
loc, newType, adaptor.getBase(),
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
- newMask->getResult(0), newPassThru);
+ newMask.value()->getResult(0), newPassThru);
// Setting the part that originally was not effectively loaded from memory
// to pass through.
@@ -821,7 +927,8 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
// Populate `vector.*` conversion patterns.
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
- ConvertVectorTransferRead>(typeConverter, patterns.getContext());
+ ConvertVectorMaskedStore, ConvertVectorTransferRead>(
+ typeConverter, patterns.getContext());
}
void vector::populateVectorNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index af0f98a1c447de0..cba299b2a1d9567 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -428,3 +428,75 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
// CHECK32: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32>
// CHECK32: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<?xi32>, vector<1xi32>
+
+// -----
+
+func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
+ %0 = memref.alloc() : memref<3x8xi8>
+ %mask = vector.create_mask %arg2 : vector<8xi1>
+ vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi1>, vector<8xi8>
+ return
+}
+// Expect no conversions, i8 is supported.
+// CHECK: func @vector_maskedstore_i8(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[VAL:[a-zA-Z0-9]+]]
+// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x8xi8>
+// CHECK-NEXT: %[[MASK:.+]] = vector.create_mask %[[ARG2]] : vector<8xi1>
+// CHECK-NEXT: vector.maskedstore %[[ALLOC]][%[[ARG0]], %[[ARG1]]], %[[MASK]], %[[VAL]]
+// CHECK-NEXT: return
+
+// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)>
+// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 3) floordiv 4)>
+// CHECK32: func @vector_maskedstore_i8(
+// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK32-SAME: %[[VAL:[a-zA-Z0-9]+]]
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<6xi32>
+// CHECK32: %[[ORIG_MASK:.+]] = vector.create_mask %[[ARG2]] : vector<8xi1>
+// CHECK32: %[[LIDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[MASK_IDX:.+]] = affine.apply #[[MASK_IDX_MAP]]()[%[[ARG2]]]
+// CHECK32: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<2xi1>
+// CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]]
+// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<2xi32> to vector<8xi8>
+// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL]], %[[BITCAST]] : vector<8xi1>, vector<8xi8>
+// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi8> to vector<2xi32>
+// CHECK32: vector.maskedstore %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]]
+
+// -----
+
+func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<8xi8>) {
+ %0 = memref.alloc() : memref<3x8xi8>
+ %mask = vector.constant_mask [4] : vector<8xi1>
+ vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi1>, vector<8xi8>
+ return
+}
+// Expect no conversions, i8 is supported.
+// CHECK: func @vector_cst_maskedstore_i8(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[VAL:[a-zA-Z0-9]+]]
+// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x8xi8>
+// CHECK-NEXT: %[[MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+// CHECK-NEXT: vector.maskedstore %[[ALLOC]][%[[ARG0]], %[[ARG1]]], %[[MASK]], %[[VAL]]
+// CHECK-NEXT: return
+
+// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)>
+// CHECK32: func @vector_cst_maskedstore_i8(
+// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK32-SAME: %[[VAL:[a-zA-Z0-9]+]]
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<6xi32>
+// CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+// CHECK32: %[[LIDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<2xi1>
+// CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]]
+// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<2xi32> to vector<8xi8>
+// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL]], %[[BITCAST]] : vector<8xi1>, vector<8xi8>
+// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi8> to vector<2xi32>
+// CHECK32: vector.maskedstore %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]]
More information about the Mlir-commits
mailing list