[Mlir-commits] [mlir] [mlir][vector][nfc] Update `alignedConversionPrecondition` (PR #122136)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Jan 8 08:30:54 PST 2025
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/122136
Adds some comments and renames variables to clarify the usage.
>From ffd50ccf09eda8a196def71fb45fb5599faadd9e Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 8 Jan 2025 16:28:19 +0000
Subject: [PATCH] [mlir][vector][nfc] Update `alignedConversionPrecondition`
Adds some comments and renames variables to clarify the usage.
---
.../Transforms/VectorEmulateNarrowType.cpp | 25 ++++++++++++-------
1 file changed, 16 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 181c394edc1d20..d04f302200519e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1069,19 +1069,25 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
return commonConversionPrecondition(rewriter, preconditionType, op);
}
-/// Verify that source and destination element types meet the precondition for
-/// the supported aligned conversion cases. Alignment means that the either the
-/// source element type is multiple of the destination element type or the other
-/// way around.
+/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
+/// means that:
+/// 1. The `dstType` element type is a multiple of the
+/// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
+/// is not supported). Let this multiple be `N`.
+/// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
+/// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
+/// not supported).
///
-/// NOTE: This method assumes that common conversion preconditions are met.
+/// NOTE: This method assumes that common conversion preconditions are met. In
+/// particular, the element type of `dstType` is assumed to be a multi-byte
+/// type (e.g. i8, i16, i32).
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
- VectorType srcType,
+ VectorType subByteVecType,
VectorType dstType,
Operation *op) {
- if (!srcType || !dstType)
+ if (!subByteVecType || !dstType)
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
- unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
+ unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
// Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
@@ -1089,7 +1095,8 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
(dstElemBitwidth % srcElemBitwidth) != 0)
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
- if ((srcType.getShape().back() % 2) != 0)
+ const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
+ if ((subByteVecType.getShape().back() % numSrcElemsPerDestElem) != 0)
return rewriter.notifyMatchFailure(
op, "Not an even number of i4 elements in trailing dim");
More information about the Mlir-commits
mailing list