[Mlir-commits] [mlir] [mlir][vector] Add support for vector.maskedstore sub-type emulation. (PR #73871)
Han-Chung Wang
llvmlistbot at llvm.org
Wed Nov 29 16:00:05 PST 2023
https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/73871
>From 46ae780e274a4b3a0e9ea5217b3d6121a1a83d91 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Wed, 29 Nov 2023 15:44:41 -0800
Subject: [PATCH 1/2] [mlir][vector] Add support for vector.maskedstore
sub-type emulation.
---
.../Transforms/VectorEmulateNarrowType.cpp | 230 +++++++++++++-----
.../Vector/vector-emulate-narrow-type.mlir | 72 ++++++
2 files changed, 240 insertions(+), 62 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 6aea0343bfc9327..05c98b89e8a94c1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -32,6 +32,78 @@ using namespace mlir;
#define DBGSNL() (llvm::dbgs() << "\n")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+/// Returns a compressed mask. The mask value is set only if any mask is present
+/// in the the scale range. E.g., if `scale` equals to 2, the following mask:
+///
+/// %mask = [1, 1, 1, 0, 0, 0]
+///
+/// will return the following new compressed mask:
+///
+/// %mask = [1, 1, 0]
+static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
+ Location loc, Value mask,
+ int origElements, int scale) {
+ auto numElements = (origElements + scale - 1) / scale;
+
+ auto maskOp = mask.getDefiningOp();
+ SmallVector<vector::ExtractOp, 2> extractOps;
+ // Finding the mask creation operation.
+ while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
+ if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
+ maskOp = extractOp.getVector().getDefiningOp();
+ extractOps.push_back(extractOp);
+ }
+ }
+ auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
+ auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
+ if (!createMaskOp && !constantMaskOp)
+ return failure();
+
+ // Computing the "compressed" mask. All the emulation logic (i.e. computing
+ // new mask index) only happens on the last dimension of the vectors.
+ Operation *newMask = nullptr;
+ auto shape = llvm::to_vector(
+ maskOp->getResultTypes()[0].cast<VectorType>().getShape().drop_back());
+ shape.push_back(numElements);
+ auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
+ if (createMaskOp) {
+ auto maskOperands = createMaskOp.getOperands();
+ auto numMaskOperands = maskOperands.size();
+ AffineExpr s0;
+ bindSymbols(rewriter.getContext(), s0);
+ s0 = s0 + scale - 1;
+ s0 = s0.floorDiv(scale);
+ OpFoldResult origIndex =
+ getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
+ OpFoldResult maskIndex =
+ affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
+ auto newMaskOperands = llvm::to_vector(maskOperands.drop_back());
+ newMaskOperands.push_back(
+ getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
+ newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
+ newMaskOperands);
+ } else if (constantMaskOp) {
+ auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
+ auto numMaskOperands = maskDimSizes.size();
+ auto origIndex =
+ cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
+ auto maskIndex =
+ rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
+ auto newMaskDimSizes = llvm::to_vector(maskDimSizes.drop_back());
+ newMaskDimSizes.push_back(maskIndex);
+ newMask = rewriter.create<vector::ConstantMaskOp>(
+ loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
+ }
+
+ while (!extractOps.empty()) {
+ newMask = rewriter.create<vector::ExtractOp>(
+ loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
+ extractOps.pop_back();
+ }
+
+ return newMask;
+}
+
namespace {
//===----------------------------------------------------------------------===//
@@ -99,6 +171,94 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertVectorMaskedStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertVectorMaskedStore final
+ : OpConversionPattern<vector::MaskedStoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto loc = op.getLoc();
+ auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
+ Type oldElementType = op.getValueToStore().getType().getElementType();
+ Type newElementType = convertedType.getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = newElementType.getIntOrFloatBitWidth();
+
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ int scale = dstBits / srcBits;
+ auto origElements = op.getValueToStore().getType().getNumElements();
+ if (origElements % scale != 0)
+ return failure();
+
+ auto stridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+ OpFoldResult linearizedIndicesOfr;
+ std::tie(std::ignore, linearizedIndicesOfr) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, srcBits, dstBits,
+ stridedMetadata.getConstifiedMixedOffset(),
+ stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides(),
+ getAsOpFoldResult(adaptor.getIndices()));
+ Value linearizedIndices =
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
+
+ // Load the whole data and use arith.select to handle the corner cases.
+ // E.g., given these input values:
+ //
+ // %mask = [1, 1, 1, 0, 0, 0]
+ // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
+ // %value_to_store = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
+ //
+ // we'll have
+ //
+ // expected output: [0x7, 0x8, 0x9, 0x4, 0x5, 0x6]
+ //
+ // %new_mask = [1, 1, 0]
+ // %maskedload = [0x12, 0x34, 0x0]
+ // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x0, 0x0]
+ // %select_using_original_mask = [0x7, 0x8, 0x9, 0x4, 0x0, 0x0]
+ // %packed_data = [0x78, 0x94, 0x0, 0x0]
+ //
+ // Using the new mask to store %packed_data results in expected output.
+ FailureOr<Operation *> newMask =
+ getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
+ if (failed(newMask))
+ return failure();
+
+ auto numElements = (origElements + scale - 1) / scale;
+ auto newType = VectorType::get(numElements, newElementType);
+ auto passThru = rewriter.create<arith::ConstantOp>(
+ loc, newType, rewriter.getZeroAttr(newType));
+
+ auto newLoad = rewriter.create<vector::MaskedLoadOp>(
+ loc, newType, adaptor.getBase(), linearizedIndices,
+ newMask.value()->getResult(0), passThru);
+
+ Value valueToStore = rewriter.create<vector::BitCastOp>(
+ loc, op.getValueToStore().getType(), newLoad);
+ valueToStore = rewriter.create<arith::SelectOp>(
+ loc, op.getMask(), op.getValueToStore(), valueToStore);
+ valueToStore =
+ rewriter.create<vector::BitCastOp>(loc, newType, valueToStore);
+
+ rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
+ op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0),
+ valueToStore);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertVectorLoad
//===----------------------------------------------------------------------===//
@@ -236,7 +396,6 @@ struct ConvertVectorMaskedLoad final
// TODO: Currently, only the even number of elements loading is supported.
// To deal with the odd number of elements, one has to extract the
// subvector at the proper offset after bit-casting.
-
auto origType = op.getVectorType();
auto origElements = origType.getNumElements();
if (origElements % scale != 0)
@@ -244,7 +403,6 @@ struct ConvertVectorMaskedLoad final
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
-
OpFoldResult linearizedIndices;
std::tie(std::ignore, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
@@ -254,66 +412,13 @@ struct ConvertVectorMaskedLoad final
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(adaptor.getIndices()));
- auto numElements = (origElements + scale - 1) / scale;
- auto newType = VectorType::get(numElements, newElementType);
-
- auto maskOp = op.getMask().getDefiningOp();
- SmallVector<vector::ExtractOp, 2> extractOps;
- // Finding the mask creation operation.
- while (maskOp &&
- !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
- if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
- maskOp = extractOp.getVector().getDefiningOp();
- extractOps.push_back(extractOp);
- }
- }
- auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
- auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
- if (!createMaskOp && !constantMaskOp)
+ FailureOr<Operation *> newMask =
+ getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
+ if (failed(newMask))
return failure();
- // Computing the "compressed" mask. All the emulation logic (i.e. computing
- // new mask index) only happens on the last dimension of the vectors.
- Operation *newMask = nullptr;
- auto shape = llvm::to_vector(
- maskOp->getResultTypes()[0].cast<VectorType>().getShape().drop_back());
- shape.push_back(numElements);
- auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
- if (createMaskOp) {
- auto maskOperands = createMaskOp.getOperands();
- auto numMaskOperands = maskOperands.size();
- AffineExpr s0;
- bindSymbols(rewriter.getContext(), s0);
- s0 = s0 + scale - 1;
- s0 = s0.floorDiv(scale);
- OpFoldResult origIndex =
- getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
- OpFoldResult maskIndex =
- affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
- auto newMaskOperands = llvm::to_vector(maskOperands.drop_back());
- newMaskOperands.push_back(
- getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
- newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
- newMaskOperands);
- } else if (constantMaskOp) {
- auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
- auto numMaskOperands = maskDimSizes.size();
- auto origIndex =
- cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
- auto maskIndex =
- rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
- auto newMaskDimSizes = llvm::to_vector(maskDimSizes.drop_back());
- newMaskDimSizes.push_back(maskIndex);
- newMask = rewriter.create<vector::ConstantMaskOp>(
- loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
- }
-
- while (!extractOps.empty()) {
- newMask = rewriter.create<vector::ExtractOp>(
- loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
- extractOps.pop_back();
- }
-
+ auto numElements = (origElements + scale - 1) / scale;
+ auto newType = VectorType::get(numElements, newElementType);
auto newPassThru =
rewriter.create<vector::BitCastOp>(loc, newType, op.getPassThru());
@@ -321,7 +426,7 @@ struct ConvertVectorMaskedLoad final
auto newLoad = rewriter.create<vector::MaskedLoadOp>(
loc, newType, adaptor.getBase(),
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
- newMask->getResult(0), newPassThru);
+ newMask.value()->getResult(0), newPassThru);
// Setting the part that originally was not effectively loaded from memory
// to pass through.
@@ -821,7 +926,8 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
// Populate `vector.*` conversion patterns.
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
- ConvertVectorTransferRead>(typeConverter, patterns.getContext());
+ ConvertVectorMaskedStore, ConvertVectorTransferRead>(
+ typeConverter, patterns.getContext());
}
void vector::populateVectorNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index af0f98a1c447de0..cba299b2a1d9567 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -428,3 +428,75 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
// CHECK32: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32>
// CHECK32: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<?xi32>, vector<1xi32>
+
+// -----
+
+func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
+ %0 = memref.alloc() : memref<3x8xi8>
+ %mask = vector.create_mask %arg2 : vector<8xi1>
+ vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi1>, vector<8xi8>
+ return
+}
+// Expect no conversions, i8 is supported.
+// CHECK: func @vector_maskedstore_i8(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[VAL:[a-zA-Z0-9]+]]
+// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x8xi8>
+// CHECK-NEXT: %[[MASK:.+]] = vector.create_mask %[[ARG2]] : vector<8xi1>
+// CHECK-NEXT: vector.maskedstore %[[ALLOC]][%[[ARG0]], %[[ARG1]]], %[[MASK]], %[[VAL]]
+// CHECK-NEXT: return
+
+// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)>
+// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 3) floordiv 4)>
+// CHECK32: func @vector_maskedstore_i8(
+// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK32-SAME: %[[VAL:[a-zA-Z0-9]+]]
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<6xi32>
+// CHECK32: %[[ORIG_MASK:.+]] = vector.create_mask %[[ARG2]] : vector<8xi1>
+// CHECK32: %[[LIDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[MASK_IDX:.+]] = affine.apply #[[MASK_IDX_MAP]]()[%[[ARG2]]]
+// CHECK32: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<2xi1>
+// CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]]
+// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<2xi32> to vector<8xi8>
+// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL]], %[[BITCAST]] : vector<8xi1>, vector<8xi8>
+// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi8> to vector<2xi32>
+// CHECK32: vector.maskedstore %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]]
+
+// -----
+
+func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<8xi8>) {
+ %0 = memref.alloc() : memref<3x8xi8>
+ %mask = vector.constant_mask [4] : vector<8xi1>
+ vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi1>, vector<8xi8>
+ return
+}
+// Expect no conversions, i8 is supported.
+// CHECK: func @vector_cst_maskedstore_i8(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[VAL:[a-zA-Z0-9]+]]
+// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x8xi8>
+// CHECK-NEXT: %[[MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+// CHECK-NEXT: vector.maskedstore %[[ALLOC]][%[[ARG0]], %[[ARG1]]], %[[MASK]], %[[VAL]]
+// CHECK-NEXT: return
+
+// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)>
+// CHECK32: func @vector_cst_maskedstore_i8(
+// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK32-SAME: %[[VAL:[a-zA-Z0-9]+]]
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<6xi32>
+// CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+// CHECK32: %[[LIDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<2xi1>
+// CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<2xi32>
+// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]]
+// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<2xi32> to vector<8xi8>
+// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL]], %[[BITCAST]] : vector<8xi1>, vector<8xi8>
+// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi8> to vector<2xi32>
+// CHECK32: vector.maskedstore %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]]
>From a326e740a8b9b26100d5d6f698115e4c40ac349d Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Wed, 29 Nov 2023 15:59:05 -0800
Subject: [PATCH 2/2] fix comment
---
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 05c98b89e8a94c1..5b1a329b2a6c393 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -228,7 +228,7 @@ struct ConvertVectorMaskedStore final
// %maskedload = [0x12, 0x34, 0x0]
// %bitcast = [0x1, 0x2, 0x3, 0x4, 0x0, 0x0]
// %select_using_original_mask = [0x7, 0x8, 0x9, 0x4, 0x0, 0x0]
- // %packed_data = [0x78, 0x94, 0x0, 0x0]
+ // %packed_data = [0x78, 0x94, 0x00]
//
// Using the new mask to store %packed_data results in expected output.
FailureOr<Operation *> newMask =
More information about the Mlir-commits
mailing list