[Mlir-commits] [mlir] [mlir][vector] Add support for vector.maskedstore sub-type emulation. (PR #73871)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 30 10:55:55 PST 2023
================
@@ -99,6 +171,94 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertVectorMaskedStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertVectorMaskedStore final
+ : OpConversionPattern<vector::MaskedStoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto loc = op.getLoc();
+ auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
+ Type oldElementType = op.getValueToStore().getType().getElementType();
+ Type newElementType = convertedType.getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = newElementType.getIntOrFloatBitWidth();
+
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ int scale = dstBits / srcBits;
+ auto origElements = op.getValueToStore().getType().getNumElements();
+ if (origElements % scale != 0)
+ return failure();
+
+ auto stridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+ OpFoldResult linearizedIndicesOfr;
+ std::tie(std::ignore, linearizedIndicesOfr) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, srcBits, dstBits,
+ stridedMetadata.getConstifiedMixedOffset(),
+ stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides(),
+ getAsOpFoldResult(adaptor.getIndices()));
+ Value linearizedIndices =
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
+
+ // Load the whole data and use arith.select to handle the corner cases.
+ // E.g., given these input values:
+ //
+ // %mask = [1, 1, 1, 0, 0, 0]
+ // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
+ // %value_to_store = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
+ //
+ // we'll have
+ //
+ // expected output: [0x7, 0x8, 0x9, 0x4, 0x5, 0x6]
+ //
+ // %new_mask = [1, 1, 0]
+ // %maskedload = [0x12, 0x34, 0x0]
+ // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x0, 0x0]
+ // %select_using_original_mask = [0x7, 0x8, 0x9, 0x4, 0x0, 0x0]
+ // %packed_data = [0x78, 0x94, 0x00]
+ //
+ // Using the new mask to store %packed_data results in expected output.
----------------
Max191 wrote:
Maybe we would never distribute this way, but it is possible to have two threads with non-overlapping indices that could result in IR with a similar problem to the above. For example, if there is a tensor of shape `6xi4`, and the work was distributed into 2 threads, storing into the first `3xi4` values and the second `3xi4` values respectively. Then these could be lowered into 2 `vector.maskedstore`/`vector.store` ops that overlap on the middle byte after narrow type emulation, since the `6xi4` would become `3xi8`.
This is probably a moot point, though, because I don't think we would distribute into non powers of 2 in this way.
https://github.com/llvm/llvm-project/pull/73871
More information about the Mlir-commits
mailing list