[Mlir-commits] [mlir] [mlir][Vector] Refactor VectorEmulateNarrowType.cpp (PR #123529)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Mar 10 09:59:50 PDT 2025


================
@@ -1428,41 +1461,76 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
   return commonConversionPrecondition(rewriter, preconditionType, op);
 }
 
-/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
-/// means that:
-///   1. The `dstType` element type is a multiple of the
-///   `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
-///   is not supported). Let this multiple be `N`.
-///   2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
-///   multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
-///   not supported).
+/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
+///
+/// Alignment means that `subByteVecTy` can be packed into a vector of
+/// `containerTy` elements. More specifically:
+///   1. The bit-width of `containerTy` is a multiple of the
+///      bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
+///      this multiple is 4.
+///   2. The multiple from 1. above divides evenly the number of the (trailing)
+///      elements in `subByteVecTy`.
+///
+/// EXAMPLE 1:
+///   `subByteVecTy = vector<2xi4>`, and
+///   `containerTy = i16`
+///
+/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
+///
+/// EXAMPLE 2:
+///   `subByteVecTy = vector<3xi4>`, and
+///   `containerTy = i16`
+///
+/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
+///
+/// EXAMPLE 3:
+///   `subByteVecTy = vector<3xi3>`, and
+///   `containerTy = i16`
+///
+/// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
 ///
 /// NOTE: This method assumes that common conversion preconditions are met. In
-/// particular, the element type of `dstType` is assumed to be a multi-byte
-/// type (e.g. i8, i16, i32).
+/// particular, `containerTy` is assumed to be a
+/// multi-byte scalar type (e.g., i8, i16, i32).
 static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
-                                                   VectorType subByteVecType,
-                                                   VectorType dstType,
+                                                   VectorType subByteVecTy,
+                                                   Type containerTy,
                                                    Operation *op) {
-  if (!subByteVecType || !dstType)
-    return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
-  unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
-  unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
+  // TODO: This is validating the inputs rather than checking the conditions
+  // documented above. Replace with an assert.
+  if (!subByteVecTy)
+    return rewriter.notifyMatchFailure(op, "not a vector!");
 
-  if (dstElemBitwidth < 8)
-    return rewriter.notifyMatchFailure(
-        op, "the bitwidth of dstType must be greater than or equal to 8");
-  if (dstElemBitwidth % srcElemBitwidth != 0)
-    return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
-  if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
+  // TODO: This is validating the inputs rather than checking the conditions
+  // documented above. Replace with an assert.
+  if (!containerTy.isIntOrFloat())
+    return rewriter.notifyMatchFailure(op, "not a scalar!");
+
+  unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
+  unsigned multiByteBits = containerTy.getIntOrFloatBitWidth();
+
+  // Enforced by the common pre-conditions.
+  assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
+
+  // TODO: Remove this condition - the assert above (and
+  // commonConversionPrecondtion) takes care of that.
+  if (multiByteBits < 8)
+    return rewriter.notifyMatchFailure(op, "not a multi-byte scalar type!");
+
+  // TODO: Add support other widths (when/if needed)
+  if (subByteBits != 2 && subByteBits != 4)
     return rewriter.notifyMatchFailure(
-        op, "only src bitwidth of 2 or 4 is supported at this moment");
+        op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
----------------
banach-space wrote:

> I know some downstream projects try to use 1bit, and I think upstream shouldn't trivially block it in this way.

Oh, definitely not trying to block anyone. This is merely trying to document the existing assumptions. Note that this condition is already present: https://github.com/llvm/llvm-project/blob/5ce4045384d7c2544185b0dbcb6222d06beb47dc/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp#L1457-L1459

> They can contribute i1 tests for sure but overall the code here should support 1-bit scenarios without problem.

They would be welcome with praise and gratitude :) 

https://github.com/llvm/llvm-project/pull/123529


More information about the Mlir-commits mailing list