[Mlir-commits] [mlir] [mlir][vector][nfc] Update `alignedConversionPrecondition` (PR #122136)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Jan 8 08:30:54 PST 2025


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/122136

Adds some comments and renames variables to clarify the usage.


>From ffd50ccf09eda8a196def71fb45fb5599faadd9e Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 8 Jan 2025 16:28:19 +0000
Subject: [PATCH] [mlir][vector][nfc] Update `alignedConversionPrecondition`

Adds some comments and renames variables to clarify the usage.
---
 .../Transforms/VectorEmulateNarrowType.cpp    | 25 ++++++++++++-------
 1 file changed, 16 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 181c394edc1d20..d04f302200519e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1069,19 +1069,25 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
   return commonConversionPrecondition(rewriter, preconditionType, op);
 }
 
-/// Verify that source and destination element types meet the precondition for
-/// the supported aligned conversion cases. Alignment means that the either the
-/// source element type is multiple of the destination element type or the other
-/// way around.
+/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
+/// means that:
+///   1. The `dstType` element type is a multiple of the
+///   `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
+///   is not supported). Let this multiple be `N`.
+///   2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
+///   multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
+///   not supported).
 ///
-/// NOTE: This method assumes that common conversion preconditions are met.
+/// NOTE: This method assumes that common conversion preconditions are met. In
+/// particular, the element type of `dstType` is assumed to be a multi-byte
+/// type (e.g. i8, i16, i32).
 static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
-                                                   VectorType srcType,
+                                                   VectorType subByteVecType,
                                                    VectorType dstType,
                                                    Operation *op) {
-  if (!srcType || !dstType)
+  if (!subByteVecType || !dstType)
     return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
-  unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
+  unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
   unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
 
   // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
@@ -1089,7 +1095,8 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
       (dstElemBitwidth % srcElemBitwidth) != 0)
     return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
 
-  if ((srcType.getShape().back() % 2) != 0)
+  const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
+  if ((subByteVecType.getShape().back() % numSrcElemsPerDestElem) != 0)
     return rewriter.notifyMatchFailure(
         op, "Not an even number of i4 elements in trailing dim");
 



More information about the Mlir-commits mailing list