[Mlir-commits] [mlir] [MLIR] VectorEmulateNarrowType to support loading of unaligned vectors (PR #113411)

Han-Chung Wang llvmlistbot at llvm.org
Thu Oct 24 12:57:32 PDT 2024


================
@@ -396,29 +486,49 @@ struct ConvertVectorMaskedLoad final
     // subvector at the proper offset after bit-casting.
     auto origType = op.getVectorType();
     auto origElements = origType.getNumElements();
-    if (origElements % scale != 0)
-      return failure();
+    bool isUnalignedEmulation = origElements % scale != 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
     OpFoldResult linearizedIndices;
-    std::tie(std::ignore, linearizedIndices) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
             stridedMetadata.getConstifiedMixedSizes(),
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
+    auto foldedFrontPaddingSize = getFrontPaddingSize(
+        rewriter, loc, linearizedInfo, isUnalignedEmulation);
+    if (!foldedFrontPaddingSize) {
+      // unimplemented case for dynamic front padding size
+      return failure();
+    }
+
     FailureOr<Operation *> newMask =
-        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
+        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale,
+                            *foldedFrontPaddingSize);
     if (failed(newMask))
       return failure();
 
-    auto numElements = (origElements + scale - 1) / scale;
+    auto numElements =
+        llvm::alignTo(*foldedFrontPaddingSize + origElements, scale) / scale;
     auto newType = VectorType::get(numElements, newElementType);
+
+    auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
+
+    Value passthru = op.getPassThru();
+    if (isUnalignedEmulation) {
+      // create an empty vector of the new type
+      auto emptyVector = rewriter.create<arith::ConstantOp>(
+          loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
+      passthru = insertSubvectorInto(rewriter, loc, op.getPassThru(),
----------------
hanhanW wrote:

nit: you can replace the `op.getPassThru()` with `passthru`.

https://github.com/llvm/llvm-project/pull/113411


More information about the Mlir-commits mailing list