[Mlir-commits] [mlir] [mlir][vector] Add support for vector.maskedstore sub-type emulation. (PR #73871)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 30 07:03:22 PST 2023


================
@@ -99,6 +171,94 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertVectorMaskedStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertVectorMaskedStore final
+    : OpConversionPattern<vector::MaskedStoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto loc = op.getLoc();
+    auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
+    Type oldElementType = op.getValueToStore().getType().getElementType();
+    Type newElementType = convertedType.getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = newElementType.getIntOrFloatBitWidth();
+
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+
+    int scale = dstBits / srcBits;
+    auto origElements = op.getValueToStore().getType().getNumElements();
+    if (origElements % scale != 0)
+      return failure();
+
+    auto stridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+    OpFoldResult linearizedIndicesOfr;
+    std::tie(std::ignore, linearizedIndicesOfr) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, srcBits, dstBits,
+            stridedMetadata.getConstifiedMixedOffset(),
+            stridedMetadata.getConstifiedMixedSizes(),
+            stridedMetadata.getConstifiedMixedStrides(),
+            getAsOpFoldResult(adaptor.getIndices()));
+    Value linearizedIndices =
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
+
+    // Load the whole data and use arith.select to handle the corner cases.
+    // E.g., given these input values:
+    //
+    //   %mask = [1, 1, 1, 0, 0, 0]
+    //   %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6]
+    //   %value_to_store = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC]
+    //
+    // we'll have
+    //
+    //    expected output: [0x7, 0x8, 0x9, 0x4, 0x5, 0x6]
+    //
+    //    %new_mask = [1, 1, 0]
+    //    %maskedload = [0x12, 0x34, 0x0]
+    //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x0, 0x0]
+    //    %select_using_original_mask = [0x7, 0x8, 0x9, 0x4, 0x0, 0x0]
+    //    %packed_data = [0x78, 0x94, 0x00]
+    //
+    // Using the new mask to store %packed_data results in expected output.
----------------
Max191 wrote:

I think this can still result in a race condition like in `memref.store` emulation, e.g., if you have the following stores running on different threads:
```
//    %memref = [0x12, 0x34, 0x56]
// Masked store
//    %new_mask = [1, 1, 0]
//    %maskedload = [0x12, 0x34, 0x0]
//    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x0, 0x0]
//    %select_using_original_mask = [0x7, 0x8, 0x9, 0x4, 0x0, 0x0]
//    %packed_data = [0x78, 0x94, 0x00]
//    vector.maskedstore %memref, %new_mask, %packed_data
// Other store
//    %rmw_mask = [0xF0]
//    %rmw_val = [0x0D]
//    memref.atomic_rmw andi %rmw_mask, %memref[1]
//    memref.atomic_rmw ori %rmw_val, %memref[1]
```
If the `memref.atomic_rmw` ops happen after the masked_load, but before the masked_store, then the masked_store will overwrite what was written by the atomic_rmw ops.

A potential solution to this race condition would be to split off the corner cases from the masked store, and rewrite them the same way as `memref.store` emulation (i.e. with two `atomic_rmw` ops like above). However, I don't think this would be a very common occurrence, since masked_store would mostly be used for tiling with masking, but it is potentially possible. Maybe at least a TODO here would be warranted. 

https://github.com/llvm/llvm-project/pull/73871


More information about the Mlir-commits mailing list