[Mlir-commits] [mlir] [MLIR] Implement emulation of static indexing subbyte type vector stores (PR #115922)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Jan 16 09:52:21 PST 2025


================
@@ -292,15 +296,119 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
       newLoad);
 }
 
+/// Downcast two values to `downcastType`, then select values
+/// based on `mask`, and casts the result to `upcastType`.
+static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
+                                     VectorType downcastType,
+                                     VectorType upcastType, Value mask,
+                                     Value trueValue, Value falseValue) {
+  assert(
+      downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
+          upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
+      "expected input and output number of bits to match");
+  if (trueValue.getType() != downcastType) {
+    trueValue = builder.create<vector::BitCastOp>(loc, downcastType, trueValue);
+  }
+  if (falseValue.getType() != downcastType) {
+    falseValue =
+        builder.create<vector::BitCastOp>(loc, downcastType, falseValue);
+  }
+  Value selectedType =
+      builder.create<arith::SelectOp>(loc, mask, trueValue, falseValue);
+  // Upcast the selected value to the new type.
+  return builder.create<vector::BitCastOp>(loc, upcastType, selectedType);
+}
+
+/// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
+/// byte in `linearizedMemref`, with a mask. The `valueToStore` is a vector of
+/// subbyte-sized elements, with size of 8 bits, and the mask is used to select
+/// which elements to store.
+///
+/// Inputs:
+///   linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
+///   storeIdx = 2
+///   valueToStore = |3|3|3|3| : vector<4xi2>
+///   mask = |0|0|1|1| : vector<4xi1>
+///
+/// Result:
+///   linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
+static void atomicStore(OpBuilder &builder, Location loc,
+                        MemRefValue linearizedMemref, Value storeIdx,
+                        VectorValue valueToStore, Value mask) {
+  assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
+
+  // Create an atomic load-modify-write region using
+  // `memref.generic_atomic_rmw`.
+  auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
+      loc, linearizedMemref, ValueRange{storeIdx});
+  Value origValue = atomicOp.getCurrentValue();
+
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointToStart(atomicOp.getBody());
+
+  // Load the original value from memory, and cast it to the original element
+  // type.
+  auto oneElemVecType = VectorType::get({1}, origValue.getType());
+  Value origVecValue = builder.create<vector::FromElementsOp>(
+      loc, oneElemVecType, ValueRange{origValue});
+
+  // Construct the final masked value and yield it.
+  Value maskedValue =
+      downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
+                              oneElemVecType, mask, valueToStore, origVecValue);
+  auto scalarMaskedValue =
+      builder.create<vector::ExtractOp>(loc, maskedValue, 0);
+  builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
+}
+
+/// Extract `sliceNumElements` from source `vector` at `extractOffset`,
+/// and insert it into an empty vector at `insertOffset`.
+/// Inputs:
+///   vec_in  = |0|1|2|3| : vector<4xi2>
+///   extractOffset = 1
+///   sliceNumElements = 2
+///   insertOffset = 2
+/// Output:
+///   vec_out = |0|0|1|2| : vector<4xi2>
+static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
+                                  Location loc, VectorValue vector,
+                                  int64_t extractOffset,
+                                  int64_t sliceNumElements,
+                                  int64_t insertOffset) {
+  assert(vector.getType().getRank() == 1 && "expected 1-D vector");
+  auto vectorElementType = vector.getType().getElementType();
+  // TODO: update and use `alignedConversionPrecondition` in the place of
+  // these asserts.
+  assert(
+      sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
+      "sliceNumElements * vector element size must be less than or equal to 8");
+  assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
+         "vector element must be a valid sub-byte type");
+  auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
+  auto emptyByteVector = rewriter.create<arith::ConstantOp>(
+      loc, VectorType::get({scale}, vectorElementType),
+      rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
+  auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
+                                              extractOffset, sliceNumElements);
+  return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
+                                   insertOffset);
+}
+
 namespace {
 
 //===----------------------------------------------------------------------===//
 // ConvertVectorStore
 //===----------------------------------------------------------------------===//
 
+///
+///
+
 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
   using OpConversionPattern::OpConversionPattern;
 
+  ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
+      : OpConversionPattern<vector::StoreOp>(context) {}
+
----------------
banach-space wrote:

Is this needed?

https://github.com/llvm/llvm-project/pull/115922


More information about the Mlir-commits mailing list