[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for better low-precision bitcast… (PR #66387)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 18 05:38:47 PDT 2023


================
@@ -155,6 +164,280 @@ struct ConvertVectorTransferRead final
 };
 } // end anonymous namespace
 
+//===----------------------------------------------------------------------===//
+// RewriteBitCastOfTruncI
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Helper struct to keep track of the provenance of a contiguous set of bits
+/// in a source vector.
+struct SourceElementRange {
+  /// The index of the source vector element that contributes bits to *this.
+  int64_t sourceElementIdx;
+  /// The range of bits in the source vector element that contribute to *this.
+  int64_t sourceBitBegin;
+  int64_t sourceBitEnd;
+};
+
+struct SourceElementRangeList : public SmallVector<SourceElementRange> {
+  /// Given the index of a SourceElementRange in the SourceElementRangeList,
+  /// compute the amount of bits that need to be shifted to the left to get the
+  /// bits in their final location. This shift amount is simply the sum of the
+  /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
+  /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
+  int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
+    int64_t res = 0;
+    for (int64_t i = 0; i < shuffleIdx; ++i)
+      res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
+    return res;
+  }
+};
+
+/// Helper struct to enumerate the source elements and bit ranges that are
+/// involved in a bitcast operation.
+/// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
+/// any 1-D vector shape and any source/target bitwidths.
+/// This creates and holds a mapping of the form:
+/// [dstVectorElementJ] ==
+///    [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ]
+/// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as:
+///   [0] = {0, [0-8)}
+///   [1] = {0, [8-16)}
+///   [2] = {0, [16-24)}
+/// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as:
+///   [0] = {0, [0, 10)}, {1, [0, 5)}
+///   [1] = {1, [5, 10)}, {2, [0, 10)}
+struct BitCastBitsEnumerator {
+  BitCastBitsEnumerator(VectorType sourceVectorType,
+                        VectorType targetVectorType);
+
+  int64_t getMaxNumberOfEntries() {
+    int64_t numVectors = 0;
+    for (const auto &l : sourceElementRanges)
+      numVectors = std::max(numVectors, (int64_t)l.size());
+    return numVectors;
+  }
+
+  VectorType sourceVectorType;
+  VectorType targetVectorType;
+  SmallVector<SourceElementRangeList> sourceElementRanges;
+};
+
+} // namespace
+
+static raw_ostream &operator<<(raw_ostream &os,
+                               const SmallVector<SourceElementRangeList> &vec) {
+  for (const auto &l : vec) {
+    for (auto it : llvm::enumerate(l)) {
+      os << "{ " << it.value().sourceElementIdx << ": b@["
+         << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
+         << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
+    }
+    os << "\n";
+  }
+  return os;
+}
+
+BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
+                                             VectorType targetVectorType)
+    : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
+
+  assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() &&
+         "requires -D non-scalable vector type");
+  assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
+         "requires -D non-scalable vector type");
+  int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
+  int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
+  LDBG("sourceVectorType: " << sourceVectorType);
+
+  int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
+  int64_t mostMinorTargetDim = targetVectorType.getShape().back();
+  LDBG("targetVectorType: " << targetVectorType);
+
+  int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
+  assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
+         "source and target bitwidths must match");
+
+  // Prepopulate one source element range per target element.
+  sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
+  for (int64_t resultBit = 0; resultBit < bitwidth;) {
+    int64_t resultElement = resultBit / targetBitWidth;
+    int64_t resultBitInElement = resultBit % targetBitWidth;
+    int64_t sourceElementIdx = resultBit / sourceBitWidth;
+    int64_t sourceBitInElement = resultBit % sourceBitWidth;
+    int64_t step = std::min(sourceBitWidth - sourceBitInElement,
+                            targetBitWidth - resultBitInElement);
+    sourceElementRanges[resultElement].push_back(
+        {sourceElementIdx, sourceBitInElement, sourceBitInElement + step});
+    resultBit += step;
+  }
+}
+
+namespace {
+/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
+/// advantage of high-level information to avoid leaving LLVM to scramble with
+/// peephole optimizations.
+
+// BitCastBitsEnumerator encodes for each element of the target vector the
+// provenance of the bits in the source vector. We can "transpose" this
+// information to build a sequence of shuffles and bitwise ops that will
+// produce the desired result.
+//
+// Let's take the following motivating example to explain the algorithm:
+// ```
+//   %0 = arith.trunci %a : vector<32xi64> to vector<32xi5>
+//   %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
+// ```
+//
+// BitCastBitsEnumerator contains the following information:
+// ```
+//   { 0: b@[0..5) lshl: 0}{1: b@[0..3) lshl: 5 }
+//   { 1: b@[3..5) lshl: 0}{2: b@[0..5) lshl: 2}{3: b@[0..1) lshl: 7 }
+//   { 3: b@[1..5) lshl: 0}{4: b@[0..4) lshl: 4 }
+//   { 4: b@[4..5) lshl: 0}{5: b@[0..5) lshl: 1}{6: b@[0..2) lshl: 6 }
+//   { 6: b@[2..5) lshl: 0}{7: b@[0..5) lshl: 3 }
+//   { 8: b@[0..5) lshl: 0}{9: b@[0..3) lshl: 5 }
+//   { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7 }
+//   { 11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4 }
+//   { 12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6 }
+//   { 14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
+//   { 16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
+//   { 17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
+//   { 19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
+//   { 20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1 }{22: b@[0..2) lshl: 6}
+//   { 22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3 }
+//   { 24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5 }
+//   { 25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7 }
+//   { 27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
+//   { 28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
+//   { 30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3 }
+// ```
+//
+// In the above, each row represents one target vector element and each
+// column represents one bit contribution from a source vector element.
+// The algorithm creates vector.shuffle operations (in this case there are 3
+// shuffles (i.e. the max number of columns in BitCastBitsEnumerator), as
+// follows:
+//   1. for each vector.shuffle, collect the source vectors that participate in
+//     this shuffle. One source vector per target element of the resulting
+//     vector.shuffle. If there is no source element contributing bits for the
+//     current vector.shuffle, take 0 (i.e. row 0 in the above example has only
+//     2 columns).
+//   2. represent the bitrange in the source vector as a mask. If there is no
+//     source element contributing bits for the current vector.shuffle, take 0.
+//   3. shift right by the proper amount to align the source bitrange at
+//     position 0. This is exactly the low end of the bitrange. For instance,
+//     the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
+//     shift right by 3 to get the bits contributed by the source element #1
+//     into position 0.
+//   4. shift left by the proper amount to to align to the desired position in
+//     the result element vector.  For instance, the contribution of the second
+//     source element for the first row needs to be shifted by `5` to form the
+//     first i8 result element.
----------------
qcolombet wrote:

I think a summary of what the whole explanation is describing would be good. At first I didn't get why we needed 3 shuffles.
E.g.,
```
// In other words, this algorithm populates the bits like so:
// ```
//     src bits 0 ... 
// 1st shuffle |xxxxx   |xx      |...
// 2nd shuffle |     xxx|  xxxxx |...
// 3rd shuffle |        |       x|...
```

https://github.com/llvm/llvm-project/pull/66387


More information about the Mlir-commits mailing list