[Mlir-commits] [mlir] [MLIR] VectorEmulateNarrowType to support loading of unaligned vectors (PR #113411)

Han-Chung Wang llvmlistbot at llvm.org
Wed Oct 23 13:36:52 PDT 2024


================
@@ -294,35 +312,67 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
     // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8>
     // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4>
     //
-    // TODO: Currently, only the even number of elements loading is supported.
-    // To deal with the odd number of elements, one has to extract the
-    // subvector at the proper offset after bit-casting.
+    // There are cases where the number of elements to load is not byte-aligned,
+    // for example:
+    //
+    // %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2>
+    //
+    // we will have to load extra bytes and extract the exact slice in between.
+    //
+    // %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8>
+    // %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2>
+    // %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides
+    // = [1]}
+    //        : vector<8xi2> to vector<3xi2>
+    //
+    // TODO: Currently the extract_strided_slice's attributes must be known at
+    // compile time as they must be constants.
 
     auto origElements = op.getVectorType().getNumElements();
-    if (origElements % scale != 0)
-      return failure();
+    bool isUnalignedEmulation = origElements % scale != 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
 
     OpFoldResult linearizedIndices;
-    std::tie(std::ignore, linearizedIndices) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto numElements = (origElements + scale - 1) / scale;
+    auto foldedFrontPaddingSize = getFrontPaddingSize(
+        rewriter, loc, linearizedInfo, isUnalignedEmulation);
+
+    if (!foldedFrontPaddingSize) {
+      // unimplemented case for dynamic front padding size
+      return failure();
+    }
+
+    auto numElements =
+        (*foldedFrontPaddingSize + origElements + scale - 1) / scale;
+    auto loadVectorType = VectorType::get(numElements, newElementType);
     auto newLoad = rewriter.create<vector::LoadOp>(
-        loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
+        loc, loadVectorType, adaptor.getBase(),
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
 
+    auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
----------------
hanhanW wrote:

There are no new/old concept for bitcast op. perhaps just use `bitcastType`?

https://github.com/llvm/llvm-project/pull/113411


More information about the Mlir-commits mailing list