[Mlir-commits] [mlir] [MLIR] VectorEmulateNarrowType to support loading of unaligned vectors (PR #113411)

Han-Chung Wang llvmlistbot at llvm.org
Wed Oct 23 13:36:52 PDT 2024


================
@@ -474,26 +524,47 @@ struct ConvertVectorTransferRead final
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
 
     OpFoldResult linearizedIndices;
-    std::tie(std::ignore, linearizedIndices) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto numElements = (origElements + scale - 1) / scale;
+    auto foldedFrontPaddingSize = getFrontPaddingSize(
+        rewriter, loc, linearizedInfo, isUnalignedEmulation);
+
+    if (!foldedFrontPaddingSize) {
+      // unimplemented case for dynamic front padding size
+      return failure();
+    }
+
+    auto numElements =
+        (*foldedFrontPaddingSize + origElements + scale - 1) / scale;
     auto newReadType = VectorType::get(numElements, newElementType);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
         loc, newReadType, adaptor.getSource(),
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
         newPadding);
 
+    auto bitCastType = VectorType::get(numElements * scale, oldElementType);
     auto bitCast =
-        rewriter.create<vector::BitCastOp>(loc, op.getType(), newRead);
+        rewriter.create<vector::BitCastOp>(loc, bitCastType, newRead);
+
+    if (isUnalignedEmulation) {
+      // we only extract a portion of the vector.
+      rewriter.replaceOpWithNewOp<vector::ExtractStridedSliceOp>(
+          op, op.getType(), bitCast,
+          rewriter.getI64ArrayAttr({*foldedFrontPaddingSize}),
+          rewriter.getI64ArrayAttr({origElements}),
+          rewriter.getI64ArrayAttr({1}));
+    } else {
+      rewriter.replaceOp(op, bitCast->getResult(0));
+    }
----------------
hanhanW wrote:

Can we refactor this part to a method and use it for both side?

https://github.com/llvm/llvm-project/pull/113411


More information about the Mlir-commits mailing list