[Mlir-commits] [mlir] [MLIR] Implement emulation of static indexing subbyte type vector stores (PR #115922)

Han-Chung Wang llvmlistbot at llvm.org
Tue Dec 3 16:29:54 PST 2024


================
@@ -309,6 +314,76 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
       newLoad);
 }
 
+/// Atomically store a subbyte-sized value to memory, with a mask.
+static void atomicStore(OpBuilder &builder, Location loc,
+                        MemRefValue emulatedMemref, Value linearizedIndex,
+                        VectorValue value, Value mask,
+                        int64_t numSrcElemsPerDest) {
+  auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
+      loc, emulatedMemref, ValueRange{linearizedIndex});
+  Value origValue = atomicOp.getCurrentValue();
+
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointToStart(atomicOp.getBody());
+
+  // i8 -> <1xi8> -> <numSrcElemsPerDest x i.>
+  auto oneVectorType = VectorType::get({1}, origValue.getType());
+  auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
+                                                         ValueRange{origValue});
+  auto vectorBitCast =
+      builder.create<vector::BitCastOp>(loc, value.getType(), fromElem);
+
+  auto select =
+      builder.create<arith::SelectOp>(loc, mask, value, vectorBitCast);
+  auto bitcast2 = builder.create<vector::BitCastOp>(loc, oneVectorType, select);
+  auto extract = builder.create<vector::ExtractOp>(loc, bitcast2, 0);
+  builder.create<memref::AtomicYieldOp>(loc, extract.getResult());
+}
+
+/// Generate a non-atomic read-modify-write sequence for subbyte storing.
+static void rmwStore(OpBuilder &rewriter, Location loc,
+                     MemRefValue emulatedMemref, Value linearizedIndex,
+                     VectorValue value, Value mask,
+                     int64_t numSrcElemsPerDest) {
+  auto emulatedIOType =
+      VectorType::get({1}, emulatedMemref.getType().getElementType());
+  auto elemLoad = rewriter.create<vector::LoadOp>(
+      loc, emulatedIOType, emulatedMemref, ValueRange{linearizedIndex});
+  auto fromBitcast = rewriter.create<vector::BitCastOp>(
+      loc,
+      VectorType::get({numSrcElemsPerDest}, value.getType().getElementType()),
+      elemLoad);
+  auto select = rewriter.create<arith::SelectOp>(loc, mask, fromBitcast, value);
+  auto toBitcast =
+      rewriter.create<vector::BitCastOp>(loc, emulatedIOType, select);
+  rewriter.create<vector::StoreOp>(loc, toBitcast, emulatedMemref,
+                                   linearizedIndex);
+}
+
+static_assert(std::is_same_v<decltype(atomicStore), decltype(rmwStore)> &&
+              "`atomicStore` and `rmwStore` must have same signature, as per "
+              "the design to keep the code clean, which one to call is "
+              "determined by the `useAtomicWrites` flag.");
+
+// Extract a slice of a vector, and insert it into a byte vector.
----------------
hanhanW wrote:

"...insert it into a byte vector" is implementation details, which makes the function comment a little ambiguous. The first question I have when I looked at the comment is that what is `byte vector`? There is only one vector in function argument. How about rephrasing it like:

```
/// Returns a vector with the same type that only has data for the given range. Additonally,
/// the data is offset by `byteOffset` E.g.,
/// Inputs:
///   vector = |01|23|45|67| : vector<4xi2>
///   sliceOffset = 1
///   sliceNumElements = 2
///   byteOffset = 1
/// Output:
///   vector = |00|00|23|45| : vector<4xi2>
```

We can also consider renaming the function to something like `offsetSubvector`. (I'm not good at naming, but I'd like to point out the original function name is not straight-forward to me. Also, please correct me if I misunderstand the code.)

Minor nit: should we assert if (1) it is a 1D vector and (2) it does not access out-of-bound?

https://github.com/llvm/llvm-project/pull/115922


More information about the Mlir-commits mailing list