[Mlir-commits] [mlir] [MLIR] Implement emulation of static indexing subbyte type vector stores (PR #115922)

Han-Chung Wang llvmlistbot at llvm.org
Tue Dec 3 16:29:54 PST 2024


================
@@ -309,6 +314,76 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
       newLoad);
 }
 
+/// Atomically store a subbyte-sized value to memory, with a mask.
+static void atomicStore(OpBuilder &builder, Location loc,
+                        MemRefValue emulatedMemref, Value linearizedIndex,
+                        VectorValue value, Value mask,
+                        int64_t numSrcElemsPerDest) {
+  auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
+      loc, emulatedMemref, ValueRange{linearizedIndex});
+  Value origValue = atomicOp.getCurrentValue();
+
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointToStart(atomicOp.getBody());
+
+  // i8 -> <1xi8> -> <numSrcElemsPerDest x i.>
+  auto oneVectorType = VectorType::get({1}, origValue.getType());
+  auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
+                                                         ValueRange{origValue});
+  auto vectorBitCast =
+      builder.create<vector::BitCastOp>(loc, value.getType(), fromElem);
+
+  auto select =
+      builder.create<arith::SelectOp>(loc, mask, value, vectorBitCast);
+  auto bitcast2 = builder.create<vector::BitCastOp>(loc, oneVectorType, select);
+  auto extract = builder.create<vector::ExtractOp>(loc, bitcast2, 0);
+  builder.create<memref::AtomicYieldOp>(loc, extract.getResult());
----------------
hanhanW wrote:

Putting my understanding to the snippet, please correct me if I'm wrong. (I did the same thing in the previous review, and I forgot. It is also a note for myself.)

So the `GenericAtomicRMWOp` operation takes a memref and indices to update the particular scalar. Here, we are using vector.bitcast ops to do type conversion. My first question was why can't we use arith.bitcast ops, and it turns out that it only supports the conversion between integers and floating points. I don't find other bitcast ops which performs the same on scalar -- which is reasonable to me.

Then we use the mask to construct the final data. Now it is in `vector<numSrcElemsPerDest x i.>` type. Then we bitcast it to the original type (i.e., vector<1xi8>` and use the `vector.extract` to get the scalar.

It'd be awesome if you can have an example in the comment. It should be easier to understand even they are just pseudo IR. It could be my own issue, but I feel that it is easier to follow when I look at the IR. 

[optional] It'd be super great if you can draw it with simple ascii art like:

```
Construct the vector<1xi8> element from the original value:
  |xxxxxxxx|

Cast the <1xi8> to match `valueToStore`, using vector<4xi2> for example:
|xx|xx|xx|xx|

Use the mask to construct the final data:
  arith.select mask, |xx|xx|xx|xx|, |yy|yy|yy|yy|

Convert the ...
```

https://github.com/llvm/llvm-project/blob/ea6cdb9a0708330089d583ce20aeaf81eec94ff7/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td#L1024-L1029

https://github.com/llvm/llvm-project/pull/115922


More information about the Mlir-commits mailing list