[Mlir-commits] [mlir] [mlir][vector] Add support for vector.maskedstore sub-type emulation. (PR #73871)

Han-Chung Wang llvmlistbot at llvm.org
Thu Nov 30 10:24:45 PST 2023


================
@@ -99,6 +171,94 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertVectorMaskedStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertVectorMaskedStore final
+    : OpConversionPattern<vector::MaskedStoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto loc = op.getLoc();
+    auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
+    Type oldElementType = op.getValueToStore().getType().getElementType();
+    Type newElementType = convertedType.getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = newElementType.getIntOrFloatBitWidth();
+
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+
+    int scale = dstBits / srcBits;
+    auto origElements = op.getValueToStore().getType().getNumElements();
+    if (origElements % scale != 0)
+      return failure();
+
+    auto stridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+    OpFoldResult linearizedIndicesOfr;
+    std::tie(std::ignore, linearizedIndicesOfr) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, srcBits, dstBits,
+            stridedMetadata.getConstifiedMixedOffset(),
+            stridedMetadata.getConstifiedMixedSizes(),
+            stridedMetadata.getConstifiedMixedStrides(),
+            getAsOpFoldResult(adaptor.getIndices()));
+    Value linearizedIndices =
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
+
+    // Load the whole data and use arith.select to handle the corner cases.
+    // E.g., given these input values:
+    //
+    //   %mask = [1, 1, 1, 0, 0, 0]
+    //   %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
+    //   %value_to_store = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
+    //
+    // we'll have
+    //
+    //    expected output: [0x7, 0x8, 0x9, 0x4, 0x5, 0x6]
+    //
+    //    %new_mask = [1, 1, 0]
+    //    %maskedload = [0x12, 0x34, 0x0]
+    //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x0, 0x0]
+    //    %select_using_original_mask = [0x7, 0x8, 0x9, 0x4, 0x0, 0x0]
+    //    %packed_data = [0x78, 0x94, 0x00]
+    //
+    // Using the new mask to store %packed_data results in expected output.
----------------
hanhanW wrote:

I'm not very sure if this is the case. I need to think more about it. The indices for store ops are not overlapping with each other when we distribute the work to multi-threads. The race condition issue happens when store ops index on the same "byte". After the emulation, different `memref.store` ops could index on the same byte. In this context, we need atomic ops.

For `vector.maskedstore` and `vector.store`, the whole bytes are taken into account during emulation. So the "index ranges" of each thread does not overlap with others. In this context, we do'nt need atomic ops.

Does it make sense?

https://github.com/llvm/llvm-project/pull/73871


More information about the Mlir-commits mailing list