[Mlir-commits] [mlir] add vector subbyte store support (PR #70293)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 30 23:46:08 PDT 2023


================
@@ -33,6 +33,70 @@ using namespace mlir;
 
 namespace {
 
+//===----------------------------------------------------------------------===//
+// ConvertVectorStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto loc = op.getLoc();
+    auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
+    Type oldElementType = op.getValueToStore().getType().getElementType();
+    Type newElementType = convertedType.getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = newElementType.getIntOrFloatBitWidth();
+
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+    int scale = dstBits / srcBits;
+
+    // Adjust the number of elements to store when emulating narrow types.
+    // Here only the 1-D vector load is considered, and the N-D memref types
+    // should be linearized.
+    // For example, to emulate i4 to i8, the following op:
+    //
+    // vector.store %arg1, %0[%arg2, %arg3] :memref<4x8xi4>, vector<8xi4>
+    //
+    // can be replaced with
+    //
+    // vector.store %bitcast_arg1, %alloc[%linear_index] : memref<16xi8>,
+    // vector<4xi8>
+
+    auto origElements = op.getValueToStore().getType().getNumElements();
+    if (origElements % scale != 0)
+      return failure();
+
+    auto stridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
+
+    OpFoldResult linearizedIndices;
+    std::tie(std::ignore, linearizedIndices) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, srcBits, dstBits,
+            stridedMetadata.getConstifiedMixedOffset(),
+            stridedMetadata.getConstifiedMixedSizes(),
+            stridedMetadata.getConstifiedMixedStrides(),
+            getAsOpFoldResult(adaptor.getIndices()));
+
+    auto numElements = (origElements + scale - 1) / scale;
----------------
saienduri wrote:

I originally had this in the case we had 27 i4s for example and the scale would be 2. With the integer division, it would get cut off, but yeah the check before ensures the modulo is 0 so not necessary. Made the change, thanks

https://github.com/llvm/llvm-project/pull/70293


More information about the Mlir-commits mailing list