[Mlir-commits] [mlir] [MLIR] Implement emulation of static indexing subbyte type vector stores (PR #115922)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Jan 16 09:52:21 PST 2025
================
@@ -292,15 +296,119 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
newLoad);
}
+/// Downcast two values to `downcastType`, then select values
+/// based on `mask`, and casts the result to `upcastType`.
+static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
+ VectorType downcastType,
+ VectorType upcastType, Value mask,
+ Value trueValue, Value falseValue) {
+ assert(
+ downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
+ upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
+ "expected input and output number of bits to match");
+ if (trueValue.getType() != downcastType) {
+ trueValue = builder.create<vector::BitCastOp>(loc, downcastType, trueValue);
+ }
+ if (falseValue.getType() != downcastType) {
+ falseValue =
+ builder.create<vector::BitCastOp>(loc, downcastType, falseValue);
+ }
+ Value selectedType =
+ builder.create<arith::SelectOp>(loc, mask, trueValue, falseValue);
+ // Upcast the selected value to the new type.
+ return builder.create<vector::BitCastOp>(loc, upcastType, selectedType);
+}
+
+/// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
+/// byte in `linearizedMemref`, with a mask. The `valueToStore` is a vector of
+/// subbyte-sized elements, with size of 8 bits, and the mask is used to select
+/// which elements to store.
+///
+/// Inputs:
+/// linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
+/// storeIdx = 2
+/// valueToStore = |3|3|3|3| : vector<4xi2>
+/// mask = |0|0|1|1| : vector<4xi1>
+///
+/// Result:
+/// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
+static void atomicStore(OpBuilder &builder, Location loc,
+ MemRefValue linearizedMemref, Value storeIdx,
+ VectorValue valueToStore, Value mask) {
+ assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
+
+ // Create an atomic load-modify-write region using
+ // `memref.generic_atomic_rmw`.
+ auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
+ loc, linearizedMemref, ValueRange{storeIdx});
+ Value origValue = atomicOp.getCurrentValue();
+
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(atomicOp.getBody());
+
+ // Load the original value from memory, and cast it to the original element
+ // type.
+ auto oneElemVecType = VectorType::get({1}, origValue.getType());
+ Value origVecValue = builder.create<vector::FromElementsOp>(
+ loc, oneElemVecType, ValueRange{origValue});
+
+ // Construct the final masked value and yield it.
+ Value maskedValue =
+ downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
+ oneElemVecType, mask, valueToStore, origVecValue);
+ auto scalarMaskedValue =
+ builder.create<vector::ExtractOp>(loc, maskedValue, 0);
+ builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
+}
+
+/// Extract `sliceNumElements` from source `vector` at `extractOffset`,
+/// and insert it into an empty vector at `insertOffset`.
+/// Inputs:
+/// vec_in = |0|1|2|3| : vector<4xi2>
+/// extractOffset = 1
+/// sliceNumElements = 2
+/// insertOffset = 2
+/// Output:
+/// vec_out = |0|0|1|2| : vector<4xi2>
+static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
+ Location loc, VectorValue vector,
+ int64_t extractOffset,
+ int64_t sliceNumElements,
+ int64_t insertOffset) {
+ assert(vector.getType().getRank() == 1 && "expected 1-D vector");
+ auto vectorElementType = vector.getType().getElementType();
+ // TODO: update and use `alignedConversionPrecondition` in the place of
+ // these asserts.
+ assert(
+ sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 &&
+ "sliceNumElements * vector element size must be less than or equal to 8");
+ assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
+ "vector element must be a valid sub-byte type");
+ auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
+ auto emptyByteVector = rewriter.create<arith::ConstantOp>(
+ loc, VectorType::get({scale}, vectorElementType),
+ rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
+ auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
+ extractOffset, sliceNumElements);
+ return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
+ insertOffset);
+}
+
namespace {
//===----------------------------------------------------------------------===//
// ConvertVectorStore
//===----------------------------------------------------------------------===//
+///
+///
+
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;
+ ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
+ : OpConversionPattern<vector::StoreOp>(context) {}
+
----------------
banach-space wrote:
Is this needed?
https://github.com/llvm/llvm-project/pull/115922
More information about the Mlir-commits
mailing list