[Mlir-commits] [mlir] Rewrites for I2 to I8 signed and unsigned extension (PR #121298)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 6 04:14:38 PST 2025
================
@@ -1233,6 +1233,117 @@ static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
+/// Rewrite the i2 -> i8 signed extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ assert(srcVecType.getElementType().isSignlessInteger(2) &&
+ "Expected i2 type");
+
+ // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
+ SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+ constexpr int64_t i2Toi8BitwidthFactor = 4;
+ i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
+ auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+ Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+ // Element 0 (bits 0-1)
+ constexpr int8_t shiftConst6 = 6;
+ auto shiftAttr6 = DenseElementsAttr::get(i8VecType, shiftConst6);
+ auto shiftValues6 = rewriter.create<arith::ConstantOp>(loc, shiftAttr6);
+ Value shl0 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues6);
+ Value elem0 = rewriter.create<arith::ShRSIOp>(loc, shl0, shiftValues6);
+
+ // Element 1 (bits 2-3)
+ constexpr int8_t shiftConst4 = 4;
+ auto shiftAttr4 = DenseElementsAttr::get(i8VecType, shiftConst4);
+ auto shiftValues4 = rewriter.create<arith::ConstantOp>(loc, shiftAttr4);
+ Value shl1 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues4);
+ Value elem1 = rewriter.create<arith::ShRSIOp>(loc, shl1, shiftValues6);
+
+ // Element 2 (bits 4-5)
+ constexpr int8_t shiftConst2 = 2;
+ auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shiftConst2);
+ auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
+ Value shl2 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues2);
+ Value elem2 = rewriter.create<arith::ShRSIOp>(loc, shl2, shiftValues6);
+
+ // Element 3 (bits 6-7)
+ Value elem3 = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues6);
+
+ // interleave all 4 elements by first interleaving even elements and then odd
+ // elem0 = [0,0,0,0]
+ // elem1 = [1,1,1,1]
+ // elem2 = [2,2,2,2]
+ // elem3 = [3,3,3,3]
+ // 02 = [0,2,0,2]
+ // 13 = [1,3,1,3]
+ // 0213 = [0,1,2,3]
+ Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, elem0, elem2);
+ Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, elem1, elem3);
+ return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
+}
+
+/// Rewrite the i2 -> i8 unsigned extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ assert(srcVecType.getElementType().isSignlessInteger(2) &&
+ "Expected i2 type");
+
+ // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
+ SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+ constexpr int64_t i2Toi8BitwidthFactor = 4;
+ i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
+ auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+ Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+ // 2. Extract each i2 element using shifts and masks
+ constexpr uint8_t mask = 3; // Mask for 2 bits: [0000 0011]
+ auto maskAttr = DenseElementsAttr::get(i8VecType, mask);
+ auto maskValues = rewriter.create<arith::ConstantOp>(loc, maskAttr);
+
+ // Element 0 (bits 0-1)
+ Value elem0 = rewriter.create<arith::AndIOp>(loc, i8Vector, maskValues);
+
+ // Element 1 (bits 2-3)
+ constexpr int8_t shift1 = 2;
+ auto shiftAttr1 = DenseElementsAttr::get(i8VecType, shift1);
+ auto shiftValues1 = rewriter.create<arith::ConstantOp>(loc, shiftAttr1);
+ Value shifted1 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues1);
+ Value elem1 = rewriter.create<arith::AndIOp>(loc, shifted1, maskValues);
----------------
ziereis wrote:
Very good idea, thank you. I created the function with ```int bitIdx, int numBits``` instead of a start and end index because this was easier for me to reason about.
https://github.com/llvm/llvm-project/pull/121298
More information about the Mlir-commits
mailing list