[Mlir-commits] [mlir] [WIP] using splat shifts (PR #87121)
Finn Plummer
llvmlistbot at llvm.org
Fri Apr 12 11:36:59 PDT 2024
https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/87121
>From e29a6dafa80938e78a5e7d37bfa33d50a829a266 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Thu, 28 Mar 2024 09:11:04 -0700
Subject: [PATCH 1/2] [WIP] using splat shifts
---
.../Vector/TransformOps/VectorTransformOps.td | 2 +
.../Vector/Transforms/VectorRewritePatterns.h | 3 +-
.../TransformOps/VectorTransformOps.cpp | 3 +-
.../Transforms/VectorEmulateNarrowType.cpp | 184 +++++++++++++++++-
.../Vector/vector-rewrite-narrow-types.mlir | 76 +++++++-
.../Vector/CPU/test-rewrite-narrow-types.mlir | 70 ++++++-
6 files changed, 325 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index f6371f39c39444..81d66e7b6ab179 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -318,6 +318,8 @@ def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
Warning: these patterns currently only work for little endian targets.
}];
+ let arguments = (ins DefaultValuedAttr<I64Attr, "0">:$max_cycle_len);
+
let assemblyFormat = "attr-dict";
}
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 453fa73429dd1a..9cd3d80c441d95 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -376,7 +376,8 @@ FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
/// ops over wider types.
/// Warning: these patterns currently only work for little endian targets.
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
+ PatternBenefit benefit = 1,
+ unsigned shiftDepth = 0);
/// Appends patterns for emulating a sub-byte vector transpose.
void populateVectorTransposeNarrowTypeRewritePatterns(
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 885644864c0f71..e8c9033da72686 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -166,7 +166,8 @@ void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- populateVectorNarrowTypeRewritePatterns(patterns);
+ populateVectorNarrowTypeRewritePatterns(patterns, /*default=*/1,
+ getMaxCycleLen());
populateVectorTransposeNarrowTypeRewritePatterns(patterns);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index dc6f126aae4c87..ecd3879e5f1f9e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -546,6 +546,14 @@ struct SourceElementRangeList : public SmallVector<SourceElementRange> {
/// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
/// [0] = {0, [0, 10)}, {1, [0, 5)}
/// [1] = {1, [5, 10)}, {2, [0, 10)}
+/// and `vector.bitcast ... : vector<4xi4> to vector<2xi8>` is decomposed as:
+/// [0] = {0, [0, 4)}, {1, [0, 4)}
+/// [1] = {2, [0, 4)}, {3, [0, 4)}
+/// and `vector.bitcast ... : vector<2xi8> to vector<4xi4>` is decomposed as:
+/// [0] = {0, [0, 4)}
+/// [1] = {0, [4, 8)}
+/// [2] = {1, [0, 4)}
+/// [3] = {1, [4, 8)}
struct BitCastBitsEnumerator {
BitCastBitsEnumerator(VectorType sourceVectorType,
VectorType targetVectorType);
@@ -633,6 +641,35 @@ struct BitCastBitsEnumerator {
/// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
/// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
/// bits extracted from the source vector (i.e. the `shuffle -> and` part).
+///
+///
+/// When we consider the above algorithm to rewrite our vector.bitcast, we rely
+/// on using dynamic shift amounts for the left and right shifts. This can be
+/// inefficient on certain targets (RDNA GPUs) in contrast to a splat constant
+/// value. So when possible we can rewrite this as a combination of shifts with
+/// a constant splat value and then regroup the selected terms.
+///
+/// Eg. Instead of:
+/// res = arith.shrui x [0, 4, 8, 0, 4, 8]
+/// use:
+/// y = arith.shrui x [0, 0, 0, 0, 0, 0] (can be folded away)
+/// y1 = arith.shrui x [4, 4, 4, 4, 4, 4]
+/// y2 = arith.shrui x [8, 8, 8, 8, 8, 8]
+/// y3 = vector.shuffle y y1 [0, 7, 3, 10]
+/// res = vector.shuffle y3 y2 [0, 1, 7, 2, 3, 10]
+///
+/// This is possible when the precomputed shift amounts following a cyclic
+/// pattern of [x, y, z, ..., x, y, z, ...] such that the cycle length,
+/// cycleLen, satisifies 1 < cycleLen < size(shiftAmounts). And the shuffles are
+/// of the form [0, 0, 0, ..., 1, 1, 1, ...]. A common pattern in
+/// (de)quantization, i24 -> 3xi8 or 3xi8 -> i24. The modified algorithm follows
+/// the same 2 steps as above, then it proceeds as follows:
+///
+/// 2. for each element in the cycle, x, of the rightShiftAmounts create a
+/// shrui with a splat constant of x.
+/// 3. repeat 2. with the respective leftShiftAmounts
+/// 4. construct a chain of vector.shuffles that will reconstruct the result
+/// from the chained shifts
struct BitCastRewriter {
/// Helper metadata struct to hold the static quantities for the rewrite.
struct Metadata {
@@ -656,10 +693,25 @@ struct BitCastRewriter {
Value initialValue, Value runningResult,
const BitCastRewriter::Metadata &metadata);
+ /// Rewrite one step of the sequence when able to use a splat constant for the
+ /// shiftright and shiftleft.
+ Value splatRewriteStep(PatternRewriter &rewriter, Location loc,
+ Value initialValue, Value runningResult,
+ const BitCastRewriter::Metadata &metadata);
+
+ bool useSplatStep(unsigned maxCycleLen) {
+ return 1 < cycleLen && cycleLen <= maxCycleLen;
+ }
+
private:
/// Underlying enumerator that encodes the provenance of the bits in the each
/// element of the result vector.
BitCastBitsEnumerator enumerator;
+
+ // Underlying cycleLen computed during precomputeMetadata. A cycleLen > 1
+ // denotes that there is a cycle in the precomputed shift amounts and we are
+ // able to use the splatRewriteStep.
+ int64_t cycleLen = 0;
};
} // namespace
@@ -775,8 +827,40 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
return success();
}
+// Check if the vector is a cycle of the first cycleLen elements.
+template <class T>
+static bool isCyclic(SmallVector<T> xs, int64_t cycleLen) {
+ for (int64_t idx = cycleLen, n = xs.size(); idx < n; idx++) {
+ if (xs[idx] != xs[idx % cycleLen])
+ return false;
+ }
+ return true;
+}
+
+static SmallVector<int64_t> constructShuffles(int64_t inputSize,
+ int64_t numCycles,
+ int64_t cycleLen, int64_t idx) {
+ // If idx == 1, then the first operand of the shuffle will be the mask which
+ // will have the original size. So we need to step through the mask with a
+ // stride of cycleSize.
+ // When idx > 1, then the first operand will be the size of (idx * cycleSize)
+ // and so we take the first idx elements of the input and then append the
+ // strided mask value.
+ int64_t inputStride = idx == 1 ? cycleLen : idx;
+
+ SmallVector<int64_t> shuffles;
+ for (int64_t cycle = 0; cycle < numCycles; cycle++) {
+ for (int64_t inputIdx = 0; inputIdx < idx; inputIdx++) {
+ shuffles.push_back(cycle * inputStride + inputIdx);
+ }
+ shuffles.push_back(inputSize + cycle * cycleLen + idx);
+ }
+ return shuffles;
+}
+
SmallVector<BitCastRewriter::Metadata>
BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
+ bool cyclicShifts = true;
SmallVector<BitCastRewriter::Metadata> result;
for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
shuffleIdx < e; ++shuffleIdx) {
@@ -811,8 +895,71 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
IntegerAttr::get(shuffledElementType, shiftLeft));
}
+ // Compute a potential cycle size by detecting the number of sourceElements
+ // at the start of shuffle that are the same
+ cycleLen = 1;
+ for (int64_t n = shuffles.size(); cycleLen < n; cycleLen++)
+ if (shuffles[cycleLen] != shuffles[0])
+ break;
+
+ cyclicShifts = cyclicShifts && (cycleLen < (int64_t)shuffles.size()) &&
+ isCyclic(shiftRightAmounts, cycleLen) &&
+ isCyclic(shiftLeftAmounts, cycleLen);
+
result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
}
+
+ cycleLen = cyclicShifts ? cycleLen : 0;
+ return result;
+}
+
+Value BitCastRewriter::splatRewriteStep(
+ PatternRewriter &rewriter, Location loc, Value initialValue,
+ Value runningResult, const BitCastRewriter::Metadata &metadata) {
+
+ // Initial result will be the Shifted Mask which will have the shuffles size.
+ int64_t inputSize = metadata.shuffles.size();
+ int64_t numCycles = inputSize / cycleLen;
+
+ auto shuffleOp = rewriter.create<vector::ShuffleOp>(
+ loc, initialValue, initialValue, metadata.shuffles);
+
+ // Intersect with the mask.
+ VectorType shuffledVectorType = shuffleOp.getResultVectorType();
+ auto constOp = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
+ Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
+
+ Value result;
+ for (int64_t idx = 0; idx < cycleLen; idx++) {
+ auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
+ loc, SplatElementsAttr::get(shuffledVectorType,
+ metadata.shiftRightAmounts[idx]));
+ Value shiftedRight =
+ rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
+
+ auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
+ loc, SplatElementsAttr::get(shuffledVectorType,
+ metadata.shiftLeftAmounts[idx]));
+ Value shiftedLeft =
+ rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
+
+ if (result) {
+ SmallVector<int64_t> shuffles =
+ constructShuffles(inputSize, numCycles, cycleLen, idx);
+ result = rewriter.create<vector::ShuffleOp>(loc, result, shiftedLeft,
+ shuffles);
+
+ // After the first shuffle in the chain, the size of the input result will
+ // grow as we append more shuffles together to reconstruct the
+ // shuffledVectorType size. Each iteration they will retain numCycles more
+ // elements.
+ inputSize = numCycles * (idx + 1);
+ } else {
+ result = shiftedLeft;
+ }
+ }
+
return result;
}
@@ -939,6 +1086,11 @@ namespace {
struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
using OpRewritePattern::OpRewritePattern;
+ RewriteBitCastOfTruncI(MLIRContext *context, PatternBenefit benefit,
+ unsigned maxCycleLen)
+ : OpRewritePattern<vector::BitCastOp>(context, benefit),
+ maxCycleLen{maxCycleLen} {}
+
LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
PatternRewriter &rewriter) const override {
// The source must be a trunc op.
@@ -961,8 +1113,12 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
Value runningResult;
for (const BitCastRewriter ::Metadata &metadata :
bcr.precomputeMetadata(shuffledElementType)) {
- runningResult = bcr.genericRewriteStep(
- rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
+ runningResult =
+ bcr.useSplatStep(maxCycleLen)
+ ? bcr.splatRewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
+ runningResult, metadata)
+ : bcr.genericRewriteStep(rewriter, bitCastOp->getLoc(),
+ truncValue, runningResult, metadata);
}
// Finalize the rewrite.
@@ -986,6 +1142,9 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
return success();
}
+
+private:
+ unsigned maxCycleLen;
};
} // namespace
@@ -1001,8 +1160,10 @@ template <typename ExtOpType>
struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
using OpRewritePattern<ExtOpType>::OpRewritePattern;
- RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
- : OpRewritePattern<ExtOpType>(context, benefit) {}
+ RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit,
+ unsigned maxCycleLen)
+ : OpRewritePattern<ExtOpType>(context, benefit),
+ maxCycleLen{maxCycleLen} {}
LogicalResult matchAndRewrite(ExtOpType extOp,
PatternRewriter &rewriter) const override {
@@ -1026,8 +1187,12 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
for (const BitCastRewriter::Metadata &metadata :
bcr.precomputeMetadata(shuffledElementType)) {
- runningResult = bcr.genericRewriteStep(
- rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
+ runningResult =
+ bcr.useSplatStep(maxCycleLen)
+ ? bcr.splatRewriteStep(rewriter, bitCastOp->getLoc(), sourceValue,
+ runningResult, metadata)
+ : bcr.genericRewriteStep(rewriter, bitCastOp->getLoc(),
+ sourceValue, runningResult, metadata);
}
// Finalize the rewrite.
@@ -1044,6 +1209,9 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
return success();
}
+
+private:
+ unsigned maxCycleLen;
};
/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
@@ -1222,10 +1390,10 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
}
void vector::populateVectorNarrowTypeRewritePatterns(
- RewritePatternSet &patterns, PatternBenefit benefit) {
+ RewritePatternSet &patterns, PatternBenefit benefit, unsigned maxCycleLen) {
patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
- benefit);
+ benefit, maxCycleLen);
// Patterns for aligned cases. We set higher priority as they are expected to
// generate better performance for aligned cases.
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 8f0148119806c9..8dd361b8a3573e 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -146,6 +146,42 @@ func.func @f4(%a: vector<16xi16>) -> vector<8xi6> {
return %1 : vector<8xi6>
}
+// CHECK-LABEL: func.func @ftrunc_splat1(
+// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<2xi16>) -> vector<1xi8> {
+func.func @ftrunc_splat1(%a: vector<2xi16>) -> vector<1xi8> {
+ // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<15> : vector<1xi16>
+ // CHECK-DAG: %[[SHL_CST:.*]] = arith.constant dense<4> : vector<1xi16>
+ // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0] : vector<2xi16>, vector<2xi16>
+ // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<1xi16>
+ // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1] : vector<2xi16>, vector<2xi16>
+ // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK]] : vector<1xi16>
+ // CHECK: %[[SHL0:.*]] = arith.shli %[[A1]], %[[SHL_CST]] : vector<1xi16>
+ // CHECK: %[[O1:.*]] = arith.ori %[[A0]], %[[SHL0]] : vector<1xi16>
+ // CHECK: %[[RES:.*]] = arith.trunci %[[O1]] : vector<1xi16> to vector<1xi8>
+ // return %[[RES]] : vector<1xi8>
+ %0 = arith.trunci %a : vector<2xi16> to vector<2xi4>
+ %1 = vector.bitcast %0 : vector<2xi4> to vector<1xi8>
+ return %1 : vector<1xi8>
+}
+
+// CHECK-LABEL: func.func @ftrunc_splat2(
+// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<4xi16>) -> vector<2xi8> {
+func.func @ftrunc_splat2(%a: vector<4xi16>) -> vector<2xi8> {
+ // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<15> : vector<2xi16>
+ // CHECK-DAG: %[[SHL_CST:.*]] = arith.constant dense<4> : vector<2xi16>
+ // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 2] : vector<4xi16>, vector<4xi16>
+ // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<2xi16>
+ // CHECK: %[[V1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 3] : vector<4xi16>, vector<4xi16>
+ // CHECK: %[[A1:.*]] = arith.andi %[[V1]], %[[MASK]] : vector<2xi16>
+ // CHECK: %[[SHL0:.*]] = arith.shli %[[A1]], %[[SHL_CST]] : vector<2xi16>
+ // CHECK: %[[O1:.*]] = arith.ori %[[A0]], %[[SHL0]] : vector<2xi16>
+ // CHECK: %[[RES:.*]] = arith.trunci %[[O1]] : vector<2xi16> to vector<2xi8>
+ // return %[[RES]] : vector<2xi8>
+ %0 = arith.trunci %a : vector<4xi16> to vector<4xi4>
+ %1 = vector.bitcast %0 : vector<4xi4> to vector<2xi8>
+ return %1 : vector<2xi8>
+}
+
// CHECK-LABEL: func.func @f1ext(
// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<5xi8>) -> vector<8xi16> {
func.func @f1ext(%a: vector<5xi8>) -> vector<8xi16> {
@@ -193,6 +229,44 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
return %1 : vector<8xi17>
}
+// CHECK-LABEL: func.func @fext_splat1(
+// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<2xi8>) -> vector<4xi16> {
+func.func @fext_splat1(%a: vector<2xi8>) -> vector<4xi16> {
+ // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<[15, -16, 15, -16]> : vector<4xi8>
+ // CHECK-DAG: %[[SHR_CST:.*]] = arith.constant dense<4> : vector<4xi8>
+ // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 0, 1, 1] : vector<2xi8>, vector<2xi8>
+ // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<4xi8>
+ // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR_CST]] : vector<4xi8>
+ // CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 5, 2, 7] : vector<4xi8>, vector<4xi8>
+ // CHECK: %[[RES:.*]] = arith.extui %[[V1]] : vector<4xi8> to vector<4xi16>
+ // return %[[RES]] : vector<4xi16>
+ %0 = vector.bitcast %a : vector<2xi8> to vector<4xi4>
+ %1 = arith.extui %0 : vector<4xi4> to vector<4xi16>
+ return %1 : vector<4xi16>
+}
+
+// CHECK-LABEL: func.func @fext_splat2(
+// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<3xi16>) -> vector<12xi32> {
+func.func @fext_splat2(%a: vector<3xi16>) -> vector<12xi32> {
+ // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<[15, 240, 3840, -4096, 15, 240, 3840, -4096, 15, 240, 3840, -4096]> : vector<12xi16>
+ // CHECK-DAG: %[[SHR_CST0:.*]] = arith.constant dense<4> : vector<12xi16>
+ // CHECK-DAG: %[[SHR_CST1:.*]] = arith.constant dense<8> : vector<12xi16>
+ // CHECK-DAG: %[[SHR_CST2:.*]] = arith.constant dense<12> : vector<12xi16>
+ // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] : vector<3xi16>, vector<3xi16>
+ // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<12xi16>
+ // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR_CST0]] : vector<12xi16>
+ // CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 13, 4, 17, 8, 21] : vector<12xi16>, vector<12xi16>
+ // CHECK: %[[SHR1:.*]] = arith.shrui %[[A0]], %[[SHR_CST1]] : vector<12xi16>
+ // CHECK: %[[V2:.*]] = vector.shuffle %[[V1]], %[[SHR1]] [0, 1, 8, 2, 3, 12, 4, 5, 16] : vector<6xi16>, vector<12xi16>
+ // CHECK: %[[SHR2:.*]] = arith.shrui %[[A0]], %[[SHR_CST2]] : vector<12xi16>
+ // CHECK: %[[V3:.*]] = vector.shuffle %[[V2]], %[[SHR2]] [0, 1, 2, 12, 3, 4, 5, 16, 6, 7, 8, 20] : vector<9xi16>, vector<12xi16>
+ // CHECK: %[[RES:.*]] = arith.extui %[[V3]] : vector<12xi16> to vector<12xi32>
+ // CHEKC: return %[[RES]] : vector<12xi32>
+ %0 = vector.bitcast %a : vector<3xi16> to vector<12xi4>
+ %1 = arith.extui %0 : vector<12xi4> to vector<12xi32>
+ return %1 : vector<12xi32>
+}
+
// CHECK-LABEL: func.func @aligned_extsi(
func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
@@ -330,7 +404,7 @@ module attributes {transform.with_named_sequence} {
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %f {
- transform.apply_patterns.vector.rewrite_narrow_types
+ transform.apply_patterns.vector.rewrite_narrow_types { max_cycle_len = 4 }
} : !transform.any_op
transform.yield
}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
index a0b39a2b68f438..a7e13ea1a79c4f 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir
@@ -124,6 +124,36 @@ func.func @f3(%v: vector<2xi48>) {
return
}
+func.func @print_as_i1_2xi8(%v : vector<2xi8>) {
+ %bitsi16 = vector.bitcast %v : vector<2xi8> to vector<16xi1>
+ vector.print %bitsi16 : vector<16xi1>
+ return
+}
+
+func.func @print_as_i1_4xi4(%v : vector<4xi4>) {
+ %bitsi16 = vector.bitcast %v : vector<4xi4> to vector<16xi1>
+ vector.print %bitsi16 : vector<16xi1>
+ return
+}
+
+func.func @ftrunc_splat(%v: vector<2xi24>) {
+ %trunc = arith.trunci %v : vector<2xi24> to vector<2xi8>
+ func.call @print_as_i1_2xi8(%trunc) : (vector<2xi8>) -> ()
+ // CHECK: (
+ // CHECK-SAME: 0, 1, 1, 1, 1, 1, 1, 1,
+ // CHECK-SAME: 1, 1, 0, 0, 0, 0, 1, 1 )
+
+ %bitcast = vector.bitcast %trunc : vector<2xi8> to vector<4xi4>
+ func.call @print_as_i1_4xi4(%bitcast) : (vector<4xi4>) -> ()
+ // CHECK: (
+ // CHECK-SAME: 0, 1, 1, 1,
+ // CHECK-SAME: 1, 1, 1, 1,
+ // CHECK-SAME: 1, 1, 0, 0,
+ // CHECK-SAME: 0, 0, 1, 1 )
+
+ return
+}
+
func.func @print_as_i1_8xi5(%v : vector<8xi5>) {
%bitsi40 = vector.bitcast %v : vector<8xi5> to vector<40xi1>
vector.print %bitsi40 : vector<40xi1>
@@ -164,6 +194,32 @@ func.func @fext(%a: vector<5xi8>) {
return
}
+func.func @print_as_i1_4xi8(%v : vector<4xi8>) {
+ %bitsi32 = vector.bitcast %v : vector<4xi8> to vector<32xi1>
+ vector.print %bitsi32 : vector<32xi1>
+ return
+}
+
+func.func @fext_splat(%a: vector<2xi8>) {
+ %0 = vector.bitcast %a : vector<2xi8> to vector<4xi4>
+ func.call @print_as_i1_4xi4(%0) : (vector<4xi4>) -> ()
+ // CHECK: (
+ // CHECK-SAME: 0, 1, 1, 1,
+ // CHECK-SAME: 1, 1, 1, 1,
+ // CHECK-SAME: 1, 1, 0, 0,
+ // CHECK-SAME: 0, 0, 1, 1 )
+
+ %1 = arith.extui %0 : vector<4xi4> to vector<4xi8>
+ func.call @print_as_i1_4xi8(%1) : (vector<4xi8>) -> ()
+ // CHECK: (
+ // CHECK-SAME: 0, 1, 1, 1, 0, 0, 0, 0,
+ // CHECK-SAME: 1, 1, 1, 1, 0, 0, 0, 0,
+ // CHECK-SAME: 1, 1, 0, 0, 0, 0, 0, 0,
+ // CHECK-SAME: 0, 0, 1, 1, 0, 0, 0, 0 )
+
+ return
+}
+
func.func @fcst_maskedload(%A: memref<?xi4>, %passthru: vector<6xi4>) -> vector<6xi4> {
%c0 = arith.constant 0: index
%mask = vector.constant_mask [3] : vector<6xi1>
@@ -190,9 +246,19 @@ func.func @entry() {
func.call @f3(%v3) : (vector<2xi48>) -> ()
%v4 = arith.constant dense<[
+ 0xafe, 0xbc3
+ ]> : vector<2xi24>
+ func.call @ftrunc_splat(%v4) : (vector<2xi24>) -> ()
+
+ %v5 = arith.constant dense<[
0xef, 0xee, 0xed, 0xec, 0xeb
]> : vector<5xi8>
- func.call @fext(%v4) : (vector<5xi8>) -> ()
+ func.call @fext(%v5) : (vector<5xi8>) -> ()
+
+ %v6 = arith.constant dense<[
+ 0xfe, 0xc3
+ ]> : vector<2xi8>
+ func.call @fext_splat(%v6) : (vector<2xi8>) -> ()
// Set up memory.
%c0 = arith.constant 0: index
@@ -218,7 +284,7 @@ module attributes {transform.with_named_sequence} {
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %f {
- transform.apply_patterns.vector.rewrite_narrow_types
+ transform.apply_patterns.vector.rewrite_narrow_types { max_cycle_len = 4 }
} : !transform.any_op
transform.yield
}
>From 461b96911b198c05afe675441d7e7a94fd39b876 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 12 Apr 2024 11:19:38 -0700
Subject: [PATCH 2/2] Split up the and operation over the cycles
- allows for the shifts to only compute the shift on the required
cycle and not the entire input
- removes the initial vector.shuffle as we can operate directly on the
input and merge them after
---
.../Transforms/VectorEmulateNarrowType.cpp | 61 ++++++-------------
.../Vector/vector-rewrite-narrow-types.mlir | 46 ++++++++------
2 files changed, 44 insertions(+), 63 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index ecd3879e5f1f9e..425431b7fa3439 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -837,23 +837,14 @@ static bool isCyclic(SmallVector<T> xs, int64_t cycleLen) {
return true;
}
-static SmallVector<int64_t> constructShuffles(int64_t inputSize,
- int64_t numCycles,
+static SmallVector<int64_t> constructShuffles(int64_t numCycles,
int64_t cycleLen, int64_t idx) {
- // If idx == 1, then the first operand of the shuffle will be the mask which
- // will have the original size. So we need to step through the mask with a
- // stride of cycleSize.
- // When idx > 1, then the first operand will be the size of (idx * cycleSize)
- // and so we take the first idx elements of the input and then append the
- // strided mask value.
- int64_t inputStride = idx == 1 ? cycleLen : idx;
-
SmallVector<int64_t> shuffles;
for (int64_t cycle = 0; cycle < numCycles; cycle++) {
for (int64_t inputIdx = 0; inputIdx < idx; inputIdx++) {
- shuffles.push_back(cycle * inputStride + inputIdx);
+ shuffles.push_back(cycle * idx + inputIdx);
}
- shuffles.push_back(inputSize + cycle * cycleLen + idx);
+ shuffles.push_back(numCycles * idx + cycle);
}
return shuffles;
}
@@ -917,47 +908,31 @@ Value BitCastRewriter::splatRewriteStep(
PatternRewriter &rewriter, Location loc, Value initialValue,
Value runningResult, const BitCastRewriter::Metadata &metadata) {
- // Initial result will be the Shifted Mask which will have the shuffles size.
- int64_t inputSize = metadata.shuffles.size();
- int64_t numCycles = inputSize / cycleLen;
-
- auto shuffleOp = rewriter.create<vector::ShuffleOp>(
- loc, initialValue, initialValue, metadata.shuffles);
-
- // Intersect with the mask.
- VectorType shuffledVectorType = shuffleOp.getResultVectorType();
- auto constOp = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
- Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
-
+ int64_t numCycles = metadata.shuffles.size() / cycleLen;
+ ShapedType vectorType = dyn_cast<ShapedType>(initialValue.getType());
Value result;
for (int64_t idx = 0; idx < cycleLen; idx++) {
+ // Intersect with the mask.
+ auto constOp = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(vectorType, metadata.masks[idx]));
+ Value andValue = rewriter.create<arith::AndIOp>(loc, initialValue, constOp);
+
auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
- loc, SplatElementsAttr::get(shuffledVectorType,
- metadata.shiftRightAmounts[idx]));
+ loc,
+ SplatElementsAttr::get(vectorType, metadata.shiftRightAmounts[idx]));
Value shiftedRight =
rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
- loc, SplatElementsAttr::get(shuffledVectorType,
- metadata.shiftLeftAmounts[idx]));
+ loc,
+ SplatElementsAttr::get(vectorType, metadata.shiftLeftAmounts[idx]));
Value shiftedLeft =
rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
- if (result) {
- SmallVector<int64_t> shuffles =
- constructShuffles(inputSize, numCycles, cycleLen, idx);
- result = rewriter.create<vector::ShuffleOp>(loc, result, shiftedLeft,
- shuffles);
-
- // After the first shuffle in the chain, the size of the input result will
- // grow as we append more shuffles together to reconstruct the
- // shuffledVectorType size. Each iteration they will retain numCycles more
- // elements.
- inputSize = numCycles * (idx + 1);
- } else {
- result = shiftedLeft;
- }
+ SmallVector<int64_t> shuffles = constructShuffles(numCycles, cycleLen, idx);
+ result = result ? rewriter.create<vector::ShuffleOp>(loc, result,
+ shiftedLeft, shuffles)
+ : shiftedLeft;
}
return result;
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 8dd361b8a3573e..396a9e9ee2cb5b 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -230,14 +230,15 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
}
// CHECK-LABEL: func.func @fext_splat1(
-// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<2xi8>) -> vector<4xi16> {
+// CHECK-SAME: %[[ARG:[0-9a-z]*]]: vector<2xi8>) -> vector<4xi16> {
func.func @fext_splat1(%a: vector<2xi8>) -> vector<4xi16> {
- // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<[15, -16, 15, -16]> : vector<4xi8>
- // CHECK-DAG: %[[SHR_CST:.*]] = arith.constant dense<4> : vector<4xi8>
- // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 0, 1, 1] : vector<2xi8>, vector<2xi8>
- // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<4xi8>
- // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR_CST]] : vector<4xi8>
- // CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 5, 2, 7] : vector<4xi8>, vector<4xi8>
+ // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<15> : vector<2xi8>
+ // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<-16> : vector<2xi8>
+ // CHECK-DAG: %[[SHR_CST:.*]] = arith.constant dense<4> : vector<2xi8>
+ // CHECK-DAG: %[[A0:.*]] = arith.andi %[[ARG]], %[[MASK0]] : vector<2xi8>
+ // CHECK-DAG: %[[A1:.*]] = arith.andi %[[ARG]], %[[MASK1]] : vector<2xi8>
+ // CHECK: %[[SHR0:.*]] = arith.shrui %[[A1]], %[[SHR_CST]] : vector<2xi8>
+ // CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 2, 1, 3] : vector<2xi8>, vector<2xi8>
// CHECK: %[[RES:.*]] = arith.extui %[[V1]] : vector<4xi8> to vector<4xi16>
// return %[[RES]] : vector<4xi16>
%0 = vector.bitcast %a : vector<2xi8> to vector<4xi4>
@@ -246,20 +247,25 @@ func.func @fext_splat1(%a: vector<2xi8>) -> vector<4xi16> {
}
// CHECK-LABEL: func.func @fext_splat2(
-// CHECK-SAME: %[[A:[0-9a-z]*]]: vector<3xi16>) -> vector<12xi32> {
+// CHECK-SAME: %[[ARG:[0-9a-z]*]]: vector<3xi16>) -> vector<12xi32> {
func.func @fext_splat2(%a: vector<3xi16>) -> vector<12xi32> {
- // CHECK-DAG: %[[MASK:.*]] = arith.constant dense<[15, 240, 3840, -4096, 15, 240, 3840, -4096, 15, 240, 3840, -4096]> : vector<12xi16>
- // CHECK-DAG: %[[SHR_CST0:.*]] = arith.constant dense<4> : vector<12xi16>
- // CHECK-DAG: %[[SHR_CST1:.*]] = arith.constant dense<8> : vector<12xi16>
- // CHECK-DAG: %[[SHR_CST2:.*]] = arith.constant dense<12> : vector<12xi16>
- // CHECK: %[[V0:.*]] = vector.shuffle %[[A]], %[[A]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] : vector<3xi16>, vector<3xi16>
- // CHECK: %[[A0:.*]] = arith.andi %[[V0]], %[[MASK]] : vector<12xi16>
- // CHECK: %[[SHR0:.*]] = arith.shrui %[[A0]], %[[SHR_CST0]] : vector<12xi16>
- // CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 13, 4, 17, 8, 21] : vector<12xi16>, vector<12xi16>
- // CHECK: %[[SHR1:.*]] = arith.shrui %[[A0]], %[[SHR_CST1]] : vector<12xi16>
- // CHECK: %[[V2:.*]] = vector.shuffle %[[V1]], %[[SHR1]] [0, 1, 8, 2, 3, 12, 4, 5, 16] : vector<6xi16>, vector<12xi16>
- // CHECK: %[[SHR2:.*]] = arith.shrui %[[A0]], %[[SHR_CST2]] : vector<12xi16>
- // CHECK: %[[V3:.*]] = vector.shuffle %[[V2]], %[[SHR2]] [0, 1, 2, 12, 3, 4, 5, 16, 6, 7, 8, 20] : vector<9xi16>, vector<12xi16>
+ // CHECK-DAG: %[[MASK0:.*]] = arith.constant dense<15> : vector<3xi16>
+ // CHECK-DAG: %[[MASK1:.*]] = arith.constant dense<240> : vector<3xi16>
+ // CHECK-DAG: %[[MASK2:.*]] = arith.constant dense<3840> : vector<3xi16>
+ // CHECK-DAG: %[[MASK3:.*]] = arith.constant dense<-4096> : vector<3xi16>
+ // CHECK-DAG: %[[SHR_CST0:.*]] = arith.constant dense<4> : vector<3xi16>
+ // CHECK-DAG: %[[SHR_CST1:.*]] = arith.constant dense<8> : vector<3xi16>
+ // CHECK-DAG: %[[SHR_CST2:.*]] = arith.constant dense<12> : vector<3xi16>
+ // CHECK: %[[A0:.*]] = arith.andi %[[ARG]], %[[MASK0]] : vector<3xi16>
+ // CHECK: %[[A1:.*]] = arith.andi %[[ARG]], %[[MASK1]] : vector<3xi16>
+ // CHECK: %[[SHR0:.*]] = arith.shrui %[[A1]], %[[SHR_CST0]] : vector<3xi16>
+ // CHECK: %[[V1:.*]] = vector.shuffle %[[A0]], %[[SHR0]] [0, 3, 1, 4, 2, 5] : vector<3xi16>, vector<3xi16>
+ // CHECK: %[[A2:.*]] = arith.andi %[[ARG]], %[[MASK2]] : vector<3xi16>
+ // CHECK: %[[SHR1:.*]] = arith.shrui %[[A2]], %[[SHR_CST1]] : vector<3xi16>
+ // CHECK: %[[V2:.*]] = vector.shuffle %[[V1]], %[[SHR1]] [0, 1, 6, 2, 3, 7, 4, 5, 8] : vector<6xi16>, vector<3xi16>
+ // CHECK: %[[A3:.*]] = arith.andi %[[ARG]], %[[MASK3]] : vector<3xi16>
+ // CHECK: %[[SHR2:.*]] = arith.shrui %[[A3]], %[[SHR_CST2]] : vector<3xi16>
+ // CHECK: %[[V3:.*]] = vector.shuffle %[[V2]], %[[SHR2]] [0, 1, 2, 9, 3, 4, 5, 10, 6, 7, 8, 11] : vector<9xi16>, vector<3xi16>
// CHECK: %[[RES:.*]] = arith.extui %[[V3]] : vector<12xi16> to vector<12xi32>
// CHEKC: return %[[RES]] : vector<12xi32>
%0 = vector.bitcast %a : vector<3xi16> to vector<12xi4>
More information about the Mlir-commits
mailing list