[Mlir-commits] [mlir] [MLIR] Implement emulation of static indexing subbyte type vector stores (PR #115922)

Han-Chung Wang llvmlistbot at llvm.org
Sun Dec 8 20:36:36 PST 2024


================
@@ -309,6 +314,99 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
       newLoad);
 }
 
+/// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
+/// byte in memory, with a mask. The `valueToStore` is a vector of subbyte-sized
+/// elements, with size of 8 bits, and the mask is used to select which elements
+/// to store.
+///
+/// Inputs:
+///   linearizedMemref = |a|b|c|d| : <4xi2> (<1xi8>)
+///   linearizedIndex = 2
+///   valueToStore = |e|f|g|h| : vector<4xi2>
+///   mask = |0|0|1|1| : vector<4xi1>
+///
+/// Result:
+///   linearizedMemref = |a|b|g|h| : <4xi2> (<1xi8>)
+static void atomicStore(OpBuilder &builder, Location loc,
+                        MemRefValue linearizedMemref, Value linearizedIndex,
+                        VectorValue valueToStore, Value mask,
+                        int64_t numSrcElemsPerDest) {
+  assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
+  auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
+      loc, linearizedMemref, ValueRange{linearizedIndex});
+  Value origValue = atomicOp.getCurrentValue();
+
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointToStart(atomicOp.getBody());
+
+  auto oneVectorType = VectorType::get({1}, origValue.getType());
+  auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
+                                                         ValueRange{origValue});
+  auto vectorBitCast =
+      builder.create<vector::BitCastOp>(loc, valueToStore.getType(), fromElem);
+
+  auto select =
+      builder.create<arith::SelectOp>(loc, mask, valueToStore, vectorBitCast);
+  auto bitcast2 = builder.create<vector::BitCastOp>(loc, oneVectorType, select);
+  auto extract = builder.create<vector::ExtractOp>(loc, bitcast2, 0);
+  builder.create<memref::AtomicYieldOp>(loc, extract.getResult());
+}
+
+/// Generate a non-atomic read-modify-write sequence for subbyte storing.
+/// It has similar logic to `atomicStore`, but without the atomicity.
+static void rmwStore(OpBuilder &rewriter, Location loc,
+                     MemRefValue linearizedMemref, Value linearizedIndex,
+                     VectorValue valueToStore, Value mask,
+                     int64_t numSrcElemsPerDest) {
+  assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
+  auto emulatedIOType =
+      VectorType::get({1}, linearizedMemref.getType().getElementType());
+  auto elemLoad = rewriter.create<vector::LoadOp>(
+      loc, emulatedIOType, linearizedMemref, ValueRange{linearizedIndex});
+  auto fromBitcast = rewriter.create<vector::BitCastOp>(
+      loc,
+      VectorType::get({numSrcElemsPerDest},
+                      valueToStore.getType().getElementType()),
+      elemLoad);
+  auto select =
+      rewriter.create<arith::SelectOp>(loc, mask, fromBitcast, valueToStore);
+  auto toBitcast =
+      rewriter.create<vector::BitCastOp>(loc, emulatedIOType, select);
+  rewriter.create<vector::StoreOp>(loc, toBitcast, linearizedMemref,
+                                   linearizedIndex);
+}
+
+/// Extract `sliceNumElements` from source `vector` at `sliceOffset`,
+/// and insert it into an empty vector at offset `byteOffset`.
+/// Inputs:
+///   vector = |01|23|45|67| : vector<4xi2>
+///   sliceOffset = 1
+///   sliceNumElements = 2
+///   byteOffset = 2
+/// Output:
+///   vector = |00|00|23|45| : vector<4xi2>
+static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
+                                  Location loc, VectorValue vector,
+                                  int64_t sliceOffset, int64_t sliceNumElements,
+                                  int64_t byteOffset) {
+  assert(vector.getType().getRank() == 1 && "expected 1-D vector");
+  auto vectorElementType = vector.getType().getElementType();
+  assert(
+      sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
+      "sliceNumElements * vector element size must be less than or equal to 8");
+  assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
+         "vector element must be a valid sub-byte type");
+  auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
+  auto emptyByteVector = rewriter.create<arith::ConstantOp>(
+      loc, VectorType::get({scale}, vectorElementType),
+      rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
+  auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
+                                              sliceOffset, sliceNumElements);
+  auto inserted = staticallyInsertSubvector(rewriter, loc, extracted,
+                                            emptyByteVector, byteOffset);
+  return inserted;
----------------
hanhanW wrote:

How about returning the value directly?

```suggestion
  return staticallyInsertSubvector(rewriter, loc, extracted,
                                            emptyByteVector, byteOffset);
```

https://github.com/llvm/llvm-project/pull/115922


More information about the Mlir-commits mailing list