[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for better low-precision bitcast… (PR #66387)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 15 08:10:13 PDT 2023


================
@@ -155,6 +164,221 @@ struct ConvertVectorTransferRead final
 };
 } // end anonymous namespace
 
+//===----------------------------------------------------------------------===//
+// RewriteBitCastOfTruncI
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Helper struct to keep track of the provenance of a contiguous set of bits
+/// in a source vector.
+struct SourceElementRange {
+  int64_t sourceElement;
+  int64_t sourceBitBegin;
+  int64_t sourceBitEnd;
+};
+
+struct SourceElementRangeList : public SmallVector<SourceElementRange> {
+  /// Given the index of a SourceElementRange in the SourceElementRangeList,
+  /// compute the amount of bits that need to be shifted to the left to get the
+  /// bits in their final location. This shift amount is simply the sum of the
+  /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
+  /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
+  int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
+    int64_t res = 0;
+    for (int64_t i = 0; i < shuffleIdx; ++i)
+      res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
+    return res;
+  }
+};
+
+/// Helper struct to enumerate the source elements and bit ranges that are
+/// involved in a bitcast operation.
+/// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
+/// any 1-D vector shape and any source/target bitwidths.
+struct BitCastBitsEnumerator {
+  BitCastBitsEnumerator(VectorType sourceVectorType,
+                        VectorType targetVectorType);
+
+  int64_t getMaxNumberOfEntries() {
+    int64_t numVectors = 0;
+    for (const auto &l : sourceElementRanges)
+      numVectors = std::max(numVectors, (int64_t)l.size());
+    return numVectors;
+  }
+
+  VectorType sourceVectorType;
+  VectorType targetVectorType;
+  SmallVector<SourceElementRangeList> sourceElementRanges;
+};
+
+} // namespace
+
+static raw_ostream &operator<<(raw_ostream &os,
+                               const SmallVector<SourceElementRangeList> &vec) {
+  for (const auto &l : vec) {
+    for (auto it : llvm::enumerate(l)) {
+      os << "{ " << it.value().sourceElement << ": b@["
+         << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
+         << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
+    }
+    os << "\n";
+  }
+  return os;
+}
+
+BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
+                                             VectorType targetVectorType)
+    : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
+
+  assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
+         "requires -D non-scalable vector type");
----------------
qcolombet wrote:

I believe we want the same assert for `sourceVectorType`.

The way this helper is currently called work, but if something else uses this API in the future, the developer may miss that both input types must have the same rank. (We said in the struct comments that this was only for 1-D vectors, but that doesn't prevent people form missing it!)

https://github.com/llvm/llvm-project/pull/66387


More information about the Mlir-commits mailing list