[Mlir-commits] [mlir] Rewrites for I2 to I8 signed and unsigned extension (PR #121298)
Han-Chung Wang
llvmlistbot at llvm.org
Mon Jan 6 23:46:23 PST 2025
================
@@ -1172,70 +1172,167 @@ Value BitCastRewriter::genericRewriteStep(
return runningResult;
}
+Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
+ assert(srcBitwidth % 8 != 0 && "Invalid source bitwidth");
+ int64_t bitwidthFactor = 8 / srcBitwidth;
+ SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+ i8VecShape.back() = i8VecShape.back() / bitwidthFactor;
+ auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+ return rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+}
+
+/// Extracts a signed N-bit sequence from each element of an 8-bit vector,
+/// starting at the specified bit index.
+Value extractNBitsFromVectorSigned(PatternRewriter &rewriter, Location loc,
+ Value src, int bitIdx, int numBits) {
+ assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
+ "Invalid bitIdx range");
+ auto srcType = cast<VectorType>(src.getType());
+ Value shl = src;
+ int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
+ if (bitsToShiftLeft != 0) {
+ Value shiftLeftValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
+ shl = rewriter.create<arith::ShLIOp>(loc, src, shiftLeftValues);
+ }
+
+ int8_t bitsToShiftRight = 8 - numBits;
+ Value shiftRightValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
+ Value shr = rewriter.create<arith::ShRSIOp>(loc, shl, shiftRightValues);
+ return shr;
+}
+
/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
-/// bitwise ops that take advantage of high-level information to avoid leaving
-/// LLVM to scramble with peephole optimizations.
+/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
Value srcValue) {
VectorType srcVecType = cast<VectorType>(srcValue.getType());
assert(srcVecType.getElementType().isSignlessInteger(4) &&
"Expected i4 type");
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
- SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
- constexpr int64_t i4Toi8BitwidthFactor = 2;
- i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
- auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
- Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+ Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
// 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
// byte are place in one vector and the high i4 elements in another vector.
- constexpr int8_t bitsToShift = 4;
- auto shiftValues = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(i8VecType, bitsToShift));
- Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
- Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
- Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
+ Value low = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 0, 4);
+ Value high = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 4, 4);
// 3. Interleave low and high i8 elements.
return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
+/// Rewrite the i2 -> i8 signed extension into a sequence of shuffles and
+/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
+static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ assert(srcVecType.getElementType().isSignlessInteger(2) &&
+ "Expected i2 type");
+
+ // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
+ Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
+
+ // 2. Extract each i2 element using shifts
+ // Element 0 (bits 0-1)
+ Value elem0 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 0, 2);
+ // Element 1 (bits 2-3)
+ Value elem1 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 2, 2);
+ // Element 2 (bits 4-5)
+ Value elem2 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 4, 2);
+ // Element 3 (bits 6-7)
+ Value elem3 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 6, 2);
+
+ // 3. Interleave all 4 elements by first interleaving even elements and then
+ // odd elem0 = [0,0,0,0] elem1 = [1,1,1,1] elem2 = [2,2,2,2] elem3 = [3,3,3,3]
+ // 02 = [0,2,0,2]
+ // 13 = [1,3,1,3]
+ // 0213 = [0,1,2,3]
+ Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, elem0, elem2);
+ Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, elem1, elem3);
+ return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
+}
+
+/// Extracts an unsigned N-bit sequence from each element of an 8-bit vector,
+/// starting at the specified bit index.
+Value extractNBitsFromVectorUnsinged(PatternRewriter &rewriter, Location loc,
----------------
hanhanW wrote:
nit: add `static` keyword
https://github.com/llvm/llvm-project/pull/121298
More information about the Mlir-commits
mailing list