[Mlir-commits] [mlir] [mlir] Rewrites for I2 to I8 signed and unsigned extension (PR #121298)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Jan 10 07:55:32 PST 2025
================
@@ -1179,70 +1183,166 @@ Value BitCastRewriter::genericRewriteStep(
return runningResult;
}
-/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
-/// bitwise ops that take advantage of high-level information to avoid leaving
-/// LLVM to scramble with peephole optimizations.
-static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
- Value srcValue) {
- VectorType srcVecType = cast<VectorType>(srcValue.getType());
- assert(srcVecType.getElementType().isSignlessInteger(4) &&
- "Expected i4 type");
+/// Bitcasts the aligned `subByteVec` vector to a vector of i8.
+/// Where aligned means it satisfies the alignedConversionPreconditions.
+///
+/// Example:
+/// vector<16x16xi2> -> vector<16x2xi8>
+/// vector<16x16xi4> -> vector<16x4xi8>
+static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
+ Value subByteVec) {
+ auto srcVecType = cast<VectorType>(subByteVec.getType());
+ int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
+ assert(8 % srcBitwidth == 0 &&
+ "Unsupported sub-byte type (not a divisor of i8)");
+ int64_t bitwidthFactor = 8 / srcBitwidth;
+ SmallVector<int64_t> vecShape(srcVecType.getShape());
+ // Adjust last dimension of the vector, so the total size remains the same.
+ vecShape.back() = vecShape.back() / bitwidthFactor;
+ auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
+ return rewriter.create<vector::BitCastOp>(loc, i8VecType, subByteVec);
+}
- // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
- SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
- constexpr int64_t i4Toi8BitwidthFactor = 2;
- i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
- auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
- Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+/// Extracts a signed N-bit sequence from each element of an 8-bit vector,
----------------
banach-space wrote:
[nit] This might be my "hardware" bias, but to me an 8-bit vector is literally a vector that holds 8-bits (for comparison, NEON has 128-bit vectors).
```suggestion
/// Extracts a signed N-bit sequence from each element of a vector of bytes,
```
https://github.com/llvm/llvm-project/pull/121298
More information about the Mlir-commits
mailing list