[flang-commits] [llvm] [openmp] [flang] [mlir] [lldb] [mlir][Vector] Add patterns for efficient i4 -> i8 conversion emulation (PR #79494)
Diego Caballero via flang-commits
flang-commits at lists.llvm.org
Thu Jan 25 16:57:38 PST 2024
https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/79494
>From b8fb65dd1e65c36cfb2104e5f35179faa6011552 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Thu, 25 Jan 2024 02:39:14 +0000
Subject: [PATCH] [mlir][Vector] Add patterns for efficient i4 -> i8 conversion
emulation
This PR adds new patterns to improve the generated vector code for the
emulation of any conversion that have to go through an i4 -> i8 type
extension (only signed extensions are supported for now). This will
impact any i4 -> i8/i16/i32/i64 signed extensions as well as sitofp i4
-> f8/f16/f32/f64.
The asm code generated for the supported cases is significantly better
after this PR for both x86 and aarch64.
---
.../Transforms/VectorEmulateNarrowType.cpp | 176 ++++++++++++++++--
.../Vector/vector-rewrite-narrow-types.mlir | 33 ++++
2 files changed, 189 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index a4a72754ccc250..8abd34fd246224 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -642,9 +642,9 @@ struct BitCastRewriter {
BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
- /// Verify that the preconditions for the rewrite are met.
- LogicalResult precondition(PatternRewriter &rewriter,
- VectorType preconditionVectorType, Operation *op);
+ /// Verify that general preconditions for the rewrite are met.
+ LogicalResult commonPrecondition(PatternRewriter &rewriter,
+ VectorType preconditionType, Operation *op);
/// Precompute the metadata for the rewrite.
SmallVector<BitCastRewriter::Metadata>
@@ -652,9 +652,9 @@ struct BitCastRewriter {
/// Rewrite one step of the sequence:
/// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
- Value rewriteStep(PatternRewriter &rewriter, Location loc, Value initialValue,
- Value runningResult,
- const BitCastRewriter::Metadata &metadata);
+ Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
+ Value initialValue, Value runningResult,
+ const BitCastRewriter::Metadata &metadata);
private:
/// Underlying enumerator that encodes the provenance of the bits in the each
@@ -719,21 +719,57 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
LDBG("\n" << enumerator.sourceElementRanges);
}
-LogicalResult BitCastRewriter::precondition(PatternRewriter &rewriter,
- VectorType precondition,
- Operation *op) {
- if (precondition.getRank() != 1 || precondition.isScalable())
+/// Verify that the precondition type meets the common preconditions for any
+/// conversion.
+static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
+ VectorType preconditionType,
+ Operation *op) {
+ if (!preconditionType || preconditionType.getRank() != 1 ||
+ preconditionType.isScalable())
return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
// TODO: consider relaxing this restriction in the future if we find ways
// to really work with subbyte elements across the MLIR/LLVM boundary.
- int64_t resultBitwidth = precondition.getElementTypeBitWidth();
+ unsigned resultBitwidth = preconditionType.getElementTypeBitWidth();
if (resultBitwidth % 8 != 0)
return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
return success();
}
+LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
+ VectorType preconditionType,
+ Operation *op) {
+ if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
+ return rewriter.notifyMatchFailure(op, "types are not vector");
+
+ return commonConversionPrecondition(rewriter, preconditionType, op);
+}
+
+/// Verify that source and destination element types meet the precondition for
+/// the supported aligned conversion cases. Alignment means that the either the
+/// source element type is multiple of the destination element type or the other
+/// way around.
+///
+/// NOTE: This method assumes that common conversion preconditions are met.
+static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
+ VectorType srcType,
+ VectorType dstType,
+ Operation *op) {
+ if (!srcType || !dstType)
+ return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
+ unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
+ unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
+ unsigned byteBitwidth = 8;
+
+ // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
+ if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
+ (dstElemBitwidth % srcElemBitwidth) != 0)
+ return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
+
+ return success();
+}
+
SmallVector<BitCastRewriter::Metadata>
BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
SmallVector<BitCastRewriter::Metadata> result;
@@ -775,9 +811,9 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
return result;
}
-Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
- Value initialValue, Value runningResult,
- const BitCastRewriter::Metadata &metadata) {
+Value BitCastRewriter::genericRewriteStep(
+ PatternRewriter &rewriter, Location loc, Value initialValue,
+ Value runningResult, const BitCastRewriter::Metadata &metadata) {
// Create vector.shuffle from the metadata.
auto shuffleOp = rewriter.create<vector::ShuffleOp>(
loc, initialValue, initialValue, metadata.shuffles);
@@ -810,6 +846,44 @@ Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
return runningResult;
}
+/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ assert(srcVecType.getElementType().isSignlessInteger(4) &&
+ "Expected i4 type");
+
+ // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
+ int64_t vecDimSize = srcVecType.getShape().back();
+ SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+ constexpr int64_t i4Toi8BitwidthFactor = 2;
+ i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
+ auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+ Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+ // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
+ // byte are place in one vector and the high i4 elements in another vector.
+ constexpr int8_t bitsToShift = 4;
+ auto shiftValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(i8VecType, bitsToShift));
+ Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
+ Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
+ Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
+
+ // 3. Interleave low and high i8 elements using a shuffle.
+ SmallVector<int64_t> interleaveMaskValues;
+ interleaveMaskValues.reserve(vecDimSize);
+ for (int i = 0, end = vecDimSize / 2; i < end; ++i) {
+ interleaveMaskValues.push_back(i);
+ interleaveMaskValues.push_back(i + (vecDimSize / 2));
+ }
+
+ return rewriter.create<vector::ShuffleOp>(
+ loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues));
+}
+
namespace {
/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
/// advantage of high-level information to avoid leaving LLVM to scramble with
@@ -829,7 +903,7 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
VectorType sourceVectorType = bitCastOp.getSourceVectorType();
VectorType targetVectorType = bitCastOp.getResultVectorType();
BitCastRewriter bcr(sourceVectorType, targetVectorType);
- if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp)))
+ if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
return failure();
// Perform the rewrite.
@@ -839,8 +913,8 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
Value runningResult;
for (const BitCastRewriter ::Metadata &metadata :
bcr.precomputeMetadata(shuffledElementType)) {
- runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
- runningResult, metadata);
+ runningResult = bcr.genericRewriteStep(
+ rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
}
// Finalize the rewrite.
@@ -893,7 +967,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
VectorType sourceVectorType = bitCastOp.getSourceVectorType();
VectorType targetVectorType = bitCastOp.getResultVectorType();
BitCastRewriter bcr(sourceVectorType, targetVectorType);
- if (failed(bcr.precondition(
+ if (failed(bcr.commonPrecondition(
rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
return failure();
@@ -904,8 +978,8 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
for (const BitCastRewriter::Metadata &metadata :
bcr.precomputeMetadata(shuffledElementType)) {
- runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(),
- sourceValue, runningResult, metadata);
+ runningResult = bcr.genericRewriteStep(
+ rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
}
// Finalize the rewrite.
@@ -923,6 +997,62 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
return success();
}
};
+
+/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+///
+/// For example:
+/// arith.extsi %in : vector<8xi4> to vector<8xi32>
+/// is rewriten as
+/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+/// %1 = arith.shli %0, 4 : vector<4xi8>
+/// %2 = arith.shrsi %1, 4 : vector<4xi8>
+/// %3 = arith.shrsi %0, 4 : vector<4xi8>
+/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
+/// : vector<4xi8>, vector<4xi8>
+/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
+///
+/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
+/// is rewriten as
+/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+/// %1 = arith.shli %0, 4 : vector<4xi8>
+/// %2 = arith.shrsi %1, 4 : vector<4xi8>
+/// %3 = arith.shrsi %0, 4 : vector<4xi8>
+/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
+/// : vector<4xi8>, vector<4xi8>
+/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
+///
+template <typename ConversionOpType>
+struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
+ using OpRewritePattern<ConversionOpType>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ConversionOpType conversionOp,
+ PatternRewriter &rewriter) const override {
+ // Set up the BitCastRewriter and verify the preconditions.
+ Value srcValue = conversionOp.getIn();
+ auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
+ auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
+ if (failed(
+ commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
+ return failure();
+
+ // Check general alignment preconditions.
+ if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
+ conversionOp)))
+ return failure();
+
+ // Perform the rewrite.
+ Value subByteExt =
+ rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+
+ // Finalize the rewrite.
+ rewriter.replaceOpWithNewOp<ConversionOpType>(
+ conversionOp, conversionOp.getType(), subByteExt);
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -944,4 +1074,10 @@ void vector::populateVectorNarrowTypeRewritePatterns(
patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
benefit);
+
+ // Patterns for aligned cases. We set higher priority as they are expected to
+ // generate better performance for aligned cases.
+ patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
+ RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
+ patterns.getContext(), benefit.getBenefit() + 1);
}
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index a600fa955b1700..c4fbb4c219b917 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -193,6 +193,39 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
return %1 : vector<8xi17>
}
+// CHECK-LABEL: func.func @aligned_extsi(
+func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
+ // CHECK: arith.shli
+ // CHECK: arith.shrsi
+ // CHECK: arith.shrsi
+ // CHECK: vector.shuffle
+ // CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32>
+ %0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_base_case(
+func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
+ // CHECK: arith.shli
+ // CHECK: arith.shrsi
+ // CHECK: arith.shrsi
+ // CHECK: vector.shuffle
+ // CHECK-NOT: arith.extsi
+ %0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
+ return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_sitofp(
+func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
+ // CHECK: arith.shli
+ // CHECK: arith.shrsi
+ // CHECK: arith.shrsi
+ // CHECK: shuffle
+ // CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32>
+ %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
More information about the flang-commits
mailing list