[Mlir-commits] [mlir] Rewrites for I2 to I8 signed and unsigned extension (PR #121298)

Andrzej WarzyƄski llvmlistbot at llvm.org
Sun Jan 5 09:30:01 PST 2025


================
@@ -1233,6 +1233,117 @@ static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
   return rewriter.create<vector::InterleaveOp>(loc, low, high);
 }
 
+/// Rewrite the i2 -> i8 signed extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
+                                    Value srcValue) {
+  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+  assert(srcVecType.getElementType().isSignlessInteger(2) &&
+         "Expected i2 type");
+
+  // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
+  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+  constexpr int64_t i2Toi8BitwidthFactor = 4;
+  i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
+  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+  // Element 0 (bits 0-1)
+  constexpr int8_t shiftConst6 = 6;
+  auto shiftAttr6 = DenseElementsAttr::get(i8VecType, shiftConst6);
+  auto shiftValues6 = rewriter.create<arith::ConstantOp>(loc, shiftAttr6);
+  Value shl0 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues6);
+  Value elem0 = rewriter.create<arith::ShRSIOp>(loc, shl0, shiftValues6);
+
+  // Element 1 (bits 2-3)
+  constexpr int8_t shiftConst4 = 4;
+  auto shiftAttr4 = DenseElementsAttr::get(i8VecType, shiftConst4);
+  auto shiftValues4 = rewriter.create<arith::ConstantOp>(loc, shiftAttr4);
+  Value shl1 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues4);
+  Value elem1 = rewriter.create<arith::ShRSIOp>(loc, shl1, shiftValues6);
+
+  // Element 2 (bits 4-5)
+  constexpr int8_t shiftConst2 = 2;
+  auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shiftConst2);
+  auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
+  Value shl2 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues2);
+  Value elem2 = rewriter.create<arith::ShRSIOp>(loc, shl2, shiftValues6);
+
+  // Element 3 (bits 6-7)
+  Value elem3 = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues6);
+
+  // interleave all 4 elements by first interleaving even elements and then odd
+  // elem0 = [0,0,0,0]
+  // elem1 = [1,1,1,1]
+  // elem2 = [2,2,2,2]
+  // elem3 = [3,3,3,3]
+  // 02    = [0,2,0,2]
+  // 13    = [1,3,1,3]
+  // 0213  = [0,1,2,3]
+  Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, elem0, elem2);
+  Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, elem1, elem3);
+  return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
+}
+
+/// Rewrite the i2 -> i8 unsigned extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
+                                      Value srcValue) {
+  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+  assert(srcVecType.getElementType().isSignlessInteger(2) &&
+         "Expected i2 type");
+
+  // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
+  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+  constexpr int64_t i2Toi8BitwidthFactor = 4;
+  i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
+  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+  // 2. Extract each i2 element using shifts and masks
+  constexpr uint8_t mask = 3; // Mask for 2 bits: [0000 0011]
+  auto maskAttr = DenseElementsAttr::get(i8VecType, mask);
+  auto maskValues = rewriter.create<arith::ConstantOp>(loc, maskAttr);
+
+  // Element 0 (bits 0-1)
+  Value elem0 = rewriter.create<arith::AndIOp>(loc, i8Vector, maskValues);
+
+  // Element 1 (bits 2-3)
+  constexpr int8_t shift1 = 2;
+  auto shiftAttr1 = DenseElementsAttr::get(i8VecType, shift1);
+  auto shiftValues1 = rewriter.create<arith::ConstantOp>(loc, shiftAttr1);
+  Value shifted1 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues1);
+  Value elem1 = rewriter.create<arith::AndIOp>(loc, shifted1, maskValues);
----------------
banach-space wrote:

This block is repeated multiple times. Also, similar blocks are present in other functions. 

Could you try introducing an utility function instead? For example, `Value extractNBitsFromVal(int bitIdxStart, int bitIdxEnd, Value src)`?

https://github.com/llvm/llvm-project/pull/121298


More information about the Mlir-commits mailing list