[Mlir-commits] [mlir] [MLIR] VectorEmulateNarrowType to support loading of unaligned vectors (PR #113411)
Han-Chung Wang
llvmlistbot at llvm.org
Thu Oct 24 12:57:32 PDT 2024
================
@@ -396,29 +486,49 @@ struct ConvertVectorMaskedLoad final
// subvector at the proper offset after bit-casting.
auto origType = op.getVectorType();
auto origElements = origType.getNumElements();
- if (origElements % scale != 0)
- return failure();
+ bool isUnalignedEmulation = origElements % scale != 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
OpFoldResult linearizedIndices;
- std::tie(std::ignore, linearizedIndices) =
+ memref::LinearizedMemRefInfo linearizedInfo;
+ std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, srcBits, dstBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(adaptor.getIndices()));
+ auto foldedFrontPaddingSize = getFrontPaddingSize(
+ rewriter, loc, linearizedInfo, isUnalignedEmulation);
+ if (!foldedFrontPaddingSize) {
+ // unimplemented case for dynamic front padding size
+ return failure();
+ }
+
FailureOr<Operation *> newMask =
- getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
+ getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale,
+ *foldedFrontPaddingSize);
if (failed(newMask))
return failure();
- auto numElements = (origElements + scale - 1) / scale;
+ auto numElements =
+ llvm::alignTo(*foldedFrontPaddingSize + origElements, scale) / scale;
auto newType = VectorType::get(numElements, newElementType);
+
+ auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
+
+ Value passthru = op.getPassThru();
+ if (isUnalignedEmulation) {
+ // create an empty vector of the new type
+ auto emptyVector = rewriter.create<arith::ConstantOp>(
+ loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
+ passthru = insertSubvectorInto(rewriter, loc, op.getPassThru(),
----------------
hanhanW wrote:
nit: you can replace the `op.getPassThru()` with `passthru`.
https://github.com/llvm/llvm-project/pull/113411
More information about the Mlir-commits
mailing list