[Mlir-commits] [mlir] [MLIR] Implement emulation of static indexing subbyte type vector stores (PR #115922)
Han-Chung Wang
llvmlistbot at llvm.org
Tue Dec 3 16:29:54 PST 2024
================
@@ -353,32 +433,153 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
// vector<4xi8>
- auto origElements = op.getValueToStore().getType().getNumElements();
- if (origElements % scale != 0)
- return failure();
+ auto origElements = valueToStore.getType().getNumElements();
+ bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
OpFoldResult linearizedIndices;
- std::tie(std::ignore, linearizedIndices) =
+ memref::LinearizedMemRefInfo linearizedInfo;
+ std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, srcBits, dstBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(adaptor.getIndices()));
- auto numElements = origElements / scale;
- auto bitCast = rewriter.create<vector::BitCastOp>(
- loc, VectorType::get(numElements, newElementType),
- op.getValueToStore());
+ auto foldedNumFrontPadElems =
+ isUnalignedEmulation
+ ? getConstantIntValue(linearizedInfo.intraDataOffset)
+ : 0;
+
+ if (!foldedNumFrontPadElems) {
+ // Unimplemented case for dynamic front padding size != 0
+ return failure();
+ }
+
+ auto emulatedMemref = cast<MemRefValue>(adaptor.getBase());
+
+ // Shortcut: conditions when subbyte store at the front is not needed:
+ // 1. The source vector size is multiple of byte size
+ // 2. The address of the store is aligned to the emulated width boundary
+ if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0) {
+ auto numElements = origElements / numSrcElemsPerDest;
+ auto bitCast = rewriter.create<vector::BitCastOp>(
+ loc, VectorType::get(numElements, newElementType),
+ op.getValueToStore());
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ op, bitCast.getResult(), emulatedMemref,
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+ return success();
+ }
+
+ // The index into the target memref we are storing to
+ Value currentDestIndex =
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+ auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto subWidthStoreMaskType =
+ VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
+ // The index into the source vector we are currently processing
+ auto currentSourceIndex = 0;
+
+ // 1. Partial width store for the first byte, when the store address is not
+ // aligned to emulated width boundary, deal with the unaligned part so that
+ // the rest elements are aligned to width boundary.
+ auto frontSubWidthStoreElem =
+ (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
+ if (frontSubWidthStoreElem != 0) {
+ SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false);
+ if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
+ std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
+ origElements, true);
+ frontSubWidthStoreElem = origElements;
+ } else {
+ std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem,
+ *foldedNumFrontPadElems, true);
+ }
+ auto frontMask = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
+
+ currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
+ auto value =
+ extractSliceIntoByte(rewriter, loc, valueToStore, 0,
+ frontSubWidthStoreElem, *foldedNumFrontPadElems);
+
+ subEmulatedWidthStore(rewriter, loc, emulatedMemref, currentDestIndex,
+ cast<VectorValue>(value), frontMask.getResult(),
+ numSrcElemsPerDest);
+
+ currentDestIndex = rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), currentDestIndex, constantOne);
+ }
+
+ if (currentSourceIndex >= origElements) {
+ rewriter.eraseOp(op);
+ return success();
+ }
- rewriter.replaceOpWithNewOp<vector::StoreOp>(
- op, bitCast.getResult(), adaptor.getBase(),
- getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+ // 2. Full width store. After the previous step, the store address is
+ // aligned to the emulated width boundary.
+ int64_t fullWidthStoreSize =
+ (origElements - currentSourceIndex) / numSrcElemsPerDest;
+ int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
+ if (fullWidthStoreSize != 0) {
----------------
hanhanW wrote:
optional nit: I'd write `> 0`. `!= 0` sometimes makes people think if they could be negative number or not. I think it is just how my decision tree works, not a big deal. So I mark it optional.
https://github.com/llvm/llvm-project/pull/115922
More information about the Mlir-commits
mailing list