[Mlir-commits] [mlir] d928a67 - [mlir][Vector] Refactor VectorEmulateNarrowType.cpp (#123529)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 16 05:22:49 PDT 2025


Author: Andrzej WarzyƄski
Date: 2025-03-16T12:22:46Z
New Revision: d928a671b84afb9c2ad64353694537a198f04651

URL: https://github.com/llvm/llvm-project/commit/d928a671b84afb9c2ad64353694537a198f04651
DIFF: https://github.com/llvm/llvm-project/commit/d928a671b84afb9c2ad64353694537a198f04651.diff

LOG: [mlir][Vector] Refactor VectorEmulateNarrowType.cpp (#123529)

This is PR refactors `alignedConversionPrecondition` from
VectorEmulateNarrowType.cpp and adds new helper hooks.

**Update `alignedConversionPrecondition` (1)**

This method doesn't require the vector type for the "container" argument. The
underlying element type is sufficient. The corresponding argument has been
renamed as `containerTy` - this is meant as the multi-byte container element
type (`i8`, `i16`, `i32`, etc). With this change, the updated invocations of
`alignedConversionPrecondition` (in e.g. `RewriteAlignedSubByteIntExt`) make it
clear that the container element type is assumed to be `i8`.

**Update alignedConversionPrecondition (2):**

The final check in `alignedConversionPrecondition` has been replaced with a new
helper method, `isSubByteVecFittable`. This helper hook is now also re-used in
`ConvertVectorTransferRead` (to improve code re-use).

**Other updates**

Extended + unified comments.

**Implements**: https://github.com/llvm/llvm-project/issues/123630

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 51e72753ff162..cf6efaa04ae44 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -519,7 +519,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
 
     auto origElements = valueToStore.getType().getNumElements();
     // Note, per-element-alignment was already verified above.
-    bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
+    bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -535,8 +535,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
             getAsOpFoldResult(adaptor.getIndices()));
 
     std::optional<int64_t> foldedNumFrontPadElems =
-        isFullyAligned ? 0
-                       : getConstantIntValue(linearizedInfo.intraDataOffset);
+        isDivisibleInSize ? 0
+                          : getConstantIntValue(linearizedInfo.intraDataOffset);
 
     if (!foldedNumFrontPadElems) {
       return rewriter.notifyMatchFailure(
@@ -554,7 +554,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     // need unaligned emulation because the store address is aligned and the
     // source is a whole byte.
     bool emulationRequiresPartialStores =
-        !isFullyAligned || *foldedNumFrontPadElems != 0;
+        !isDivisibleInSize || *foldedNumFrontPadElems != 0;
     if (!emulationRequiresPartialStores) {
       // Basic case: storing full bytes.
       auto numElements = origElements / emulatedPerContainerElem;
@@ -881,7 +881,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
 
     auto origElements = op.getVectorType().getNumElements();
     // Note, per-element-alignment was already verified above.
-    bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
+    bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -897,8 +897,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
             getAsOpFoldResult(adaptor.getIndices()));
 
     std::optional<int64_t> foldedIntraVectorOffset =
-        isFullyAligned ? 0
-                       : getConstantIntValue(linearizedInfo.intraDataOffset);
+        isDivisibleInSize ? 0
+                          : getConstantIntValue(linearizedInfo.intraDataOffset);
 
     // Always load enough elements which can cover the original elements.
     int64_t maxintraDataOffset =
@@ -915,7 +915,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
       result = dynamicallyExtractSubVector(
           rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
           linearizedInfo.intraDataOffset, origElements);
-    } else if (!isFullyAligned) {
+    } else if (!isDivisibleInSize) {
       result = staticallyExtractSubvector(
           rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }
@@ -1002,7 +1002,7 @@ struct ConvertVectorMaskedLoad final
     auto origType = op.getVectorType();
     auto origElements = origType.getNumElements();
     // Note, per-element-alignment was already verified above.
-    bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
+    bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -1017,8 +1017,8 @@ struct ConvertVectorMaskedLoad final
             getAsOpFoldResult(adaptor.getIndices()));
 
     std::optional<int64_t> foldedIntraVectorOffset =
-        isFullyAligned ? 0
-                       : getConstantIntValue(linearizedInfo.intraDataOffset);
+        isDivisibleInSize ? 0
+                          : getConstantIntValue(linearizedInfo.intraDataOffset);
 
     int64_t maxIntraDataOffset =
         foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
@@ -1042,7 +1042,7 @@ struct ConvertVectorMaskedLoad final
       passthru = dynamicallyInsertSubVector(
           rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
           origElements);
-    } else if (!isFullyAligned) {
+    } else if (!isDivisibleInSize) {
       passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
                                            *foldedIntraVectorOffset);
     }
@@ -1070,7 +1070,7 @@ struct ConvertVectorMaskedLoad final
       mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
                                         linearizedInfo.intraDataOffset,
                                         origElements);
-    } else if (!isFullyAligned) {
+    } else if (!isDivisibleInSize) {
       mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
                                        *foldedIntraVectorOffset);
     }
@@ -1081,7 +1081,7 @@ struct ConvertVectorMaskedLoad final
       result = dynamicallyExtractSubVector(
           rewriter, loc, result, op.getPassThru(),
           linearizedInfo.intraDataOffset, origElements);
-    } else if (!isFullyAligned) {
+    } else if (!isDivisibleInSize) {
       result = staticallyExtractSubvector(
           rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }
@@ -1091,6 +1091,38 @@ struct ConvertVectorMaskedLoad final
   }
 };
 
+/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
+///
+/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
+/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
+/// (a multi-byte scalar, e.g. i16), where N is some integer.
+///
+/// Put 
diff erently, this method checks whether this would be valid:
+///
+///   vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
+///
+/// EXAMPLES:
+///   * vector<4xi4> -> i16 - yes (N = 1)
+///   * vector<4xi4> -> i8 - yes (N = 2)
+///   * vector<3xi4> -> i8 - no (N would have to be 1.5)
+///   * vector<3xi2> -> i16 - no (N would have to be 0.5)
+static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
+                                       Type multiByteScalarTy) {
+  assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
+
+  int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
+  int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
+
+  assert(subByteBits < 8 && "Not a sub-byte scalar type!");
+  assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
+  assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
+
+  int elemsPerMultiByte = multiByteBits / subByteBits;
+
+  // TODO: This is a bit too restrictive for vectors rank > 1.
+  return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
+}
+
 //===----------------------------------------------------------------------===//
 // ConvertVectorTransferRead
 //===----------------------------------------------------------------------===//
@@ -1127,7 +1159,8 @@ struct ConvertVectorTransferRead final
     auto origElements = op.getVectorType().getNumElements();
 
     // Note, per-element-alignment was already verified above.
-    bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
+    bool isDivisibleInSize =
+        fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
 
     auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
                                                       adaptor.getPadding());
@@ -1146,8 +1179,8 @@ struct ConvertVectorTransferRead final
             getAsOpFoldResult(adaptor.getIndices()));
 
     std::optional<int64_t> foldedIntraVectorOffset =
-        isFullyAligned ? 0
-                       : getConstantIntValue(linearizedInfo.intraDataOffset);
+        isDivisibleInSize ? 0
+                          : getConstantIntValue(linearizedInfo.intraDataOffset);
 
     int64_t maxIntraDataOffset =
         foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
@@ -1171,7 +1204,7 @@ struct ConvertVectorTransferRead final
       result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
                                            linearizedInfo.intraDataOffset,
                                            origElements);
-    } else if (!isFullyAligned) {
+    } else if (!isDivisibleInSize) {
       result = staticallyExtractSubvector(
           rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }
@@ -1428,41 +1461,69 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
   return commonConversionPrecondition(rewriter, preconditionType, op);
 }
 
-/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
-/// means that:
-///   1. The `dstType` element type is a multiple of the
-///   `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
-///   is not supported). Let this multiple be `N`.
-///   2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
-///   multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
-///   not supported).
+/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
+///
+/// Alignment means that `subByteVecTy` can be packed into a vector of
+/// `containerTy` elements. More specifically:
+///   1. The bit-width of `containerTy` is a multiple of the
+///      bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
+///      this multiple is 4.
+///   2. The multiple from 1. above divides evenly the number of the (trailing)
+///      elements in `subByteVecTy`.
+///
+/// EXAMPLE 1:
+///   `subByteVecTy = vector<2xi4>`, and
+///   `containerTy = i16`
+///
+/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
+///
+/// EXAMPLE 2:
+///   `subByteVecTy = vector<3xi4>`, and
+///   `containerTy = i16`
+///
+/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
+///
+/// EXAMPLE 3:
+///   `subByteVecTy = vector<3xi3>`, and
+///   `containerTy = i16`
+///
+/// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
 ///
 /// NOTE: This method assumes that common conversion preconditions are met. In
-/// particular, the element type of `dstType` is assumed to be a multi-byte
-/// type (e.g. i8, i16, i32).
+/// particular, `containerTy` is assumed to be a
+/// multi-byte scalar type (e.g., i8, i16, i32).
 static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
-                                                   VectorType subByteVecType,
-                                                   VectorType dstType,
+                                                   VectorType subByteVecTy,
+                                                   Type containerTy,
                                                    Operation *op) {
-  if (!subByteVecType || !dstType)
-    return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
-  unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
-  unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
+  assert(containerTy.isIntOrFloat() &&
+         "container element type is not a scalar");
 
-  if (dstElemBitwidth < 8)
-    return rewriter.notifyMatchFailure(
-        op, "the bitwidth of dstType must be greater than or equal to 8");
-  if (dstElemBitwidth % srcElemBitwidth != 0)
-    return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
-  if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
+  // TODO: This is validating the inputs rather than checking the conditions
+  // documented above. Replace with an assert.
+  if (!subByteVecTy)
+    return rewriter.notifyMatchFailure(op, "not a vector!");
+
+  unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
+  unsigned containerBits = containerTy.getIntOrFloatBitWidth();
+
+  // Enforced by the common pre-conditions.
+  assert(containerBits % 8 == 0 && "Not a multi-byte scalar type!");
+
+  // TODO: Add support other widths (when/if needed)
+  if (subByteBits != 2 && subByteBits != 4)
     return rewriter.notifyMatchFailure(
-        op, "only src bitwidth of 2 or 4 is supported at this moment");
+        op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
+
+  // Condition 1 ("per-element" alignment)
+  if (containerBits % subByteBits != 0)
+    return rewriter.notifyMatchFailure(op, "unalagined element types");
 
-  const int numSrcElemsPerByte = 8 / srcElemBitwidth;
-  if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0)
+  // Condition 2 ("full" alignment)
+  if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy))
     return rewriter.notifyMatchFailure(
-        op, "the trailing dimension of the input vector of sub-bytes must be a "
-            "multiple of 8 / <sub-byte-width>");
+        op, "not possible to fit this sub-byte vector type into a vector of "
+            "the given multi-byte type");
 
   return success();
 }
@@ -1899,8 +1960,9 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
       return failure();
 
     // Check general alignment preconditions.
-    if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
-                                             conversionOp)))
+    if (failed(alignedConversionPrecondition(
+            rewriter, srcVecType,
+            /*containerTy=*/rewriter.getI8Type(), conversionOp)))
       return failure();
 
     // Perform the rewrite.
@@ -1964,8 +2026,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
 
     // Check general alignment preconditions. We invert the src/dst type order
     // to reuse the existing precondition logic.
-    if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
-                                             truncOp)))
+    if (failed(alignedConversionPrecondition(
+            rewriter, dstVecType,
+            /*containerTy=*/rewriter.getI8Type(), truncOp)))
       return failure();
 
     // Create a new iX -> i8 truncation op.


        


More information about the Mlir-commits mailing list