[Mlir-commits] [mlir] [mlir] Rewrites for I2 to I8 signed and unsigned extension (PR #121298)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Jan 8 09:03:40 PST 2025


================
@@ -1172,70 +1176,144 @@ Value BitCastRewriter::genericRewriteStep(
   return runningResult;
 }
 
-/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
-/// bitwise ops that take advantage of high-level information to avoid leaving
-/// LLVM to scramble with peephole optimizations.
-static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
-                                    Value srcValue) {
-  VectorType srcVecType = cast<VectorType>(srcValue.getType());
-  assert(srcVecType.getElementType().isSignlessInteger(4) &&
-         "Expected i4 type");
+/// takes a aligned subByte vector as Input and bitcasts it to a vector of i8.
+///
+/// Example:
+/// vector<16x16xi2> -> vector<16x2xi8>
+/// vector<16x16xi4> -> vector<16x4xi8>
+static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
+                                      Value srcValue) {
+  auto srcVecType = cast<VectorType>(srcValue.getType());
+  int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
+  assert(8 % srcBitwidth == 0 && "Invalid source bitwidth");
+  int64_t bitwidthFactor = 8 / srcBitwidth;
+  SmallVector<int64_t> vecShape(srcVecType.getShape());
+  // adjust last dimension of the vector so the total size remains the same.
+  vecShape.back() = vecShape.back() / bitwidthFactor;
+  auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
+  return rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+}
 
-  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
-  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
-  constexpr int64_t i4Toi8BitwidthFactor = 2;
-  i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
-  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
-  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+/// Extracts a signed N-bit sequence from each element of an 8-bit vector,
+/// starting at the specified bit index.
+///
+/// Example:
+/// extract numBits=2 starting at bitIdx=2
+/// src    =               [0101|11|10]
+/// shl    = src << 4    -> [11100000]
+/// result = shl >> 6    -> [11111111]
+static Value extractNBitsFromVectorSigned(PatternRewriter &rewriter,
+                                          Location loc, Value src, int bitIdx,
+                                          int numBits) {
+  assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
+         "Invalid bitIdx range");
+  auto srcType = cast<VectorType>(src.getType());
+  Value shl = src;
+  int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
+  if (bitsToShiftLeft != 0) {
+    Value shiftLeftValues = rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
+    shl = rewriter.create<arith::ShLIOp>(loc, src, shiftLeftValues);
+  }
 
-  // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
-  // byte are place in one vector and the high i4 elements in another vector.
-  constexpr int8_t bitsToShift = 4;
-  auto shiftValues = rewriter.create<arith::ConstantOp>(
-      loc, DenseElementsAttr::get(i8VecType, bitsToShift));
-  Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
-  Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
-  Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
+  int8_t bitsToShiftRight = 8 - numBits;
+  Value shiftRightValues = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
+  Value shr = rewriter.create<arith::ShRSIOp>(loc, shl, shiftRightValues);
+  return shr;
+}
 
-  // 3. Interleave low and high i8 elements.
-  return rewriter.create<vector::InterleaveOp>(loc, low, high);
+/// Extracts an unsigned N-bit sequence from each element of an 8-bit vector,
+/// starting at the specified bit index.
+///
+/// Example:
+/// extract numBits=2 starting at bitIdx=2
+/// src                 = [0101|10|10]
+/// mask                = [00000011]
+/// shr    = src >> 6   = [00010110]
+/// result = shr & mask = [00000010]
+static Value extractNBitsFromVectorUnsinged(PatternRewriter &rewriter,
+                                            Location loc, Value src, int bitIdx,
+                                            int numBits) {
+  assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
+         "Invalid bitIdx range");
+  auto srcType = cast<VectorType>(src.getType());
+  int8_t bitsToShiftRight = bitIdx;
+  Value shr = src;
+  if (bitsToShiftRight != 0) {
+    Value shiftRightValues = rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
+    shr = rewriter.create<arith::ShRUIOp>(loc, src, shiftRightValues);
+  }
+  if (bitIdx + numBits == 8) {
+    return shr;
+  }
+  uint8_t lowBitsMask = (1 << numBits) - 1;
+  Value lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(srcType, lowBitsMask));
+  return rewriter.create<arith::AndIOp>(loc, shr, lowBitsMaskValues);
 }
 
-/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
-/// bitwise ops that take advantage of high-level information to avoid leaving
-/// LLVM to scramble with peephole optimizations.
-static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
-                                      Value srcValue) {
-  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+using ExtractNBitsFn =
+    std::function<Value(PatternRewriter &, Location, Value, int, int)>;
+
+/// Rewrite the i4 -> i8  extension into a sequence of shuffles and
+/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
+static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc,
+                              Value srcValue, const ExtractNBitsFn &extFn) {
+  auto srcVecType = cast<VectorType>(srcValue.getType());
   assert(srcVecType.getElementType().isSignlessInteger(4) &&
          "Expected i4 type");
 
   // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
-  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
-  constexpr int64_t i4Toi8BitwidthFactor = 2;
-  i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
-  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
-  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
-
-  // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
-  //  byte are placed in one vector and the high i4 elements in another vector.
-  constexpr uint8_t lowBitsMask = 15; // Equivalent to [00001111] bit mask
-  auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
-      loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
-  Value low = rewriter.create<arith::AndIOp>(loc, i8VecType, i8Vector,
-                                             lowBitsMaskValues);
-  constexpr int8_t highBitsToShift = 4;
-  auto highShiftValues = rewriter.create<arith::ConstantOp>(
-      loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
-  Value high = rewriter.create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
+  Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
+
+  // 2. Extend i4 elements to i8 elements. Low i4 elemens of each
+  // byte are place in one vector and the high i4 elements in another vector.
+  Value low = extFn(rewriter, loc, i8Vector, 0, 4);
+  Value high = extFn(rewriter, loc, i8Vector, 4, 4);
 
   // 3. Interleave low and high i8 elements.
   return rewriter.create<vector::InterleaveOp>(loc, low, high);
 }
 
+/// Rewrite the i2 -> i8  extension into a sequence of shuffles and
+/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
+static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc,
+                              Value srcValue, const ExtractNBitsFn &extFn) {
+  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+  assert(srcVecType.getElementType().isSignlessInteger(2) &&
+         "Expected i2 type");
+
+  // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
+  Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
+
+  // 2. Extract each i2 element
+  // Element 0 (bits 0-1)
+  Value elem0 = extFn(rewriter, loc, i8Vector, 0, 2);
+  // Element 1 (bits 2-3)
+  Value elem1 = extFn(rewriter, loc, i8Vector, 2, 2);
+  // Element 2 (bits 4-5)
+  Value elem2 = extFn(rewriter, loc, i8Vector, 4, 2);
+  // Element 3 (bits 6-7)
+  Value elem3 = extFn(rewriter, loc, i8Vector, 6, 2);
----------------
banach-space wrote:

Are these vectors or scalars?

https://github.com/llvm/llvm-project/pull/121298


More information about the Mlir-commits mailing list