[Mlir-commits] [mlir] [MLIR] Implement emulation of static indexing subbyte type vector stores (PR #115922)

Han-Chung Wang llvmlistbot at llvm.org
Wed Nov 20 10:59:43 PST 2024


================
@@ -306,6 +307,73 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
       newLoad);
 }
 
+/// Atomically store a subbyte-sized value to memory, with a mask.
+static void atomicStore(OpBuilder &rewriter, Location loc,
+                        TypedValue<MemRefType> emulatedMemref,
+                        Value emulatedIndex, TypedValue<VectorType> value,
+                        Value mask, int64_t scale) {
+  auto atomicOp = rewriter.create<memref::GenericAtomicRMWOp>(
+      loc, emulatedMemref, ValueRange{emulatedIndex});
+  OpBuilder builder =
+      OpBuilder::atBlockEnd(atomicOp.getBody(), rewriter.getListener());
+  Value origValue = atomicOp.getCurrentValue();
+
+  // i8 -> vector type <1xi8> then <1xi8> -> <scale x i.>
+  auto oneVectorType = VectorType::get({1}, origValue.getType());
+  auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
+                                                         ValueRange{origValue});
+  auto vectorBitCast =
+      builder.create<vector::BitCastOp>(loc, value.getType(), fromElem);
+
+  auto select =
+      builder.create<arith::SelectOp>(loc, mask, value, vectorBitCast);
+  auto bitcast2 = builder.create<vector::BitCastOp>(loc, oneVectorType, select);
+  auto extract = builder.create<vector::ExtractOp>(loc, bitcast2, 0);
+  builder.create<memref::AtomicYieldOp>(loc, extract.getResult());
+}
+
+/// Generate a non-atomic read-modify-write sequence for subbyte storing.
+static void rmwStore(OpBuilder &rewriter, Location loc,
+                     TypedValue<MemRefType> emulatedMemref, Value emulatedIndex,
+                     TypedValue<VectorType> value, Value mask,
+                     int64_t numSrcElemsPerDest) {
+  auto emulatedIOType =
+      VectorType::get({1}, emulatedMemref.getType().getElementType());
+  auto elemLoad = rewriter.create<vector::LoadOp>(
+      loc, emulatedIOType, emulatedMemref, ValueRange{emulatedIndex});
+  auto fromBitcast = rewriter.create<vector::BitCastOp>(
+      loc,
+      VectorType::get({numSrcElemsPerDest}, value.getType().getElementType()),
+      elemLoad);
+  auto select = rewriter.create<arith::SelectOp>(loc, mask, fromBitcast, value);
+  auto toBitcast =
+      rewriter.create<vector::BitCastOp>(loc, emulatedIOType, select);
+  rewriter.create<vector::StoreOp>(loc, toBitcast, emulatedMemref,
+                                   emulatedIndex);
+}
+
+static_assert(std::is_same_v<decltype(atomicStore), decltype(rmwStore)> &&
+              "`atomicStore` and `rmwStore` must have same function type.");
----------------
hanhanW wrote:

Can you add a comment about why they need to have the same function type? It is not easy to figure out without looking at the codes in the whole file, IMHO.

https://github.com/llvm/llvm-project/pull/115922


More information about the Mlir-commits mailing list