[Mlir-commits] [mlir] 929eb50 - [mlir] Rewrites for I2 to I8 signed and unsigned extension (#121298)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 15 00:12:42 PST 2025
Author: ziereis
Date: 2025-01-15T08:12:39Z
New Revision: 929eb500d4c9b3fff0693c49fd55c8093dc1ad62
URL: https://github.com/llvm/llvm-project/commit/929eb500d4c9b3fff0693c49fd55c8093dc1ad62
DIFF: https://github.com/llvm/llvm-project/commit/929eb500d4c9b3fff0693c49fd55c8093dc1ad62.diff
LOG: [mlir] Rewrites for I2 to I8 signed and unsigned extension (#121298)
Adds rewrites for i2 to i8 signed and unsigned extension, similar to the
ones that already exist for i4 to i8 conversion.
I use this for i6 quantized models, and this gives me roughly a 2x
speedup for an i6 4096x4096 dequantization-matmul on an AMD 5950x.
I didn't add the rewrite for i8 to i2 truncation because I currently
don't use it, but if this is needed, I can add it as well.
---------
Co-authored-by: Andrzej WarzyĆski <andrzej.warzynski at gmail.com>
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index d04f302200519e..a674a590091815 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1090,15 +1090,20 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
- // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
- if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
- (dstElemBitwidth % srcElemBitwidth) != 0)
- return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
+ if (dstElemBitwidth < 8)
+ return rewriter.notifyMatchFailure(
+ op, "the bitwidth of dstType must be greater than or equal to 8");
+ if (dstElemBitwidth % srcElemBitwidth != 0)
+ return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
+ if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
+ return rewriter.notifyMatchFailure(
+ op, "only src bitwidth of 2 or 4 is supported at this moment");
- const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
- if ((subByteVecType.getShape().back() % numSrcElemsPerDestElem) != 0)
+ const int numSrcElemsPerByte = 8 / srcElemBitwidth;
+ if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0)
return rewriter.notifyMatchFailure(
- op, "Not an even number of i4 elements in trailing dim");
+ op, "the trailing dimension of the input vector of sub-bytes must be a "
+ "multiple of 8 / <sub-byte-width>");
return success();
}
@@ -1179,70 +1184,166 @@ Value BitCastRewriter::genericRewriteStep(
return runningResult;
}
-/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
-/// bitwise ops that take advantage of high-level information to avoid leaving
-/// LLVM to scramble with peephole optimizations.
-static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
- Value srcValue) {
- VectorType srcVecType = cast<VectorType>(srcValue.getType());
- assert(srcVecType.getElementType().isSignlessInteger(4) &&
- "Expected i4 type");
+/// Bitcasts the aligned `subByteVec` vector to a vector of i8.
+/// Where aligned means it satisfies the alignedConversionPreconditions.
+///
+/// Example:
+/// vector<16x16xi2> -> vector<16x4xi8>
+/// vector<16x16xi4> -> vector<16x8xi8>
+static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
+ Value subByteVec) {
+ auto srcVecType = cast<VectorType>(subByteVec.getType());
+ int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
+ assert(8 % srcBitwidth == 0 &&
+ "Unsupported sub-byte type (not a divisor of i8)");
+ int64_t numSrcElemsPerByte = 8 / srcBitwidth;
+ SmallVector<int64_t> vecShape(srcVecType.getShape());
+ // Adjust last dimension of the vector, so the total size remains the same.
+ vecShape.back() = vecShape.back() / numSrcElemsPerByte;
+ auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
+ return rewriter.create<vector::BitCastOp>(loc, i8VecType, subByteVec);
+}
- // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
- SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
- constexpr int64_t i4Toi8BitwidthFactor = 2;
- i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
- auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
- Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+/// Extracts a signed N-bit sequence from each element of a vector of bytes,
+/// starting at the specified bit index.
+/// The `bitIdx` starts at 0 from the LSB and moves to the left.
+///
+/// Example for a single element:
+/// Extract numBits=2 starting at bitIdx=2
+/// src = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
+/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
+/// target = [. . . . ^ ^ . .]
+///
+/// The target sequence is [11](decimal=-1) as signed 2-bit integer.
+/// So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer.
+///
+/// src = [01 01 11 10]
+/// shl = arith.shl(src, 4) -> [11 10 00 00]
+/// result = arith.shrsi(shl, 6) -> [11 11 11 11]
+static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter,
+ Location loc, Value src,
+ int bitIdx, int numBits) {
+ auto srcType = cast<VectorType>(src.getType());
+ Value shl = src;
+ int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
+ assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
+ "Invalid bitIdx range");
+ if (bitsToShiftLeft != 0) {
+ Value shiftLeftValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
+ shl = rewriter.create<arith::ShLIOp>(loc, src, shiftLeftValues);
+ }
- // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
- // byte are place in one vector and the high i4 elements in another vector.
- constexpr int8_t bitsToShift = 4;
- auto shiftValues = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(i8VecType, bitsToShift));
- Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
- Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
- Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
+ int8_t bitsToShiftRight = 8 - numBits;
+ Value shiftRightValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
+ Value shr = rewriter.create<arith::ShRSIOp>(loc, shl, shiftRightValues);
+ return shr;
+}
- // 3. Interleave low and high i8 elements.
- return rewriter.create<vector::InterleaveOp>(loc, low, high);
+/// Extracts an unsigned N-bit sequence from each element of a vector of bytes,
+/// starting at the specified bit index.
+/// The `bitIdx` starts at 0 from the LSB and moves to the left.
+///
+/// Example for a single element:
+/// Extract numBits=2 starting at bitIdx=2
+/// src = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0]
+/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
+/// target = [. . . . ^ ^ . .]
+///
+/// The target sequence is [10](decimal=2) as unsigned 2-bit integer.
+/// So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer.
+///
+/// src = [01 01 10 10]
+/// mask = [00 00 00 11]
+/// shr = arith.shrui(src, 2) = [00 01 01 10]
+/// result = arith.andi(shr, mask) = [00 00 00 10]
+/// NOTE: Similarly to extractNBitsPerByteAndSignExtendToI8, this could be
+/// achieved by using arith::ShLIOp + arith::ShRUIOp instead of the masking.
+/// However, by using arith::ShRUIOp + arith::AndIOp, we are eliminating shift
+/// left when the index is 0.
+static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter,
+ Location loc, Value src,
+ int bitIdx, int numBits) {
+ assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
+ "Invalid bitIdx range");
+ auto srcType = cast<VectorType>(src.getType());
+ int8_t bitsToShiftRight = bitIdx;
+ Value shr = src;
+ if (bitsToShiftRight != 0) {
+ Value shiftRightValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
+ shr = rewriter.create<arith::ShRUIOp>(loc, src, shiftRightValues);
+ }
+ if (bitIdx + numBits == 8) {
+ return shr;
+ }
+ uint8_t lowBitsMask = (1 << numBits) - 1;
+ Value lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(srcType, lowBitsMask));
+ return rewriter.create<arith::AndIOp>(loc, shr, lowBitsMaskValues);
}
-/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
-/// bitwise ops that take advantage of high-level information to avoid leaving
-/// LLVM to scramble with peephole optimizations.
-static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
- Value srcValue) {
- VectorType srcVecType = cast<VectorType>(srcValue.getType());
+using ExtractNBitsFn =
+ std::function<Value(PatternRewriter &, Location, Value, int, int)>;
+
+/// Rewrite the i4 -> i8 extension into a sequence of shuffles and
+/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
+static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc,
+ Value srcValue, const ExtractNBitsFn &extFn) {
+ auto srcVecType = cast<VectorType>(srcValue.getType());
assert(srcVecType.getElementType().isSignlessInteger(4) &&
"Expected i4 type");
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
- SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
- constexpr int64_t i4Toi8BitwidthFactor = 2;
- i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
- auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
- Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
-
- // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
- // byte are placed in one vector and the high i4 elements in another vector.
- constexpr uint8_t lowBitsMask = 15; // Equivalent to [00001111] bit mask
- auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
- Value low = rewriter.create<arith::AndIOp>(loc, i8VecType, i8Vector,
- lowBitsMaskValues);
- constexpr int8_t highBitsToShift = 4;
- auto highShiftValues = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
- Value high = rewriter.create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
+ Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
+
+ // 2. Extend i4 elements to i8 elements. Low i4 elemens of each
+ // byte are place in one vector and the high i4 elements in another vector.
+ Value low = extFn(rewriter, loc, i8Vector, 0, 4);
+ Value high = extFn(rewriter, loc, i8Vector, 4, 4);
// 3. Interleave low and high i8 elements.
return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
+/// Rewrite the i2 -> i8 extension into a sequence of shuffles and
+/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
+static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc,
+ Value srcValue, const ExtractNBitsFn &extFn) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ assert(srcVecType.getElementType().isSignlessInteger(2) &&
+ "Expected i2 type");
+
+ // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
+ Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
+
+ // 2. Extract each i2 element
+ // Positon 0 (bits 0-1)
+ Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
+ // Position 1 (bits 2-3)
+ Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
+ // Position 2 (bits 4-5)
+ Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
+ // Position 3 (bits 6-7)
+ Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
+
+ // 3. Interleave all 4 elements by first interleaving
+ // even elements and then odd
+ // vec0 = [0,0,0,0],...
+ // vec1 = [1,1,1,1],...
+ // vec2 = [2,2,2,2],...
+ // vec3 = [3,3,3,3],...
+ // 02 = [0,2,0,2,0,2,0,2],...
+ // 13 = [1,3,1,3,1,3,1,3],...
+ // 0213 = [0,1,2,3,...],...
+ Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, vec0, vec2);
+ Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, vec1, vec3);
+ return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
+}
+
/// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
-/// ops that take advantage of high-level information to avoid leaving LLVM to
-/// scramble with peephole optimizations.
+/// ops to avoid leaving LLVM to scramble with peephole optimizations.
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
Value srcValue) {
VectorType srcVecType = cast<VectorType>(srcValue.getType());
@@ -1443,13 +1544,19 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
return failure();
// Perform the rewrite.
+ Location loc = conversionOp.getLoc();
+ const auto &extFn = isSigned ? extractNBitsPerByteAndSignExtendToI8
+ : extractNBitsPerByteAndExtendToI8;
Value subByteExt;
- if (isSigned) {
- subByteExt =
- rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
- } else {
- subByteExt =
- rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
+ case 2:
+ subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn);
+ break;
+ case 4:
+ subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn);
+ break;
+ default:
+ return failure();
}
// Finalize the rewrite.
@@ -1490,6 +1597,10 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
return failure();
+ // TODO: Add support for truncating to i2.
+ if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
+ return failure();
+
// Check general alignment preconditions. We invert the src/dst type order
// to reuse the existing precondition logic.
if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 210025e30d7db5..8d28f248e392d2 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -193,6 +193,25 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
return %1 : vector<8xi17>
}
+
+// Negative test - the trailing dim 1 is not a multiple of 2 (i.e. 8 / 4).
+// CHECK-LABEL: func.func @unaligned_extsi_i4_to_i8(
+func.func @unaligned_extsi_i4_to_i8(%a: vector<1xi4>) -> vector<1xi8> {
+ // CHECK-NOT: arith.bitcast
+ // CHECK: arith.extsi %[[IN:.*]] : vector<1xi4> to vector<1xi8>
+ %0 = arith.extsi %a : vector<1xi4> to vector<1xi8>
+ return %0 : vector<1xi8>
+}
+
+// Negative test - the trailing dim 2 is not a multiple of 4 (i.e. 8 / 2).
+// CHECK-LABEL: func.func @unaligned_extsi_i2_to_i8(
+func.func @unaligned_extsi_i2_to_i8(%a: vector<2xi2>) -> vector<2xi8> {
+ // CHECK-NOT: arith.bitcast
+ // CHECK: arith.extsi %[[IN:.*]] : vector<2xi2> to vector<2xi8>
+ %0 = arith.extsi %a : vector<2xi2> to vector<2xi8>
+ return %0 : vector<2xi8>
+}
+
// CHECK-LABEL: func.func @aligned_extsi_i4_to_i8(
func.func @aligned_extsi_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
@@ -206,6 +225,31 @@ func.func @aligned_extsi_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
return %0 : vector<8xi8>
}
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i8(
+func.func @aligned_extsi_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// Extract bits 0-1
+// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
+// Extract bits 2-3
+// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
+// Extract bits 4-5
+// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
+// Extract bits 6-7
+// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+ %0 = arith.extsi %a : vector<8xi2> to vector<8xi8>
+ return %0 : vector<8xi8>
+}
+
// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32(
func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
@@ -220,8 +264,34 @@ func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
return %0 : vector<8xi32>
}
-// CHECK-LABEL: func.func @aligned_extsi_2d(
-func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32(
+func.func @aligned_extsi_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// Extract bits 0-1
+// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
+// Extract bits 2-3
+// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
+// Extract bits 4-5
+// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
+// Extract bits 6-7
+// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+// CHECK: %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+ %0 = arith.extsi %a : vector<8xi2> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32_2d(
+func.func @aligned_extsi_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
@@ -234,6 +304,32 @@ func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
return %0 : vector<8x32xi32>
}
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32_2d(
+func.func @aligned_extsi_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
+// Extract bits 0-1
+// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<8x8xi8>
+// Extract bits 2-3
+// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
+// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<8x8xi8>
+// Extract bits 4-5
+// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
+// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<8x8xi8>
+// Extract bits 6-7
+// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
+// CHECK: %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+ %0 = arith.extsi %a : vector<8x32xi2> to vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
// CHECK-LABEL: func.func @aligned_trunci_i8_to_i4(
func.func @aligned_trunci_i8_to_i4(%a: vector<8xi8>) -> vector<8xi4> {
@@ -292,6 +388,13 @@ func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
return %0 : vector<3x8x32xi4>
}
+func.func @aligned_trunci_i8_to_i2_no_match(%a: vector<8xi8>) -> vector<8xi2> {
+ // CHECK-NOT: arith.bitcast
+ // CHECK: arith.trunci %[[IN:.*]] : vector<8xi8> to vector<8xi2>
+ %0 = arith.trunci %a : vector<8xi8> to vector<8xi2>
+ return %0 : vector<8xi2>
+}
+
// CHECK-LABEL: func.func @aligned_extui_i4_to_i8(
func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
@@ -305,6 +408,31 @@ func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
return %0 : vector<8xi8>
}
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i8(
+func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// Extract bits 0-1
+// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
+// Extract bits 2-3
+// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
+// Extract bits 4-5
+// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
+// Extract bits 6-7
+// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+ %0 = arith.extui %a : vector<8xi2> to vector<8xi8>
+ return %0 : vector<8xi8>
+}
+
// CHECK-LABEL: func.func @aligned_extui_i4_to_i32(
func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
@@ -319,8 +447,34 @@ func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
return %0 : vector<8xi32>
}
-// CHECK-LABEL: func.func @aligned_extui_2d(
-func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i32(
+func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// Extract bits 0-1
+// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
+// Extract bits 2-3
+// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
+// Extract bits 4-5
+// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
+// Extract bits 6-7
+// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+// CHECK: %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+ %0 = arith.extui %a : vector<8xi2> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i4_to_i32_2d(
+func.func @aligned_extui_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
@@ -333,6 +487,32 @@ func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
return %0 : vector<8x32xi32>
}
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i32_2d(
+func.func @aligned_extui_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<8x8xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
+// Extract bits 0-1
+// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// Extract bits 2-3
+// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
+// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// Extract bits 4-5
+// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
+// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// Extract bits 6-7
+// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
+// CHECK: %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+ %0 = arith.extui %a : vector<8x32xi2> to vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
// CHECK-LABEL: func.func @aligned_sitofp(
func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
More information about the Mlir-commits
mailing list