[Mlir-commits] [mlir] [mlir] Rewrites for I2 to I8 signed and unsigned extension (PR #121298)

Han-Chung Wang llvmlistbot at llvm.org
Sun Jan 12 22:31:18 PST 2025


================
@@ -1179,70 +1184,166 @@ Value BitCastRewriter::genericRewriteStep(
   return runningResult;
 }
 
-/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
-/// bitwise ops that take advantage of high-level information to avoid leaving
-/// LLVM to scramble with peephole optimizations.
-static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
-                                    Value srcValue) {
-  VectorType srcVecType = cast<VectorType>(srcValue.getType());
-  assert(srcVecType.getElementType().isSignlessInteger(4) &&
-         "Expected i4 type");
+/// Bitcasts the aligned `subByteVec` vector to a vector of i8.
+/// Where aligned means it satisfies the alignedConversionPreconditions.
+///
+/// Example:
+/// vector<16x16xi2> -> vector<16x4xi8>
+/// vector<16x16xi4> -> vector<16x8xi8>
+static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
+                                      Value subByteVec) {
+  auto srcVecType = cast<VectorType>(subByteVec.getType());
+  int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
+  assert(8 % srcBitwidth == 0 &&
+         "Unsupported sub-byte type (not a divisor of i8)");
+  int64_t numSrcElemsPerByte = 8 / srcBitwidth;
+  SmallVector<int64_t> vecShape(srcVecType.getShape());
+  // Adjust last dimension of the vector, so the total size remains the same.
+  vecShape.back() = vecShape.back() / numSrcElemsPerByte;
+  auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
+  return rewriter.create<vector::BitCastOp>(loc, i8VecType, subByteVec);
+}
 
-  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
-  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
-  constexpr int64_t i4Toi8BitwidthFactor = 2;
-  i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
-  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
-  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+/// Extracts a signed N-bit sequence from each element of a vector of bytes,
+/// starting at the specified bit index.
+/// The `bitIdx` starts at 0 from the LSB and moves to the left.
+///
+/// Example for a single element:
+/// Extract numBits=2 starting at bitIdx=2
+/// src     = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
+/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
+/// target  = [.   .   .   .   ^   ^   .   .]
+///
+/// The target sequence is [11](decimal=-1) as signed 2-bit integer.
+/// So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer.
+///
+/// src     =                         [01 01 11 10]
+/// shl     = arith.shl(src, 4)    -> [11 10 00 00]
+/// result  = arith.shrsi(shl, 6)  -> [11 11 11 11]
+static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter,
+                                                  Location loc, Value src,
+                                                  int bitIdx, int numBits) {
+  assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
+         "Invalid bitIdx range");
+  auto srcType = cast<VectorType>(src.getType());
+  Value shl = src;
+  int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
----------------
hanhanW wrote:

optional nit: I think we can move this to before the assertion, and replace the `bitIdx <= 8 - numBits` with `bitsToShiftLeft >= 0`. The statement looks more straight-forward to me.

https://github.com/llvm/llvm-project/pull/121298


More information about the Mlir-commits mailing list