[Mlir-commits] [mlir] [mlir][Vector] Add support for trunci to narrow type emulation (PR #82565)
Diego Caballero
llvmlistbot at llvm.org
Wed Feb 21 23:09:37 PST 2024
https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/82565
>From de8a47ae91437a53b483ba72b04c9660df06fe37 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Wed, 21 Feb 2024 23:03:21 +0000
Subject: [PATCH 1/8] [mlir][Vector] Replace `vector.shuffle` with
`vector.interleave` in vector narrow type emulation
This PR replaces the generation of `vector.shuffle` with
`vector.interleave` in the i4 conversions in vector narrow type
emulation. The multi dimensional semantics of `vector.interleave` allow
us to enable these conversion emulations also for multi dimensional
vectors.
---
.../Transforms/VectorEmulateNarrowType.cpp | 17 ++--
.../Vector/vector-rewrite-narrow-types.mlir | 82 +++++++++++++------
2 files changed, 67 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 36fb66708407b4..9ebe36cd3861e0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -724,9 +724,8 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
VectorType preconditionType,
Operation *op) {
- if (!preconditionType || preconditionType.getRank() != 1 ||
- preconditionType.isScalable())
- return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
+ if (!preconditionType || preconditionType.isScalable())
+ return rewriter.notifyMatchFailure(op, "scalable vector");
// TODO: consider relaxing this restriction in the future if we find ways
// to really work with subbyte elements across the MLIR/LLVM boundary.
@@ -743,6 +742,9 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
return rewriter.notifyMatchFailure(op, "types are not vector");
+ if (!preconditionType || preconditionType.getRank() != 1)
+ return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
+
return commonConversionPrecondition(rewriter, preconditionType, op);
}
@@ -879,8 +881,7 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
interleaveMaskValues.push_back(i + (vecDimSize / 2));
}
- return rewriter.create<vector::ShuffleOp>(
- loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues));
+ return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
namespace {
@@ -1008,8 +1009,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
/// %1 = arith.shli %0, 4 : vector<4xi8>
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
-/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
-/// : vector<4xi8>, vector<4xi8>
+/// %4 = vector.interleave %2, %3 : vector<4xi8>
/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
///
/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
@@ -1018,8 +1018,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
/// %1 = arith.shli %0, 4 : vector<4xi8>
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
-/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
-/// : vector<4xi8>, vector<4xi8>
+/// %4 = vector.interleave %2, %3 : vector<4xi8>
/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
///
template <typename ConversionOpType>
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 02063a81664b81..94e78ce40a3c19 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -195,53 +195,89 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
// CHECK-LABEL: func.func @aligned_extsi(
func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
- // CHECK: arith.shli
- // CHECK: arith.shrsi
- // CHECK: arith.shrsi
- // CHECK: vector.shuffle
- // CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32>
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
%0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
return %0 : vector<8xi32>
}
+// CHECK-LABEL: func.func @aligned_extsi_2d(
+func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+ %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
// CHECK-LABEL: func.func @aligned_extsi_base_case(
func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
- // CHECK: arith.shli
- // CHECK: arith.shrsi
- // CHECK: arith.shrsi
- // CHECK: vector.shuffle
- // CHECK-NOT: arith.extsi
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
%0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
return %0 : vector<8xi8>
}
// CHECK-LABEL: func.func @aligned_sitofp(
func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
- // CHECK: arith.shli
- // CHECK: arith.shrsi
- // CHECK: arith.shrsi
- // CHECK: shuffle
- // CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32>
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
%0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func.func @aligned_sitofp_2d(
+func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
+ %0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32>
+ return %0 : vector<8x32xf32>
+}
+
// CHECK-LABEL: func.func @i4_transpose(
-// CHECK-SAME: %[[A:[0-9a-z]*]]
func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
- // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi4> to vector<8x16xi8>
- // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
- // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
+// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
+// CHECK: %[[EXT:.*]] = vector.interleave
+// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
%0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
return %0 : vector<16x8xi4>
}
// CHECK-LABEL: func.func @i7_transpose(
-// CHECK-SAME: %[[A:[0-9a-z]*]]
func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
- // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi7> to vector<8x16xi8>
- // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
- // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
+// CHECK-SAME: %[[IN:.*]]: vector<8x16xi7>) -> vector<16x8xi7> {
+// CHECK: %[[EXT:.*]] = arith.extsi %[[IN]] : vector<8x16xi7> to vector<8x16xi8>
+// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
%0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7>
return %0 : vector<16x8xi7>
}
>From 87f772ff7352a47018849c39ac05175309a190e5 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Thu, 22 Feb 2024 00:10:27 +0000
Subject: [PATCH 2/8] Feedback
---
.../Vector/Transforms/VectorEmulateNarrowType.cpp | 9 +--------
1 file changed, 1 insertion(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 9ebe36cd3861e0..41a778b8496ef0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -873,14 +873,7 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
- // 3. Interleave low and high i8 elements using a shuffle.
- SmallVector<int64_t> interleaveMaskValues;
- interleaveMaskValues.reserve(vecDimSize);
- for (int i = 0, end = vecDimSize / 2; i < end; ++i) {
- interleaveMaskValues.push_back(i);
- interleaveMaskValues.push_back(i + (vecDimSize / 2));
- }
-
+ // 3. Interleave low and high i8 elements.
return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
>From 4a2fa749f3cf746baeffa81b3a15554b6bd88448 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Thu, 22 Feb 2024 00:55:54 +0000
Subject: [PATCH 3/8] Remove unused var
---
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 41a778b8496ef0..fc11ae63e718a5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -857,7 +857,6 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
"Expected i4 type");
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
- int64_t vecDimSize = srcVecType.getShape().back();
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
constexpr int64_t i4Toi8BitwidthFactor = 2;
i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
>From f2021d70a7f5da5d29f5513c5c6b6df6090261f8 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Thu, 22 Feb 2024 01:13:06 +0000
Subject: [PATCH 4/8] [mlir][Vector] Add support for trunci to narrow type
emulation
WIP
---
.../Transforms/VectorEmulateNarrowType.cpp | 125 +++++++++++++++++-
.../Vector/vector-rewrite-narrow-types.mlir | 15 +++
2 files changed, 134 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index fc11ae63e718a5..394041bd2b2b20 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -729,8 +729,8 @@ static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
// TODO: consider relaxing this restriction in the future if we find ways
// to really work with subbyte elements across the MLIR/LLVM boundary.
- unsigned resultBitwidth = preconditionType.getElementTypeBitWidth();
- if (resultBitwidth % 8 != 0)
+ unsigned bitwidth = preconditionType.getElementTypeBitWidth();
+ if (bitwidth % 8 != 0)
return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
return success();
@@ -876,6 +876,54 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
+/// Rewrite the i8 -> i4 signed extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ assert(srcVecType.getElementType().isSignlessInteger(8) &&
+ "Expected i8 type");
+
+ // 1. Zero out the upper side of each i8 element.
+ constexpr int8_t i8BitMask = 0x0F;
+ Value zeroOutMask = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(srcVecType, i8BitMask));
+ Value zeroOutSrc = rewriter.create<arith::AndIOp>(loc, srcValue, zeroOutMask);
+
+ // 2. De-interleave low and high i8 elements.
+ int64_t vecDimSize = srcVecType.getShape().back();
+ SmallVector<int64_t> deinterleaveLowMaskValues;
+ SmallVector<int64_t> deinterleaveHighMaskValues;
+ deinterleaveLowMaskValues.reserve(vecDimSize / 2);
+ deinterleaveHighMaskValues.reserve(vecDimSize / 2);
+ for (int i = 0, end = vecDimSize; i < end; i += 2) {
+ deinterleaveLowMaskValues.push_back(i);
+ deinterleaveHighMaskValues.push_back(i + 1);
+ }
+
+ auto lowShuffleOp = rewriter.create<vector::ShuffleOp>(
+ loc, zeroOutSrc, zeroOutSrc,
+ rewriter.getI64ArrayAttr(deinterleaveLowMaskValues));
+ auto highShuffleOp = rewriter.create<vector::ShuffleOp>(
+ loc, zeroOutSrc, zeroOutSrc,
+ rewriter.getI64ArrayAttr(deinterleaveHighMaskValues));
+
+ // 3. Move high i4 values to upper side of the byte.
+ constexpr int8_t bitsToShift = 4;
+ VectorType deinterI8VecType = highShuffleOp.getResultVectorType();
+ auto shiftValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
+ auto shlHighOp = rewriter.create<arith::ShLIOp>(loc, highShuffleOp, shiftValues);
+
+ // 4. Merge high and low i4 values.
+ auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, shlHighOp, lowShuffleOp);
+
+ // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
+ auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
+ return rewriter.create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
+}
+
namespace {
/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
/// advantage of high-level information to avoid leaving LLVM to scramble with
@@ -1019,7 +1067,7 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
LogicalResult matchAndRewrite(ConversionOpType conversionOp,
PatternRewriter &rewriter) const override {
- // Set up the BitCastRewriter and verify the preconditions.
+ // Verify the preconditions.
Value srcValue = conversionOp.getIn();
auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
@@ -1027,8 +1075,9 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
return failure();
- // Check general alignment preconditions.
- if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
+ // Check general alignment preconditions. We invert the src/dst type order
+ // to resue the extension preconditions.
+ if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
conversionOp)))
return failure();
@@ -1043,6 +1092,69 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
}
};
+/// Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+///
+/// For example:
+/// arith.extsi %in : vector<8xi4> to vector<8xi32>
+/// is rewriten as
+/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+/// %1 = arith.shli %0, 4 : vector<4xi8>
+/// %2 = arith.shrsi %1, 4 : vector<4xi8>
+/// %3 = arith.shrsi %0, 4 : vector<4xi8>
+/// %4 = vector.interleave %2, %3 : vector<4xi8>
+/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
+///
+/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
+/// is rewriten as
+/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
+/// %1 = arith.shli %0, 4 : vector<4xi8>
+/// %2 = arith.shrsi %1, 4 : vector<4xi8>
+/// %3 = arith.shrsi %0, 4 : vector<4xi8>
+/// %4 = vector.interleave %2, %3 : vector<4xi8>
+/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
+///
+struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
+ using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
+ PatternRewriter &rewriter) const override {
+ // Verify the preconditions.
+ Value srcValue = truncOp.getIn();
+ auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
+ auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
+
+ // Only single dim vectors are supported until we have
+ // `vector.deinterleave`.
+ if (srcVecType.getRank() != 1)
+ return failure();
+
+ if (failed(
+ commonConversionPrecondition(rewriter, srcVecType, truncOp)))
+ return failure();
+
+ // Check general alignment preconditions. We invert the src/dst type order
+ // to reuse the existing precondition logic.
+ if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
+ truncOp)))
+ return failure();
+
+ // Create a new iX -> i8 truncation op.
+ Location loc = truncOp.getLoc();
+ auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
+ Value i8TruncVal = rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);
+
+ // Rewrite the i8 -> i4 truncation part.
+ Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
+
+ // Finalize the rewrite.
+ rewriter.replaceOp(truncOp, subByteTrunc);
+ return success();
+ }
+};
+
+
/// Rewrite a sub-byte vector transpose into a sequence of instructions that
/// perform the transpose on wider (byte) element types.
/// For example:
@@ -1115,7 +1227,8 @@ void vector::populateVectorNarrowTypeRewritePatterns(
// Patterns for aligned cases. We set higher priority as they are expected to
// generate better performance for aligned cases.
patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
- RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
+ RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
+ RewriteAlignedSubByteIntTrunc>(
patterns.getContext(), benefit.getBenefit() + 1);
}
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 94e78ce40a3c19..497ca7e876d208 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -262,6 +262,21 @@ func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
return %0 : vector<8x32xf32>
}
+// CHECK-LABEL: func.func @aligned_trunci_base_case(
+func.func @aligned_trunci_base_case(%a: vector<8xi8>) -> vector<8xi4> {
+// CHECK-SAME: %[[VAL_0:.*]]: vector<8xi8>) -> vector<8xi4> {
+// CHECK: %[[VAL_1:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[VAL_2:.*]] = arith.constant dense<15> : vector<8xi8>
+// CHECK: %[[VAL_3:.*]] = arith.andi %[[VAL_0]], %[[VAL_2]] : vector<8xi8>
+// CHECK: %[[VAL_4:.*]] = vector.shuffle %[[VAL_3]], %[[VAL_3]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
+// CHECK: %[[VAL_5:.*]] = vector.shuffle %[[VAL_3]], %[[VAL_3]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+// CHECK: %[[VAL_6:.*]] = arith.shli %[[VAL_5]], %[[VAL_1]] : vector<4xi8>
+// CHECK: %[[VAL_7:.*]] = arith.ori %[[VAL_6]], %[[VAL_4]] : vector<4xi8>
+// CHECK: %[[VAL_8:.*]] = vector.bitcast %[[VAL_7]] : vector<4xi8> to vector<8xi4>
+ %0 = arith.trunci %a : vector<8xi8> to vector<8xi4>
+ return %0 : vector<8xi4>
+}
+
// CHECK-LABEL: func.func @i4_transpose(
func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
>From 8bb89f797c55f46b60d9905323c6e63f32518595 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Thu, 22 Feb 2024 06:32:17 +0000
Subject: [PATCH 5/8] Improve sequence
---
.../Transforms/VectorEmulateNarrowType.cpp | 38 ++++++++++---------
1 file changed, 20 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 394041bd2b2b20..1582a067db99fc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -885,13 +885,7 @@ static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
assert(srcVecType.getElementType().isSignlessInteger(8) &&
"Expected i8 type");
- // 1. Zero out the upper side of each i8 element.
- constexpr int8_t i8BitMask = 0x0F;
- Value zeroOutMask = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(srcVecType, i8BitMask));
- Value zeroOutSrc = rewriter.create<arith::AndIOp>(loc, srcValue, zeroOutMask);
-
- // 2. De-interleave low and high i8 elements.
+ // 1. De-interleave low and high i8 elements.
int64_t vecDimSize = srcVecType.getShape().back();
SmallVector<int64_t> deinterleaveLowMaskValues;
SmallVector<int64_t> deinterleaveHighMaskValues;
@@ -903,21 +897,30 @@ static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
}
auto lowShuffleOp = rewriter.create<vector::ShuffleOp>(
- loc, zeroOutSrc, zeroOutSrc,
+ loc, srcValue, srcValue,
rewriter.getI64ArrayAttr(deinterleaveLowMaskValues));
auto highShuffleOp = rewriter.create<vector::ShuffleOp>(
- loc, zeroOutSrc, zeroOutSrc,
+ loc, srcValue, srcValue,
rewriter.getI64ArrayAttr(deinterleaveHighMaskValues));
- // 3. Move high i4 values to upper side of the byte.
+ // 2. Move high i4 values to upper side of the byte.
constexpr int8_t bitsToShift = 4;
VectorType deinterI8VecType = highShuffleOp.getResultVectorType();
auto shiftValues = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
- auto shlHighOp = rewriter.create<arith::ShLIOp>(loc, highShuffleOp, shiftValues);
+ Value shlHigh =
+ rewriter.create<arith::ShLIOp>(loc, highShuffleOp, shiftValues);
+
+ // 3. Zero out the upper side of each low i8 element.
+ constexpr int8_t i8LowBitMask = 0x0F;
+ Value zeroOutMask = rewriter.create<arith::ConstantOp>(
+ loc,
+ DenseElementsAttr::get(lowShuffleOp.getResultVectorType(), i8LowBitMask));
+ Value zeroOutLow =
+ rewriter.create<arith::AndIOp>(loc, lowShuffleOp, zeroOutMask);
// 4. Merge high and low i4 values.
- auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, shlHighOp, lowShuffleOp);
+ auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, shlHigh, zeroOutLow);
// 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
@@ -1130,8 +1133,7 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
if (srcVecType.getRank() != 1)
return failure();
- if (failed(
- commonConversionPrecondition(rewriter, srcVecType, truncOp)))
+ if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
return failure();
// Check general alignment preconditions. We invert the src/dst type order
@@ -1143,7 +1145,8 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
// Create a new iX -> i8 truncation op.
Location loc = truncOp.getLoc();
auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
- Value i8TruncVal = rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);
+ Value i8TruncVal =
+ rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);
// Rewrite the i8 -> i4 truncation part.
Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
@@ -1154,7 +1157,6 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
}
};
-
/// Rewrite a sub-byte vector transpose into a sequence of instructions that
/// perform the transpose on wider (byte) element types.
/// For example:
@@ -1228,8 +1230,8 @@ void vector::populateVectorNarrowTypeRewritePatterns(
// generate better performance for aligned cases.
patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
- RewriteAlignedSubByteIntTrunc>(
- patterns.getContext(), benefit.getBenefit() + 1);
+ RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
+ benefit.getBenefit() + 1);
}
void vector::populateVectorTransposeNarrowTypeRewritePatterns(
>From ab7880028448a62f31631d8ce529506dc5e9bcc3 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Thu, 22 Feb 2024 06:48:06 +0000
Subject: [PATCH 6/8] Add more tests
---
.../Transforms/VectorEmulateNarrowType.cpp | 43 ++++++++-----------
.../Vector/vector-rewrite-narrow-types.mlir | 34 +++++++++++----
2 files changed, 44 insertions(+), 33 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 1582a067db99fc..190189bb83f2e0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -903,15 +903,7 @@ static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
loc, srcValue, srcValue,
rewriter.getI64ArrayAttr(deinterleaveHighMaskValues));
- // 2. Move high i4 values to upper side of the byte.
- constexpr int8_t bitsToShift = 4;
- VectorType deinterI8VecType = highShuffleOp.getResultVectorType();
- auto shiftValues = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
- Value shlHigh =
- rewriter.create<arith::ShLIOp>(loc, highShuffleOp, shiftValues);
-
- // 3. Zero out the upper side of each low i8 element.
+ // 2. Zero out the upper side of each low i8 element.
constexpr int8_t i8LowBitMask = 0x0F;
Value zeroOutMask = rewriter.create<arith::ConstantOp>(
loc,
@@ -919,6 +911,14 @@ static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
Value zeroOutLow =
rewriter.create<arith::AndIOp>(loc, lowShuffleOp, zeroOutMask);
+ // 3. Move high i4 values to upper side of the byte.
+ constexpr int8_t bitsToShift = 4;
+ VectorType deinterI8VecType = highShuffleOp.getResultVectorType();
+ auto shiftValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
+ Value shlHigh =
+ rewriter.create<arith::ShLIOp>(loc, highShuffleOp, shiftValues);
+
// 4. Merge high and low i4 values.
auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, shlHigh, zeroOutLow);
@@ -1100,23 +1100,18 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
/// LLVM to scramble with peephole optimizations.
///
/// For example:
-/// arith.extsi %in : vector<8xi4> to vector<8xi32>
+/// arith.trunci %in : vector<8xi32> to vector<8xi4>
/// is rewriten as
-/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
-/// %1 = arith.shli %0, 4 : vector<4xi8>
-/// %2 = arith.shrsi %1, 4 : vector<4xi8>
-/// %3 = arith.shrsi %0, 4 : vector<4xi8>
-/// %4 = vector.interleave %2, %3 : vector<4xi8>
-/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
///
-/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
-/// is rewriten as
-/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
-/// %1 = arith.shli %0, 4 : vector<4xi8>
-/// %2 = arith.shrsi %1, 4 : vector<4xi8>
-/// %3 = arith.shrsi %0, 4 : vector<4xi8>
-/// %4 = vector.interleave %2, %3 : vector<4xi8>
-/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
+/// %cst = arith.constant dense<15> : vector<4xi8>
+/// %cst_0 = arith.constant dense<4> : vector<4xi8>
+/// %0 = arith.trunci %in : vector<8xi32> to vector<8xi8>
+/// %1 = vector.shuffle %0, %0 [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
+/// %2 = vector.shuffle %0, %0 [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+/// %3 = arith.andi %1, %cst : vector<4xi8>
+/// %4 = arith.shli %2, %cst_0 : vector<4xi8>
+/// %5 = arith.ori %3, %4 : vector<4xi8>
+/// %6 = vector.bitcast %5 : vector<4xi8> to vector<8xi4>
///
struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 497ca7e876d208..c553596c577df6 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -262,17 +262,33 @@ func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
return %0 : vector<8x32xf32>
}
+// CHECK-LABEL: func.func @aligned_trunci(
+func.func @aligned_trunci(%a: vector<8xi32>) -> vector<8xi4> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> {
+// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8>
+// CHECK: %[[LOW:.*]] = vector.shuffle %[[I8]], %[[I8]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
+// CHECK: %[[HIGH:.*]] = vector.shuffle %[[I8]], %[[I8]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
+// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
+// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
+ %0 = arith.trunci %a : vector<8xi32> to vector<8xi4>
+ return %0 : vector<8xi4>
+}
+
// CHECK-LABEL: func.func @aligned_trunci_base_case(
func.func @aligned_trunci_base_case(%a: vector<8xi8>) -> vector<8xi4> {
-// CHECK-SAME: %[[VAL_0:.*]]: vector<8xi8>) -> vector<8xi4> {
-// CHECK: %[[VAL_1:.*]] = arith.constant dense<4> : vector<4xi8>
-// CHECK: %[[VAL_2:.*]] = arith.constant dense<15> : vector<8xi8>
-// CHECK: %[[VAL_3:.*]] = arith.andi %[[VAL_0]], %[[VAL_2]] : vector<8xi8>
-// CHECK: %[[VAL_4:.*]] = vector.shuffle %[[VAL_3]], %[[VAL_3]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
-// CHECK: %[[VAL_5:.*]] = vector.shuffle %[[VAL_3]], %[[VAL_3]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
-// CHECK: %[[VAL_6:.*]] = arith.shli %[[VAL_5]], %[[VAL_1]] : vector<4xi8>
-// CHECK: %[[VAL_7:.*]] = arith.ori %[[VAL_6]], %[[VAL_4]] : vector<4xi8>
-// CHECK: %[[VAL_8:.*]] = vector.bitcast %[[VAL_7]] : vector<4xi8> to vector<8xi4>
+// CHECK-SAME: %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> {
+// CHECK: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[LOW:.*]] = vector.shuffle %[[IN]], %[[IN]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
+// CHECK: %[[HIGH:.*]] = vector.shuffle %[[IN]], %[[IN]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
+// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
+// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
%0 = arith.trunci %a : vector<8xi8> to vector<8xi4>
return %0 : vector<8xi4>
}
>From e74cf3c3e3458373f688d0b12102fb1faffa8162 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Thu, 22 Feb 2024 06:59:56 +0000
Subject: [PATCH 7/8] 2d test
---
.../Vector/Transforms/VectorEmulateNarrowType.cpp | 5 ++---
.../Vector/vector-rewrite-narrow-types.mlir | 15 +++++++++++++--
2 files changed, 15 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 190189bb83f2e0..c1ee2fbeb7f034 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1078,9 +1078,8 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
return failure();
- // Check general alignment preconditions. We invert the src/dst type order
- // to resue the extension preconditions.
- if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
+ // Check general alignment preconditions.
+ if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
conversionOp)))
return failure();
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index c553596c577df6..8f0148119806c9 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -281,8 +281,8 @@ func.func @aligned_trunci(%a: vector<8xi32>) -> vector<8xi4> {
// CHECK-LABEL: func.func @aligned_trunci_base_case(
func.func @aligned_trunci_base_case(%a: vector<8xi8>) -> vector<8xi4> {
// CHECK-SAME: %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> {
-// CHECK: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
-// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
// CHECK: %[[LOW:.*]] = vector.shuffle %[[IN]], %[[IN]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
// CHECK: %[[HIGH:.*]] = vector.shuffle %[[IN]], %[[IN]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
@@ -293,6 +293,17 @@ func.func @aligned_trunci_base_case(%a: vector<8xi8>) -> vector<8xi4> {
return %0 : vector<8xi4>
}
+// CHECK-LABEL: func.func @aligned_trunci_2d(
+func.func @aligned_trunci_2d(%a: vector<8x32xi32>) -> vector<8x32xi4> {
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.andi
+// CHECK-NOT: vector.shli
+// CHECK-NOT: vector.ori
+// CHECK: arith.trunci
+ %0 = arith.trunci %a : vector<8x32xi32> to vector<8x32xi4>
+ return %0 : vector<8x32xi4>
+}
+
// CHECK-LABEL: func.func @i4_transpose(
func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
>From 6bf7417b7f09f7588924a18b862bb406565b9ea6 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Thu, 22 Feb 2024 07:09:14 +0000
Subject: [PATCH 8/8] Fix ori order
---
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index c1ee2fbeb7f034..82c08cc5a54936 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -920,7 +920,7 @@ static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
rewriter.create<arith::ShLIOp>(loc, highShuffleOp, shiftValues);
// 4. Merge high and low i4 values.
- auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, shlHigh, zeroOutLow);
+ auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
// 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
More information about the Mlir-commits
mailing list