[Mlir-commits] [mlir] [mlir][Vector] Refactor VectorEmulateNarrowType.cpp (PR #123529)
Alan Li
llvmlistbot at llvm.org
Mon Mar 10 07:53:23 PDT 2025
================
@@ -1428,41 +1461,76 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
return commonConversionPrecondition(rewriter, preconditionType, op);
}
-/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
-/// means that:
-/// 1. The `dstType` element type is a multiple of the
-/// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
-/// is not supported). Let this multiple be `N`.
-/// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
-/// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
-/// not supported).
+/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
+///
+/// Alignment means that `subByteVecTy` can be packed into a vector of
+/// `containerTy` elements. More specifically:
+/// 1. The bit-width of `containerTy` is a multiple of the
+/// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
+/// this multiple is 4.
+/// 2. The multiple from 1. above divides evenly the number of the (trailing)
+/// elements in `subByteVecTy`.
+///
+/// EXAMPLE 1:
+/// `subByteVecTy = vector<2xi4>`, and
+/// `containerTy = i16`
+///
+/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
+///
+/// EXAMPLE 2:
+/// `subByteVecTy = vector<3xi4>`, and
+/// `containerTy = i16`
+///
+/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
+///
+/// EXAMPLE 3:
+/// `subByteVecTy = vector<3xi3>`, and
+/// `containerTy = i16`
+///
+/// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
///
/// NOTE: This method assumes that common conversion preconditions are met. In
-/// particular, the element type of `dstType` is assumed to be a multi-byte
-/// type (e.g. i8, i16, i32).
+/// particular, `containerTy` is assumed to be a
+/// multi-byte scalar type (e.g., i8, i16, i32).
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
- VectorType subByteVecType,
- VectorType dstType,
+ VectorType subByteVecTy,
+ Type containerTy,
Operation *op) {
- if (!subByteVecType || !dstType)
- return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
- unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
- unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
+ // TODO: This is validating the inputs rather than checking the conditions
+ // documented above. Replace with an assert.
+ if (!subByteVecTy)
+ return rewriter.notifyMatchFailure(op, "not a vector!");
- if (dstElemBitwidth < 8)
- return rewriter.notifyMatchFailure(
- op, "the bitwidth of dstType must be greater than or equal to 8");
- if (dstElemBitwidth % srcElemBitwidth != 0)
- return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
- if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
+ // TODO: This is validating the inputs rather than checking the conditions
+ // documented above. Replace with an assert.
+ if (!containerTy.isIntOrFloat())
+ return rewriter.notifyMatchFailure(op, "not a scalar!");
+
+ unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
+ unsigned multiByteBits = containerTy.getIntOrFloatBitWidth();
+
+ // Enforced by the common pre-conditions.
+ assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
+
+ // TODO: Remove this condition - the assert above (and
+ // commonConversionPrecondtion) takes care of that.
+ if (multiByteBits < 8)
+ return rewriter.notifyMatchFailure(op, "not a multi-byte scalar type!");
+
+ // TODO: Add support other widths (when/if needed)
+ if (subByteBits != 2 && subByteBits != 4)
return rewriter.notifyMatchFailure(
- op, "only src bitwidth of 2 or 4 is supported at this moment");
+ op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
----------------
lialan wrote:
1 bit?
https://github.com/llvm/llvm-project/pull/123529
More information about the Mlir-commits
mailing list