[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for better low-precision bitcast… (PR #66387)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 18 05:38:46 PDT 2023
================
@@ -155,6 +164,280 @@ struct ConvertVectorTransferRead final
};
} // end anonymous namespace
+//===----------------------------------------------------------------------===//
+// RewriteBitCastOfTruncI
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Helper struct to keep track of the provenance of a contiguous set of bits
+/// in a source vector.
+struct SourceElementRange {
+ /// The index of the source vector element that contributes bits to *this.
+ int64_t sourceElementIdx;
+ /// The range of bits in the source vector element that contribute to *this.
+ int64_t sourceBitBegin;
+ int64_t sourceBitEnd;
+};
+
+struct SourceElementRangeList : public SmallVector<SourceElementRange> {
+ /// Given the index of a SourceElementRange in the SourceElementRangeList,
+ /// compute the amount of bits that need to be shifted to the left to get the
+ /// bits in their final location. This shift amount is simply the sum of the
+ /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
+ /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
+ int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
+ int64_t res = 0;
+ for (int64_t i = 0; i < shuffleIdx; ++i)
+ res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
+ return res;
+ }
+};
+
+/// Helper struct to enumerate the source elements and bit ranges that are
+/// involved in a bitcast operation.
+/// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
+/// any 1-D vector shape and any source/target bitwidths.
+/// This creates and holds a mapping of the form:
+/// [dstVectorElementJ] ==
+/// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
+/// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
+/// [0] = {0, [0-8)}
+/// [1] = {0, [8-16)}
+/// [2] = {0, [16-24)}
+/// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
+/// [0] = {0, [0, 10)}, {1, [0, 5)}
+/// [1] = {1, [5, 10)}, {2, [0, 10)}
+struct BitCastBitsEnumerator {
+ BitCastBitsEnumerator(VectorType sourceVectorType,
+ VectorType targetVectorType);
+
+ int64_t getMaxNumberOfEntries() {
+ int64_t numVectors = 0;
+ for (const auto &l : sourceElementRanges)
+ numVectors = std::max(numVectors, (int64_t)l.size());
+ return numVectors;
+ }
+
+ VectorType sourceVectorType;
+ VectorType targetVectorType;
+ SmallVector<SourceElementRangeList> sourceElementRanges;
+};
+
+} // namespace
+
+static raw_ostream &operator<<(raw_ostream &os,
+ const SmallVector<SourceElementRangeList> &vec) {
+ for (const auto &l : vec) {
+ for (auto it : llvm::enumerate(l)) {
+ os << "{ " << it.value().sourceElementIdx << ": b@["
+ << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
+ << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
+ }
+ os << "\n";
+ }
+ return os;
+}
+
+BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
+ VectorType targetVectorType)
+ : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
+
+ assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
+ "requires -D non-scalable vector type");
+ assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
+ "requires -D non-scalable vector type");
+ int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
+ int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
+ LDBG("sourceVectorType: " << sourceVectorType);
+
+ int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
+ int64_t mostMinorTargetDim = targetVectorType.getShape().back();
+ LDBG("targetVectorType: " << targetVectorType);
+
+ int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
+ assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
+ "source and target bitwidths must match");
+
+ // Prepopulate one source element range per target element.
+ sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
+ for (int64_t resultBit = 0; resultBit < bitwidth;) {
+ int64_t resultElement = resultBit / targetBitWidth;
+ int64_t resultBitInElement = resultBit % targetBitWidth;
+ int64_t sourceElementIdx = resultBit / sourceBitWidth;
+ int64_t sourceBitInElement = resultBit % sourceBitWidth;
+ int64_t step = std::min(sourceBitWidth - sourceBitInElement,
+ targetBitWidth - resultBitInElement);
+ sourceElementRanges[resultElement].push_back(
+ {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
+ resultBit += step;
+ }
+}
+
+namespace {
+/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
+/// advantage of high-level information to avoid leaving LLVM to scramble with
+/// peephole optimizations.
+
+// BitCastBitsEnumerator encodes for each element of the target vector the
+// provenance of the bits in the source vector. We can "transpose" this
+// information to build a sequence of shuffles and bitwise ops that will
+// produce the desired result.
+//
+// Let's take the following motivating example to explain the algorithm:
+// ```
+// %0 = arith.trunci %a : vector<32xi64> to vector<32xi5>
+// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
+// ```
+//
+// BitCastBitsEnumerator contains the following information:
+// ```
+// { 0: b@[0..5) lshl: 0}{1: b@[0..3) lshl: 5 }
+// { 1: b@[3..5) lshl: 0}{2: b@[0..5) lshl: 2}{3: b@[0..1) lshl: 7 }
+// { 3: b@[1..5) lshl: 0}{4: b@[0..4) lshl: 4 }
+// { 4: b@[4..5) lshl: 0}{5: b@[0..5) lshl: 1}{6: b@[0..2) lshl: 6 }
+// { 6: b@[2..5) lshl: 0}{7: b@[0..5) lshl: 3 }
+// { 8: b@[0..5) lshl: 0}{9: b@[0..3) lshl: 5 }
+// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7 }
+// { 11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4 }
+// { 12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6 }
+// { 14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
+// { 16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
+// { 17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
+// { 19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
+// { 20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1 }{22: b@[0..2) lshl: 6}
+// { 22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3 }
+// { 24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5 }
+// { 25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7 }
+// { 27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
+// { 28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
+// { 30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3 }
+// ```
+//
+// In the above, each row represents one target vector element and each
+// column represents one bit contribution from a source vector element.
+// The algorithm creates vector.shuffle operations (in this case there are 3
+// shuffles (i.e. the max number of columns in BitCastBitsEnumerator), as
+// follows:
+// 1. for each vector.shuffle, collect the source vectors that participate in
+// this shuffle. One source vector per target element of the resulting
+// vector.shuffle. If there is no source element contributing bits for the
+// current vector.shuffle, take 0 (i.e. row 0 in the above example has only
+// 2 columns).
+// 2. represent the bitrange in the source vector as a mask. If there is no
+// source element contributing bits for the current vector.shuffle, take 0.
+// 3. shift right by the proper amount to align the source bitrange at
+// position 0. This is exactly the low end of the bitrange. For instance,
+// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
+// shift right by 3 to get the bits contributed by the source element #1
+// into position 0.
+// 4. shift left by the proper amount to to align to the desired position in
+// the result element vector. For instance, the contribution of the second
+// source element for the first row needs to be shifted by `5` to form the
+// first i8 result element.
+// Eventually, we end up building the sequence
+// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update the
+// result vector (i.e. the `shiftright -> shiftleft -> or` part) with the bits
+// extracted from the source vector (i.e. the `shuffle -> and` part).
+struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
+ PatternRewriter &rewriter) const override {
+ // The source must be a trunc op.
+ auto truncOp =
+ bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
+ if (!truncOp)
+ return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
+
+ VectorType targetVectorType = bitCastOp.getResultVectorType();
+ if (targetVectorType.getRank() != 1 || targetVectorType.isScalable())
+ return rewriter.notifyMatchFailure(bitCastOp, "scalable or >1-D vector");
+ // TODO: consider relaxing this restriction in the future if we find ways
+ // to really work with subbyte elements across the MLIR/LLVM boundary.
+ int64_t resultBitwidth = targetVectorType.getElementTypeBitWidth();
+ if (resultBitwidth % 8 != 0)
+ return rewriter.notifyMatchFailure(bitCastOp, "bitwidth is not k * 8");
+
+ VectorType sourceVectorType = bitCastOp.getSourceVectorType();
+ BitCastBitsEnumerator be(sourceVectorType, targetVectorType);
+ LDBG("\n" << be.sourceElementRanges);
+
+ Value initialValue = truncOp.getIn();
+ auto initalVectorType = initialValue.getType().cast<VectorType>();
+ auto initalElementType = initalVectorType.getElementType();
+ auto initalElementBitWidth = initalElementType.getIntOrFloatBitWidth();
+
+ Value res;
+ for (int64_t shuffleIdx = 0, e = be.getMaxNumberOfEntries(); shuffleIdx < e;
+ ++shuffleIdx) {
+ SmallVector<int64_t> shuffles;
+ SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
+
+ // Create the attribute quantities for the shuffle / mask / shift ops.
+ for (auto &l : be.sourceElementRanges) {
+ int64_t sourceElementIdx = (shuffleIdx < (int64_t)l.size())
----------------
qcolombet wrote:
Nit: Put this `shuffleIdx < (int64_t)l.size()` in a variable and reuse .
https://github.com/llvm/llvm-project/pull/66387
More information about the Mlir-commits
mailing list