[Mlir-commits] [mlir] 9d0acb8 - [mlir][Vector] Add support for trunci to narrow type emulation (#82565)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 27 15:27:35 PST 2024
Author: Diego Caballero
Date: 2024-02-27T15:27:31-08:00
New Revision: 9d0acb872a5063f570366cd0e94b069d286cc71f
URL: https://github.com/llvm/llvm-project/commit/9d0acb872a5063f570366cd0e94b069d286cc71f
DIFF: https://github.com/llvm/llvm-project/commit/9d0acb872a5063f570366cd0e94b069d286cc71f.diff
LOG: [mlir][Vector] Add support for trunci to narrow type emulation (#82565)
This PR add support for `arith.trunci` to vector narrow type emulation for iX -> i4 truncations, for X >= 8. For now, the pattern only works for 1D vectors and is based on `vector.shuffle` ops. We would need `vector.deinterleave` to add n-D vector support.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index fc11ae63e718a5..dc6f126aae4c87 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -729,8 +729,8 @@ static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
// TODO: consider relaxing this restriction in the future if we find ways
// to really work with subbyte elements across the MLIR/LLVM boundary.
- unsigned resultBitwidth = preconditionType.getElementTypeBitWidth();
- if (resultBitwidth % 8 != 0)
+ unsigned bitwidth = preconditionType.getElementTypeBitWidth();
+ if (bitwidth % 8 != 0)
return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
return success();
@@ -768,6 +768,10 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
(dstElemBitwidth % srcElemBitwidth) != 0)
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
+ if ((srcType.getShape().back() % 2) != 0)
+ return rewriter.notifyMatchFailure(
+ op, "Not an even number of i4 elements in trailing dim");
+
return success();
}
@@ -876,6 +880,58 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
+/// Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
+/// that take advantage of high-level information to avoid leaving LLVM to
+/// scramble with peephole optimizations.
+static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ assert(srcVecType.getElementType().isSignlessInteger(8) &&
+ "Expected i8 type");
+
+ // 1. De-interleave low and high i8 elements.
+ int64_t vecDimSize = srcVecType.getShape().back();
+ SmallVector<int64_t> deinterleaveLowMaskValues;
+ SmallVector<int64_t> deinterleaveHighMaskValues;
+ assert((vecDimSize % 2) == 0 && "Odd number of i4 elements");
+ deinterleaveLowMaskValues.reserve(vecDimSize / 2);
+ deinterleaveHighMaskValues.reserve(vecDimSize / 2);
+ for (int i = 0, end = vecDimSize; i < end; i += 2) {
+ deinterleaveLowMaskValues.push_back(i);
+ deinterleaveHighMaskValues.push_back(i + 1);
+ }
+
+ auto lowShuffleOp = rewriter.create<vector::ShuffleOp>(
+ loc, srcValue, srcValue,
+ rewriter.getI64ArrayAttr(deinterleaveLowMaskValues));
+ auto highShuffleOp = rewriter.create<vector::ShuffleOp>(
+ loc, srcValue, srcValue,
+ rewriter.getI64ArrayAttr(deinterleaveHighMaskValues));
+
+ // 2. Zero out the upper side of each low i8 element.
+ constexpr int8_t i8LowBitMask = 0x0F;
+ Value zeroOutMask = rewriter.create<arith::ConstantOp>(
+ loc,
+ DenseElementsAttr::get(lowShuffleOp.getResultVectorType(), i8LowBitMask));
+ Value zeroOutLow =
+ rewriter.create<arith::AndIOp>(loc, lowShuffleOp, zeroOutMask);
+
+ // 3. Move high i4 values to upper side of the byte.
+ constexpr int8_t bitsToShift = 4;
+ VectorType deinterI8VecType = highShuffleOp.getResultVectorType();
+ auto shiftValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
+ Value shlHigh =
+ rewriter.create<arith::ShLIOp>(loc, highShuffleOp, shiftValues);
+
+ // 4. Merge high and low i4 values.
+ auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
+
+ // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
+ auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
+ return rewriter.create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
+}
+
namespace {
/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
/// advantage of high-level information to avoid leaving LLVM to scramble with
@@ -1019,7 +1075,7 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
LogicalResult matchAndRewrite(ConversionOpType conversionOp,
PatternRewriter &rewriter) const override {
- // Set up the BitCastRewriter and verify the preconditions.
+ // Verify the preconditions.
Value srcValue = conversionOp.getIn();
auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
@@ -1043,6 +1099,65 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
}
};
+/// Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+///
+/// For example:
+/// arith.trunci %in : vector<8xi32> to vector<8xi4>
+/// is rewriten as
+///
+/// %cst = arith.constant dense<15> : vector<4xi8>
+/// %cst_0 = arith.constant dense<4> : vector<4xi8>
+/// %0 = arith.trunci %in : vector<8xi32> to vector<8xi8>
+/// %1 = vector.shuffle %0, %0 [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
+/// %2 = vector.shuffle %0, %0 [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+/// %3 = arith.andi %1, %cst : vector<4xi8>
+/// %4 = arith.shli %2, %cst_0 : vector<4xi8>
+/// %5 = arith.ori %3, %4 : vector<4xi8>
+/// %6 = vector.bitcast %5 : vector<4xi8> to vector<8xi4>
+///
+struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
+ using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
+ PatternRewriter &rewriter) const override {
+ // Verify the preconditions.
+ Value srcValue = truncOp.getIn();
+ auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
+ auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
+ if (!srcVecType || !dstVecType)
+ return failure();
+
+ // Only single dim vectors are supported until we have
+ // `vector.deinterleave`.
+ if (srcVecType.getRank() != 1)
+ return failure();
+
+ if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
+ return failure();
+
+ // Check general alignment preconditions. We invert the src/dst type order
+ // to reuse the existing precondition logic.
+ if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
+ truncOp)))
+ return failure();
+
+ // Create a new iX -> i8 truncation op.
+ Location loc = truncOp.getLoc();
+ auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
+ Value i8TruncVal =
+ rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);
+
+ // Rewrite the i8 -> i4 truncation part.
+ Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
+
+ // Finalize the rewrite.
+ rewriter.replaceOp(truncOp, subByteTrunc);
+ return success();
+ }
+};
+
/// Rewrite a sub-byte vector transpose into a sequence of instructions that
/// perform the transpose on wider (byte) element types.
/// For example:
@@ -1115,8 +1230,9 @@ void vector::populateVectorNarrowTypeRewritePatterns(
// Patterns for aligned cases. We set higher priority as they are expected to
// generate better performance for aligned cases.
patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
- RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
- patterns.getContext(), benefit.getBenefit() + 1);
+ RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
+ RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
+ benefit.getBenefit() + 1);
}
void vector::populateVectorTransposeNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 94e78ce40a3c19..8f0148119806c9 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -262,6 +262,48 @@ func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
return %0 : vector<8x32xf32>
}
+// CHECK-LABEL: func.func @aligned_trunci(
+func.func @aligned_trunci(%a: vector<8xi32>) -> vector<8xi4> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> {
+// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8>
+// CHECK: %[[LOW:.*]] = vector.shuffle %[[I8]], %[[I8]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
+// CHECK: %[[HIGH:.*]] = vector.shuffle %[[I8]], %[[I8]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
+// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
+// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
+ %0 = arith.trunci %a : vector<8xi32> to vector<8xi4>
+ return %0 : vector<8xi4>
+}
+
+// CHECK-LABEL: func.func @aligned_trunci_base_case(
+func.func @aligned_trunci_base_case(%a: vector<8xi8>) -> vector<8xi4> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> {
+// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[LOW:.*]] = vector.shuffle %[[IN]], %[[IN]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
+// CHECK: %[[HIGH:.*]] = vector.shuffle %[[IN]], %[[IN]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
+// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
+// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
+ %0 = arith.trunci %a : vector<8xi8> to vector<8xi4>
+ return %0 : vector<8xi4>
+}
+
+// CHECK-LABEL: func.func @aligned_trunci_2d(
+func.func @aligned_trunci_2d(%a: vector<8x32xi32>) -> vector<8x32xi4> {
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.andi
+// CHECK-NOT: vector.shli
+// CHECK-NOT: vector.ori
+// CHECK: arith.trunci
+ %0 = arith.trunci %a : vector<8x32xi32> to vector<8x32xi4>
+ return %0 : vector<8x32xi4>
+}
+
// CHECK-LABEL: func.func @i4_transpose(
func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
More information about the Mlir-commits
mailing list