[Mlir-commits] [mlir] 6dfaecf - [mlir][Vector] Add patterns for efficient unsigned i4 -> i8 conversion emulation (#89131)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 1 09:32:25 PDT 2024
Author: Kojo Acquah
Date: 2024-05-01T12:32:20-04:00
New Revision: 6dfaecf077ade4bf003345501fdcfcebc8409ff7
URL: https://github.com/llvm/llvm-project/commit/6dfaecf077ade4bf003345501fdcfcebc8409ff7
DIFF: https://github.com/llvm/llvm-project/commit/6dfaecf077ade4bf003345501fdcfcebc8409ff7.diff
LOG: [mlir][Vector] Add patterns for efficient unsigned i4 -> i8 conversion emulation (#89131)
This PR builds on https://github.com/llvm/llvm-project/pull/79494 with an additional path for efficient unsigned `i4 ->i8` type extension for 1D/2D operations. This will impact any i4 -> i8/i16/i32/i64 unsigned extensions as well as sitofp i4 -> f8/f16/f32/f64.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index d24721f3defa65..a301b919dc5232 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -880,6 +880,38 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
+/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ assert(srcVecType.getElementType().isSignlessInteger(4) &&
+ "Expected i4 type");
+
+ // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
+ SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+ constexpr int64_t i4Toi8BitwidthFactor = 2;
+ i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
+ auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+ Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+ // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
+ // byte are placed in one vector and the high i4 elements in another vector.
+ constexpr uint8_t lowBitsMask = 15; // Equivalent to [00001111] bit mask
+ auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
+ Value low = rewriter.create<arith::AndIOp>(loc, i8VecType, i8Vector,
+ lowBitsMaskValues);
+ constexpr int8_t highBitsToShift = 4;
+ auto highShiftValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
+ Value high = rewriter.create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
+
+ // 3. Interleave low and high i8 elements.
+ return rewriter.create<vector::InterleaveOp>(loc, low, high);
+}
+
/// Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
/// that take advantage of high-level information to avoid leaving LLVM to
/// scramble with peephole optimizations.
@@ -1048,9 +1080,10 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
/// bitwise ops that take advantage of high-level information to avoid leaving
-/// LLVM to scramble with peephole optimizations.
+/// LLVM to scramble with peephole optimizations. Templated to choose between
+/// signed and unsigned conversions.
///
-/// For example:
+/// For example (signed):
/// arith.extsi %in : vector<8xi4> to vector<8xi32>
/// is rewriten as
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
@@ -1069,16 +1102,25 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
/// %4 = vector.interleave %2, %3 : vector<4xi8>
/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
///
-template <typename ConversionOpType>
-struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
+/// Example (unsigned):
+/// arith.extui %in : vector<8xi4> to vector<8xi32>
+/// is rewritten as
+/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+/// %1 = arith.andi %0, 15 : vector<4xi8>
+/// %2 = arith.shrui %0, 4 : vector<4xi8>
+/// %3 = vector.interleave %1, %2 : vector<4xi8>
+/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32>
+///
+template <typename ConversionOpType, bool isSigned>
+struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
using OpRewritePattern<ConversionOpType>::OpRewritePattern;
LogicalResult matchAndRewrite(ConversionOpType conversionOp,
PatternRewriter &rewriter) const override {
// Verify the preconditions.
Value srcValue = conversionOp.getIn();
- auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
- auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
+ auto srcVecType = cast<VectorType>(srcValue.getType());
+ auto dstVecType = cast<VectorType>(conversionOp.getType());
if (failed(
commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
return failure();
@@ -1089,8 +1131,14 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
return failure();
// Perform the rewrite.
- Value subByteExt =
- rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ Value subByteExt;
+ if (isSigned) {
+ subByteExt =
+ rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ } else {
+ subByteExt =
+ rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ }
// Finalize the rewrite.
rewriter.replaceOpWithNewOp<ConversionOpType>(
@@ -1229,10 +1277,12 @@ void vector::populateVectorNarrowTypeRewritePatterns(
// Patterns for aligned cases. We set higher priority as they are expected to
// generate better performance for aligned cases.
- patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
- RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
+ patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>,
+ RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
benefit.getBenefit() + 1);
+ patterns.add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>>(
+ patterns.getContext(), benefit.getBenefit() + 1);
}
void vector::populateVectorTransposeNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 8f0148119806c9..614b2d4945348b 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -324,6 +324,47 @@ func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
return %0 : vector<16x8xi7>
}
+// CHECK-LABEL: func.func @aligned_extui(
+func.func @aligned_extui(%a: vector<8xi4>) -> vector<8xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+ %0 = arith.extui %a : vector<8xi4> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_2d(
+func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[VAL_0]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK: %[[I32:.*]] = arith.extui %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+ %0 = arith.extui %a : vector<8x32xi4> to vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_base_case(
+func.func @aligned_extui_base_case(%a: vector<8xi4>) -> vector<8xi8> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+ %0 = arith.extui %a : vector<8xi4> to vector<8xi8>
+ return %0 : vector<8xi8>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
@@ -335,4 +376,3 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-
More information about the Mlir-commits
mailing list