[Mlir-commits] [mlir] Rewrites for I2 to I8 signed and unsigned extension (PR #121298)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 7 10:24:40 PST 2025
https://github.com/ziereis updated https://github.com/llvm/llvm-project/pull/121298
>From 25333f82f3fd4517f34aea65e57121610c8c6634 Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Sun, 29 Dec 2024 16:35:54 +0100
Subject: [PATCH 1/5] init
---
.../Transforms/VectorEmulateNarrowType.cpp | 137 +++++++++++++++-
.../Vector/vector-rewrite-narrow-types.mlir | 149 +++++++++++++++++-
2 files changed, 276 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 181c394edc1d20..519e6e68bc1b0a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1084,8 +1084,8 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
- // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
- if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
+ // Only {s}i4/i2 -> (size_of({{s}i/f}) >= 8) are supported for now.
+ if ((srcElemBitwidth != 4 && srcElemBitwidth != 2) || dstElemBitwidth < 8 ||
(dstElemBitwidth % srcElemBitwidth) != 0)
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
@@ -1233,6 +1233,117 @@ static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
+/// Rewrite the i2 -> i8 signed extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ assert(srcVecType.getElementType().isSignlessInteger(2) &&
+ "Expected i2 type");
+
+ // 1. Generate a bitcast vector<Xxi2> -> vector<X/4xi8>.
+ SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+ constexpr int64_t i2Toi8BitwidthFactor = 4;
+ i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
+ auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+ Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+ // Position 0 (bits 0-1)
+ constexpr int8_t shiftConst6 = 6;
+ auto shiftAttr6 = DenseElementsAttr::get(i8VecType, shiftConst6);
+ auto shiftValues6 = rewriter.create<arith::ConstantOp>(loc, shiftAttr6);
+ Value shl0 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues6);
+ Value elem0 = rewriter.create<arith::ShRSIOp>(loc, shl0, shiftValues6);
+
+ // Position 1 (bits 2-3)
+ constexpr int8_t shiftConst4 = 4;
+ auto shiftAttr4 = DenseElementsAttr::get(i8VecType, shiftConst4);
+ auto shiftValues4 = rewriter.create<arith::ConstantOp>(loc, shiftAttr4);
+ Value shl1 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues4);
+ Value elem1 = rewriter.create<arith::ShRSIOp>(loc, shl1, shiftValues6);
+
+ // Position 1 (bits 4-5)
+ constexpr int8_t shiftConst2 = 2;
+ auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shiftConst2);
+ auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
+ Value shl2 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues2);
+ Value elem2 = rewriter.create<arith::ShRSIOp>(loc, shl2, shiftValues6);
+
+ // Position 3 (bits 6-7)
+ Value elem3 = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues6);
+
+ // interleave all 4 elements by first interleaving even elements and then odd
+ // elem0 = [0,0,0,0]
+ // elem1 = [1,1,1,1]
+ // elem2 = [2,2,2,2]
+ // elem3 = [3,3,3,3]
+ // 02 = [0,2,0,2]
+ // 13 = [1,3,1,3]
+ // 0213 = [0,1,2,3]
+ Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, elem0, elem2);
+ Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, elem1, elem3);
+ return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
+}
+
+/// Rewrite the i2 -> i8 unsigned extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ assert(srcVecType.getElementType().isSignlessInteger(2) &&
+ "Expected i2 type");
+
+ // 1. Generate a bitcast vector<Xxi2> -> vector<X/4xi8>.
+ SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+ constexpr int64_t i2Toi8BitwidthFactor = 4;
+ i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
+ auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+ Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+ // 2. Extract each i2 element using shifts and masks
+ constexpr uint8_t mask = 3; // Mask for 2 bits: [0000 0011]
+ auto maskAttr = DenseElementsAttr::get(i8VecType, mask);
+ auto maskValues = rewriter.create<arith::ConstantOp>(loc, maskAttr);
+
+ // Element 0 (bits 0-1)
+ Value elem0 = rewriter.create<arith::AndIOp>(loc, i8Vector, maskValues);
+
+ // Element 1 (bits 2-3)
+ constexpr int8_t shift1 = 2;
+ auto shiftAttr1 = DenseElementsAttr::get(i8VecType, shift1);
+ auto shiftValues1 = rewriter.create<arith::ConstantOp>(loc, shiftAttr1);
+ Value shifted1 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues1);
+ Value elem1 = rewriter.create<arith::AndIOp>(loc, shifted1, maskValues);
+
+ // Element 2 (bits 4-5)
+ constexpr int8_t shift2 = 4;
+ auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shift2);
+ auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
+ Value shifted2 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues2);
+ Value elem2 = rewriter.create<arith::AndIOp>(loc, shifted2, maskValues);
+
+ // Element 3 (bits 6-7)
+ constexpr int8_t shift3 = 6;
+ auto shiftAttr3 = DenseElementsAttr::get(i8VecType, shift3);
+ auto shiftValues3 = rewriter.create<arith::ConstantOp>(loc, shiftAttr3);
+ Value shifted3 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues3);
+ Value elem3 = rewriter.create<arith::AndIOp>(loc, shifted3, maskValues);
+
+ // interleave all 4 elements by first interleaving even elements and then odd
+ // elem0 = [0,0,0,0]
+ // elem1 = [1,1,1,1]
+ // elem2 = [2,2,2,2]
+ // elem3 = [3,3,3,3]
+ // 02 = [0,2,0,2]
+ // 13 = [1,3,1,3]
+ // 0213 = [0,1,2,3]
+ Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, elem0, elem2);
+ Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, elem1, elem3);
+ return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
+}
+
/// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
/// ops that take advantage of high-level information to avoid leaving LLVM to
/// scramble with peephole optimizations.
@@ -1438,11 +1549,21 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
// Perform the rewrite.
Value subByteExt;
if (isSigned) {
- subByteExt =
- rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ if (srcVecType.getElementType().getIntOrFloatBitWidth() == 2)
+ subByteExt =
+ rewriteI2ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ else {
+ subByteExt =
+ rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ }
} else {
- subByteExt =
- rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ if (srcVecType.getElementType().getIntOrFloatBitWidth() == 2) {
+ subByteExt =
+ rewriteI2ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ } else {
+ subByteExt =
+ rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ }
}
// Finalize the rewrite.
@@ -1489,6 +1610,10 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
truncOp)))
return failure();
+ // not supported currently.
+ if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
+ return failure();
+
// Create a new iX -> i8 truncation op.
Location loc = truncOp.getLoc();
auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 210025e30d7db5..0b469066f290c2 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -220,8 +220,8 @@ func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
return %0 : vector<8xi32>
}
-// CHECK-LABEL: func.func @aligned_extsi_2d(
-func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-LABEL: func.func @aligned_extsi_i4_2d(
+func.func @aligned_extsi_i4_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
@@ -234,6 +234,72 @@ func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
return %0 : vector<8x32xi32>
}
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i8(
+func.func @aligned_extsi_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+ %0 = arith.extsi %a : vector<8xi2> to vector<8xi8>
+ return %0 : vector<8xi8>
+}
+
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32(
+func.func @aligned_extsi_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+// CHECK: %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+ %0 = arith.extsi %a : vector<8xi2> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_2d(
+func.func @aligned_extsi_i2_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
+// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
+// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
+// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
+// CHECK: %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+ %0 = arith.extsi %a : vector<8x32xi2> to vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
// CHECK-LABEL: func.func @aligned_trunci_i8_to_i4(
func.func @aligned_trunci_i8_to_i4(%a: vector<8xi8>) -> vector<8xi4> {
@@ -292,6 +358,13 @@ func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
return %0 : vector<3x8x32xi4>
}
+func.func @aligned_trunci_i8_to_i2_no_match(%a: vector<8xi8>) -> vector<8xi2> {
+ // CHECK-NOT: arith.bitcast
+ // CHECK: arith.trunci %[[IN:.*]] : vector<8xi8> to vector<8xi2>
+ %0 = arith.trunci %a : vector<8xi8> to vector<8xi2>
+ return %0 : vector<8xi2>
+}
+
// CHECK-LABEL: func.func @aligned_extui_i4_to_i8(
func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
@@ -319,8 +392,8 @@ func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
return %0 : vector<8xi32>
}
-// CHECK-LABEL: func.func @aligned_extui_2d(
-func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-LABEL: func.func @aligned_extui_i4_2d(
+func.func @aligned_extui_i4_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
@@ -333,6 +406,74 @@ func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
return %0 : vector<8x32xi32>
}
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i8(
+func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+ %0 = arith.extui %a : vector<8xi2> to vector<8xi8>
+ return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i32(
+func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+// CHECK: %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+ %0 = arith.extui %a : vector<8xi2> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_2d(
+func.func @aligned_extui_i2_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<8x8xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
+// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
+// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
+// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK: %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
+// CHECK: %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+ %0 = arith.extui %a : vector<8x32xi2> to vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
// CHECK-LABEL: func.func @aligned_sitofp(
func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
>From 29bad6e225de28a4d050c243abfce0f8de5a715d Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Sun, 29 Dec 2024 16:47:18 +0100
Subject: [PATCH 2/5] fix typos in comments
---
.../lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 519e6e68bc1b0a..11cce3540ac436 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1242,7 +1242,7 @@ static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
assert(srcVecType.getElementType().isSignlessInteger(2) &&
"Expected i2 type");
- // 1. Generate a bitcast vector<Xxi2> -> vector<X/4xi8>.
+ // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
constexpr int64_t i2Toi8BitwidthFactor = 4;
i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
@@ -1295,7 +1295,7 @@ static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
assert(srcVecType.getElementType().isSignlessInteger(2) &&
"Expected i2 type");
- // 1. Generate a bitcast vector<Xxi2> -> vector<X/4xi8>.
+ // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
constexpr int64_t i2Toi8BitwidthFactor = 4;
i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
>From ec04f6028f735fb339e741701218d147018931ac Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Sun, 29 Dec 2024 16:48:25 +0100
Subject: [PATCH 3/5] fix typos in comments
---
.../Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 11cce3540ac436..323da627de7bc2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1249,28 +1249,28 @@ static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
- // Position 0 (bits 0-1)
+ // Element 0 (bits 0-1)
constexpr int8_t shiftConst6 = 6;
auto shiftAttr6 = DenseElementsAttr::get(i8VecType, shiftConst6);
auto shiftValues6 = rewriter.create<arith::ConstantOp>(loc, shiftAttr6);
Value shl0 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues6);
Value elem0 = rewriter.create<arith::ShRSIOp>(loc, shl0, shiftValues6);
- // Position 1 (bits 2-3)
+ // Element 1 (bits 2-3)
constexpr int8_t shiftConst4 = 4;
auto shiftAttr4 = DenseElementsAttr::get(i8VecType, shiftConst4);
auto shiftValues4 = rewriter.create<arith::ConstantOp>(loc, shiftAttr4);
Value shl1 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues4);
Value elem1 = rewriter.create<arith::ShRSIOp>(loc, shl1, shiftValues6);
- // Position 1 (bits 4-5)
+ // Element 2 (bits 4-5)
constexpr int8_t shiftConst2 = 2;
auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shiftConst2);
auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
Value shl2 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues2);
Value elem2 = rewriter.create<arith::ShRSIOp>(loc, shl2, shiftValues6);
- // Position 3 (bits 6-7)
+ // Element 3 (bits 6-7)
Value elem3 = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues6);
// interleave all 4 elements by first interleaving even elements and then odd
>From ba2e565123d4a2469a50fb70a916abbc5c96db12 Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Mon, 6 Jan 2025 12:44:34 +0100
Subject: [PATCH 4/5] refactoring: - extracts repeated code into functions -
reorder tests - improve naming
---
.../Transforms/VectorEmulateNarrowType.cpp | 236 +++++++++---------
.../Vector/vector-rewrite-narrow-types.mlir | 128 +++++-----
2 files changed, 178 insertions(+), 186 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 323da627de7bc2..a3b99f9098dad0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1172,70 +1172,62 @@ Value BitCastRewriter::genericRewriteStep(
return runningResult;
}
-/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
-/// bitwise ops that take advantage of high-level information to avoid leaving
-/// LLVM to scramble with peephole optimizations.
-static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
- Value srcValue) {
+Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
VectorType srcVecType = cast<VectorType>(srcValue.getType());
- assert(srcVecType.getElementType().isSignlessInteger(4) &&
- "Expected i4 type");
-
- // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
+ int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
+ assert(srcBitwidth % 8 != 0 && "Invalid source bitwidth");
+ int64_t bitwidthFactor = 8 / srcBitwidth;
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
- constexpr int64_t i4Toi8BitwidthFactor = 2;
- i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
+ i8VecShape.back() = i8VecShape.back() / bitwidthFactor;
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
- Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+ return rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+}
- // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
- // byte are place in one vector and the high i4 elements in another vector.
- constexpr int8_t bitsToShift = 4;
- auto shiftValues = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(i8VecType, bitsToShift));
- Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
- Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
- Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
+/// Extracts a signed N-bit sequence from each element of an 8-bit vector,
+/// starting at the specified bit index.
+Value extractNBitsFromVectorSigned(PatternRewriter &rewriter, Location loc,
+ Value src, int bitIdx, int numBits) {
+ assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
+ "Invalid bitIdx range");
+ auto srcType = cast<VectorType>(src.getType());
+ Value shl = src;
+ int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
+ if (bitsToShiftLeft != 0) {
+ Value shiftLeftValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
+ shl = rewriter.create<arith::ShLIOp>(loc, src, shiftLeftValues);
+ }
- // 3. Interleave low and high i8 elements.
- return rewriter.create<vector::InterleaveOp>(loc, low, high);
+ int8_t bitsToShiftRight = 8 - numBits;
+ Value shiftRightValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
+ Value shr = rewriter.create<arith::ShRSIOp>(loc, shl, shiftRightValues);
+ return shr;
}
-/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
-/// bitwise ops that take advantage of high-level information to avoid leaving
-/// LLVM to scramble with peephole optimizations.
-static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
- Value srcValue) {
+/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
+/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
+static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
VectorType srcVecType = cast<VectorType>(srcValue.getType());
assert(srcVecType.getElementType().isSignlessInteger(4) &&
"Expected i4 type");
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
- SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
- constexpr int64_t i4Toi8BitwidthFactor = 2;
- i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
- auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
- Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+ Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
- // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
- // byte are placed in one vector and the high i4 elements in another vector.
- constexpr uint8_t lowBitsMask = 15; // Equivalent to [00001111] bit mask
- auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
- Value low = rewriter.create<arith::AndIOp>(loc, i8VecType, i8Vector,
- lowBitsMaskValues);
- constexpr int8_t highBitsToShift = 4;
- auto highShiftValues = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
- Value high = rewriter.create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
+ // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
+ // byte are place in one vector and the high i4 elements in another vector.
+ Value low = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 0, 4);
+ Value high = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 4, 4);
// 3. Interleave low and high i8 elements.
return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
/// Rewrite the i2 -> i8 signed extension into a sequence of shuffles and
-/// bitwise ops that take advantage of high-level information to avoid leaving
-/// LLVM to scramble with peephole optimizations.
+/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
Value srcValue) {
VectorType srcVecType = cast<VectorType>(srcValue.getType());
@@ -1243,41 +1235,20 @@ static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
"Expected i2 type");
// 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
- SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
- constexpr int64_t i2Toi8BitwidthFactor = 4;
- i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
- auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
- Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+ Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
+ // 2. Extract each i2 element using shifts
// Element 0 (bits 0-1)
- constexpr int8_t shiftConst6 = 6;
- auto shiftAttr6 = DenseElementsAttr::get(i8VecType, shiftConst6);
- auto shiftValues6 = rewriter.create<arith::ConstantOp>(loc, shiftAttr6);
- Value shl0 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues6);
- Value elem0 = rewriter.create<arith::ShRSIOp>(loc, shl0, shiftValues6);
-
+ Value elem0 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 0, 2);
// Element 1 (bits 2-3)
- constexpr int8_t shiftConst4 = 4;
- auto shiftAttr4 = DenseElementsAttr::get(i8VecType, shiftConst4);
- auto shiftValues4 = rewriter.create<arith::ConstantOp>(loc, shiftAttr4);
- Value shl1 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues4);
- Value elem1 = rewriter.create<arith::ShRSIOp>(loc, shl1, shiftValues6);
-
+ Value elem1 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 2, 2);
// Element 2 (bits 4-5)
- constexpr int8_t shiftConst2 = 2;
- auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shiftConst2);
- auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
- Value shl2 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues2);
- Value elem2 = rewriter.create<arith::ShRSIOp>(loc, shl2, shiftValues6);
-
+ Value elem2 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 4, 2);
// Element 3 (bits 6-7)
- Value elem3 = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues6);
+ Value elem3 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 6, 2);
- // interleave all 4 elements by first interleaving even elements and then odd
- // elem0 = [0,0,0,0]
- // elem1 = [1,1,1,1]
- // elem2 = [2,2,2,2]
- // elem3 = [3,3,3,3]
+ // 3. Interleave all 4 elements by first interleaving even elements and then
+ // odd elem0 = [0,0,0,0] elem1 = [1,1,1,1] elem2 = [2,2,2,2] elem3 = [3,3,3,3]
// 02 = [0,2,0,2]
// 13 = [1,3,1,3]
// 0213 = [0,1,2,3]
@@ -1286,9 +1257,51 @@ static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
}
+/// Extracts an unsigned N-bit sequence from each element of an 8-bit vector,
+/// starting at the specified bit index.
+Value extractNBitsFromVectorUnsinged(PatternRewriter &rewriter, Location loc,
+ Value src, int bitIdx, int numBits) {
+ assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
+ "Invalid bitIdx range");
+ auto srcType = cast<VectorType>(src.getType());
+ int8_t bitsToShiftRight = bitIdx;
+ Value shr = src;
+ if (bitsToShiftRight != 0) {
+ Value shiftRightValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
+ shr = rewriter.create<arith::ShRUIOp>(loc, src, shiftRightValues);
+ }
+ if (bitIdx + numBits == 8) {
+ return shr;
+ }
+ uint8_t lowBitsMask = (1 << numBits) - 1;
+ Value lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(srcType, lowBitsMask));
+ return rewriter.create<arith::AndIOp>(loc, shr, lowBitsMaskValues);
+}
+
+/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
+/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
+static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ assert(srcVecType.getElementType().isSignlessInteger(4) &&
+ "Expected i4 type");
+
+ // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
+ Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
+
+ // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
+ // byte are placed in one vector and the high i4 elements in another vector.
+ Value low = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 0, 4);
+ Value high = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 4, 4);
+
+ // 3. Interleave low and high i8 elements.
+ return rewriter.create<vector::InterleaveOp>(loc, low, high);
+}
+
/// Rewrite the i2 -> i8 unsigned extension into a sequence of shuffles and
-/// bitwise ops that take advantage of high-level information to avoid leaving
-/// LLVM to scramble with peephole optimizations.
+/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
Value srcValue) {
VectorType srcVecType = cast<VectorType>(srcValue.getType());
@@ -1296,46 +1309,20 @@ static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
"Expected i2 type");
// 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
- SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
- constexpr int64_t i2Toi8BitwidthFactor = 4;
- i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
- auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
- Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+ Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
// 2. Extract each i2 element using shifts and masks
- constexpr uint8_t mask = 3; // Mask for 2 bits: [0000 0011]
- auto maskAttr = DenseElementsAttr::get(i8VecType, mask);
- auto maskValues = rewriter.create<arith::ConstantOp>(loc, maskAttr);
-
// Element 0 (bits 0-1)
- Value elem0 = rewriter.create<arith::AndIOp>(loc, i8Vector, maskValues);
-
+ Value elem0 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 0, 2);
// Element 1 (bits 2-3)
- constexpr int8_t shift1 = 2;
- auto shiftAttr1 = DenseElementsAttr::get(i8VecType, shift1);
- auto shiftValues1 = rewriter.create<arith::ConstantOp>(loc, shiftAttr1);
- Value shifted1 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues1);
- Value elem1 = rewriter.create<arith::AndIOp>(loc, shifted1, maskValues);
-
+ Value elem1 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 2, 2);
// Element 2 (bits 4-5)
- constexpr int8_t shift2 = 4;
- auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shift2);
- auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
- Value shifted2 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues2);
- Value elem2 = rewriter.create<arith::AndIOp>(loc, shifted2, maskValues);
-
+ Value elem2 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 4, 2);
// Element 3 (bits 6-7)
- constexpr int8_t shift3 = 6;
- auto shiftAttr3 = DenseElementsAttr::get(i8VecType, shift3);
- auto shiftValues3 = rewriter.create<arith::ConstantOp>(loc, shiftAttr3);
- Value shifted3 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues3);
- Value elem3 = rewriter.create<arith::AndIOp>(loc, shifted3, maskValues);
-
- // interleave all 4 elements by first interleaving even elements and then odd
- // elem0 = [0,0,0,0]
- // elem1 = [1,1,1,1]
- // elem2 = [2,2,2,2]
- // elem3 = [3,3,3,3]
+ Value elem3 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 6, 2);
+
+ // 3. Interleave all 4 elements by first interleaving even elements and then
+ // odd elem0 = [0,0,0,0] elem1 = [1,1,1,1] elem2 = [2,2,2,2] elem3 = [3,3,3,3]
// 02 = [0,2,0,2]
// 13 = [1,3,1,3]
// 0213 = [0,1,2,3]
@@ -1345,8 +1332,7 @@ static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
}
/// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
-/// ops that take advantage of high-level information to avoid leaving LLVM to
-/// scramble with peephole optimizations.
+/// ops to avoid leaving LLVM to scramble with peephole optimizations.
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
Value srcValue) {
VectorType srcVecType = cast<VectorType>(srcValue.getType());
@@ -1549,20 +1535,30 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
// Perform the rewrite.
Value subByteExt;
if (isSigned) {
- if (srcVecType.getElementType().getIntOrFloatBitWidth() == 2)
+ switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
+ case 2:
subByteExt =
rewriteI2ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
- else {
+ break;
+ case 4:
subByteExt =
rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ break;
+ default:
+ return failure();
}
} else {
- if (srcVecType.getElementType().getIntOrFloatBitWidth() == 2) {
+ switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
+ case 2:
subByteExt =
rewriteI2ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
- } else {
+ break;
+ case 4:
subByteExt =
rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ break;
+ default:
+ return failure();
}
}
@@ -1604,16 +1600,16 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
return failure();
+ // TODO: Add support for truncating to i2.
+ if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
+ return failure();
+
// Check general alignment preconditions. We invert the src/dst type order
// to reuse the existing precondition logic.
if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
truncOp)))
return failure();
- // not supported currently.
- if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
- return failure();
-
// Create a new iX -> i8 truncation op.
Location loc = truncOp.getLoc();
auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 0b469066f290c2..3d37fe4efa40c4 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -206,34 +206,6 @@ func.func @aligned_extsi_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
return %0 : vector<8xi8>
}
-// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32(
-func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
-// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
-// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
-// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
-// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
-// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
- %0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
- return %0 : vector<8xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_extsi_i4_2d(
-func.func @aligned_extsi_i4_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
-// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
-// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
-// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
-// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
- %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
- return %0 : vector<8x32xi32>
-}
-
// CHECK-LABEL: func.func @aligned_extsi_i2_to_i8(
func.func @aligned_extsi_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
@@ -255,6 +227,19 @@ func.func @aligned_extsi_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
return %0 : vector<8xi8>
}
+// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32(
+func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+ %0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32(
func.func @aligned_extsi_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
@@ -278,8 +263,22 @@ func.func @aligned_extsi_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
return %0 : vector<8xi32>
}
-// CHECK-LABEL: func.func @aligned_extsi_i2_2d(
-func.func @aligned_extsi_i2_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32_2d(
+func.func @aligned_extsi_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+ %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32_2d(
+func.func @aligned_extsi_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
@@ -378,34 +377,6 @@ func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
return %0 : vector<8xi8>
}
-// CHECK-LABEL: func.func @aligned_extui_i4_to_i32(
-func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
-// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
-// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
-// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
-// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
-// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
-// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
- %0 = arith.extui %a : vector<8xi4> to vector<8xi32>
- return %0 : vector<8xi32>
-}
-
-// CHECK-LABEL: func.func @aligned_extui_i4_2d(
-func.func @aligned_extui_i4_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
-// CHECK-SAME: %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
-// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
-// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
-// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8>
-// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
-// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
-// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
-// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
- %0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32>
- return %0 : vector<8x32xi32>
-}
-
// CHECK-LABEL: func.func @aligned_extui_i2_to_i8(
func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
@@ -419,8 +390,7 @@ func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
-// CHECK: %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
-// CHECK: %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
// CHECK: %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
@@ -428,6 +398,20 @@ func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
return %0 : vector<8xi8>
}
+// CHECK-LABEL: func.func @aligned_extui_i4_to_i32(
+func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+ %0 = arith.extui %a : vector<8xi4> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
// CHECK-LABEL: func.func @aligned_extui_i2_to_i32(
func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
@@ -441,8 +425,7 @@ func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
-// CHECK: %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
-// CHECK: %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
@@ -451,8 +434,22 @@ func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
return %0 : vector<8xi32>
}
-// CHECK-LABEL: func.func @aligned_extui_i2_2d(
-func.func @aligned_extui_i2_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-LABEL: func.func @aligned_extui_i4_to_i32_2d(
+func.func @aligned_extui_i4_to_i32_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+ %0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i32_2d(
+func.func @aligned_extui_i2_to_i32_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
@@ -464,8 +461,7 @@ func.func @aligned_extui_i2_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<8x8xi8>
// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<8x8xi8>
-// CHECK: %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
-// CHECK: %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK: %[[ELEM3:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
>From 7c6d76269e1093c804b8dbeea88a8484245846e6 Mon Sep 17 00:00:00 2001
From: Thomas Ziereis <ziereis at roofline.ai>
Date: Tue, 7 Jan 2025 19:24:25 +0100
Subject: [PATCH 5/5] refactoring
---
.../Transforms/VectorEmulateNarrowType.cpp | 189 ++++++++----------
1 file changed, 78 insertions(+), 111 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index a3b99f9098dad0..5cb833283245bc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1084,10 +1084,14 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
- // Only {s}i4/i2 -> (size_of({{s}i/f}) >= 8) are supported for now.
- if ((srcElemBitwidth != 4 && srcElemBitwidth != 2) || dstElemBitwidth < 8 ||
- (dstElemBitwidth % srcElemBitwidth) != 0)
- return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
+ if (dstElemBitwidth < 8)
+ return rewriter.notifyMatchFailure(
+ op, "the bitwidth of dstType must be greater than or equal to 8");
+ if (dstElemBitwidth % srcElemBitwidth != 0)
+ return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
+ if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
+ return rewriter.notifyMatchFailure(
+ op, "only src bitwidth of 2 or 4 is supported at this moment");
if ((srcType.getShape().back() % 2) != 0)
return rewriter.notifyMatchFailure(
@@ -1172,22 +1176,35 @@ Value BitCastRewriter::genericRewriteStep(
return runningResult;
}
-Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
- Value srcValue) {
- VectorType srcVecType = cast<VectorType>(srcValue.getType());
+/// takes a aligned subByte vector as Input and bitcasts it to a vector of i8.
+///
+/// Example:
+/// vector<16x16xi2> -> vector<16x2xi8>
+/// vector<16x16xi4> -> vector<16x4xi8>
+static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ auto srcVecType = cast<VectorType>(srcValue.getType());
int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
- assert(srcBitwidth % 8 != 0 && "Invalid source bitwidth");
+ assert(8 % srcBitwidth == 0 && "Invalid source bitwidth");
int64_t bitwidthFactor = 8 / srcBitwidth;
- SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
- i8VecShape.back() = i8VecShape.back() / bitwidthFactor;
- auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+ SmallVector<int64_t> vecShape(srcVecType.getShape());
+ // adjust last dimension of the vector so the total size remains the same.
+ vecShape.back() = vecShape.back() / bitwidthFactor;
+ auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
return rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
}
/// Extracts a signed N-bit sequence from each element of an 8-bit vector,
/// starting at the specified bit index.
-Value extractNBitsFromVectorSigned(PatternRewriter &rewriter, Location loc,
- Value src, int bitIdx, int numBits) {
+///
+/// Example:
+/// extract numBits=2 starting at bitIdx=2
+/// src = [0101|11|10]
+/// shl = src << 4 -> [11100000]
+/// result = shl >> 6 -> [11111111]
+static Value extractNBitsFromVectorSigned(PatternRewriter &rewriter,
+ Location loc, Value src, int bitIdx,
+ int numBits) {
assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
"Invalid bitIdx range");
auto srcType = cast<VectorType>(src.getType());
@@ -1206,61 +1223,18 @@ Value extractNBitsFromVectorSigned(PatternRewriter &rewriter, Location loc,
return shr;
}
-/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
-/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
-static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
- Value srcValue) {
- VectorType srcVecType = cast<VectorType>(srcValue.getType());
- assert(srcVecType.getElementType().isSignlessInteger(4) &&
- "Expected i4 type");
-
- // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
- Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
-
- // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
- // byte are place in one vector and the high i4 elements in another vector.
- Value low = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 0, 4);
- Value high = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 4, 4);
-
- // 3. Interleave low and high i8 elements.
- return rewriter.create<vector::InterleaveOp>(loc, low, high);
-}
-
-/// Rewrite the i2 -> i8 signed extension into a sequence of shuffles and
-/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
-static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
- Value srcValue) {
- VectorType srcVecType = cast<VectorType>(srcValue.getType());
- assert(srcVecType.getElementType().isSignlessInteger(2) &&
- "Expected i2 type");
-
- // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
- Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
-
- // 2. Extract each i2 element using shifts
- // Element 0 (bits 0-1)
- Value elem0 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 0, 2);
- // Element 1 (bits 2-3)
- Value elem1 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 2, 2);
- // Element 2 (bits 4-5)
- Value elem2 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 4, 2);
- // Element 3 (bits 6-7)
- Value elem3 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 6, 2);
-
- // 3. Interleave all 4 elements by first interleaving even elements and then
- // odd elem0 = [0,0,0,0] elem1 = [1,1,1,1] elem2 = [2,2,2,2] elem3 = [3,3,3,3]
- // 02 = [0,2,0,2]
- // 13 = [1,3,1,3]
- // 0213 = [0,1,2,3]
- Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, elem0, elem2);
- Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, elem1, elem3);
- return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
-}
-
/// Extracts an unsigned N-bit sequence from each element of an 8-bit vector,
/// starting at the specified bit index.
-Value extractNBitsFromVectorUnsinged(PatternRewriter &rewriter, Location loc,
- Value src, int bitIdx, int numBits) {
+///
+/// Example:
+/// extract numBits=2 starting at bitIdx=2
+/// src = [0101|10|10]
+/// mask = [00000011]
+/// shr = src >> 6 = [00010110]
+/// result = shr & mask = [00000010]
+static Value extractNBitsFromVectorUnsinged(PatternRewriter &rewriter,
+ Location loc, Value src, int bitIdx,
+ int numBits) {
assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
"Invalid bitIdx range");
auto srcType = cast<VectorType>(src.getType());
@@ -1280,30 +1254,33 @@ Value extractNBitsFromVectorUnsinged(PatternRewriter &rewriter, Location loc,
return rewriter.create<arith::AndIOp>(loc, shr, lowBitsMaskValues);
}
-/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
+using ExtractNBitsFn =
+ std::function<Value(PatternRewriter &, Location, Value, int, int)>;
+
+/// Rewrite the i4 -> i8 extension into a sequence of shuffles and
/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
-static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
- Value srcValue) {
- VectorType srcVecType = cast<VectorType>(srcValue.getType());
+static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc,
+ Value srcValue, const ExtractNBitsFn &extFn) {
+ auto srcVecType = cast<VectorType>(srcValue.getType());
assert(srcVecType.getElementType().isSignlessInteger(4) &&
"Expected i4 type");
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
- // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
- // byte are placed in one vector and the high i4 elements in another vector.
- Value low = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 0, 4);
- Value high = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 4, 4);
+ // 2. Extend i4 elements to i8 elements. Low i4 elemens of each
+ // byte are place in one vector and the high i4 elements in another vector.
+ Value low = extFn(rewriter, loc, i8Vector, 0, 4);
+ Value high = extFn(rewriter, loc, i8Vector, 4, 4);
// 3. Interleave low and high i8 elements.
return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
-/// Rewrite the i2 -> i8 unsigned extension into a sequence of shuffles and
+/// Rewrite the i2 -> i8 extension into a sequence of shuffles and
/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
-static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
- Value srcValue) {
+static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc,
+ Value srcValue, const ExtractNBitsFn &extFn) {
VectorType srcVecType = cast<VectorType>(srcValue.getType());
assert(srcVecType.getElementType().isSignlessInteger(2) &&
"Expected i2 type");
@@ -1311,18 +1288,22 @@ static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
// 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
- // 2. Extract each i2 element using shifts and masks
+ // 2. Extract each i2 element
// Element 0 (bits 0-1)
- Value elem0 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 0, 2);
+ Value elem0 = extFn(rewriter, loc, i8Vector, 0, 2);
// Element 1 (bits 2-3)
- Value elem1 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 2, 2);
+ Value elem1 = extFn(rewriter, loc, i8Vector, 2, 2);
// Element 2 (bits 4-5)
- Value elem2 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 4, 2);
+ Value elem2 = extFn(rewriter, loc, i8Vector, 4, 2);
// Element 3 (bits 6-7)
- Value elem3 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 6, 2);
-
- // 3. Interleave all 4 elements by first interleaving even elements and then
- // odd elem0 = [0,0,0,0] elem1 = [1,1,1,1] elem2 = [2,2,2,2] elem3 = [3,3,3,3]
+ Value elem3 = extFn(rewriter, loc, i8Vector, 6, 2);
+
+ // 3. Interleave all 4 elements by first interleaving
+ // even elements and then odd
+ // elem0 = [0,0,0,0]
+ // elem1 = [1,1,1,1]
+ // elem2 = [2,2,2,2]
+ // elem3 = [3,3,3,3]
// 02 = [0,2,0,2]
// 13 = [1,3,1,3]
// 0213 = [0,1,2,3]
@@ -1533,33 +1514,19 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
return failure();
// Perform the rewrite.
+ Location loc = conversionOp.getLoc();
+ const auto &extFn = isSigned ? extractNBitsFromVectorSigned
+ : extractNBitsFromVectorUnsinged;
Value subByteExt;
- if (isSigned) {
- switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
- case 2:
- subByteExt =
- rewriteI2ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
- break;
- case 4:
- subByteExt =
- rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
- break;
- default:
- return failure();
- }
- } else {
- switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
- case 2:
- subByteExt =
- rewriteI2ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
- break;
- case 4:
- subByteExt =
- rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
- break;
- default:
- return failure();
- }
+ switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
+ case 2:
+ subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn);
+ break;
+ case 4:
+ subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn);
+ break;
+ default:
+ return failure();
}
// Finalize the rewrite.
More information about the Mlir-commits
mailing list