[Mlir-commits] [mlir] [mlir][vector][nfc] Update `alignedConversionPrecondition` (PR #122136)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 8 08:31:34 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

Adds some comments and renames variables to clarify the usage.


---
Full diff: https://github.com/llvm/llvm-project/pull/122136.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+16-9) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 181c394edc1d20..d04f302200519e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1069,19 +1069,25 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
   return commonConversionPrecondition(rewriter, preconditionType, op);
 }
 
-/// Verify that source and destination element types meet the precondition for
-/// the supported aligned conversion cases. Alignment means that the either the
-/// source element type is multiple of the destination element type or the other
-/// way around.
+/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
+/// means that:
+///   1. The `dstType` element type is a multiple of the
+///   `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
+///   is not supported). Let this multiple be `N`.
+///   2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
+///   multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
+///   not supported).
 ///
-/// NOTE: This method assumes that common conversion preconditions are met.
+/// NOTE: This method assumes that common conversion preconditions are met. In
+/// particular, the element type of `dstType` is assumed to be a multi-byte
+/// type (e.g. i8, i16, i32).
 static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
-                                                   VectorType srcType,
+                                                   VectorType subByteVecType,
                                                    VectorType dstType,
                                                    Operation *op) {
-  if (!srcType || !dstType)
+  if (!subByteVecType || !dstType)
     return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
-  unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
+  unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
   unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
 
   // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
@@ -1089,7 +1095,8 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
       (dstElemBitwidth % srcElemBitwidth) != 0)
     return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
 
-  if ((srcType.getShape().back() % 2) != 0)
+  const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
+  if ((subByteVecType.getShape().back() % numSrcElemsPerDestElem) != 0)
     return rewriter.notifyMatchFailure(
         op, "Not an even number of i4 elements in trailing dim");
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/122136


More information about the Mlir-commits mailing list