[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for better low-precision bitcast… (PR #66387)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 15 08:10:12 PDT 2023


================
@@ -155,6 +164,221 @@ struct ConvertVectorTransferRead final
 };
 } // end anonymous namespace
 
+//===----------------------------------------------------------------------===//
+// RewriteBitCastOfTruncI
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Helper struct to keep track of the provenance of a contiguous set of bits
+/// in a source vector.
+struct SourceElementRange {
+  int64_t sourceElement;
+  int64_t sourceBitBegin;
+  int64_t sourceBitEnd;
+};
+
+struct SourceElementRangeList : public SmallVector<SourceElementRange> {
+  /// Given the index of a SourceElementRange in the SourceElementRangeList,
+  /// compute the amount of bits that need to be shifted to the left to get the
+  /// bits in their final location. This shift amount is simply the sum of the
+  /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
+  /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
+  int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
+    int64_t res = 0;
+    for (int64_t i = 0; i < shuffleIdx; ++i)
+      res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
+    return res;
+  }
+};
+
+/// Helper struct to enumerate the source elements and bit ranges that are
+/// involved in a bitcast operation.
+/// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
+/// any 1-D vector shape and any source/target bitwidths.
+struct BitCastBitsEnumerator {
+  BitCastBitsEnumerator(VectorType sourceVectorType,
+                        VectorType targetVectorType);
+
+  int64_t getMaxNumberOfEntries() {
+    int64_t numVectors = 0;
+    for (const auto &l : sourceElementRanges)
+      numVectors = std::max(numVectors, (int64_t)l.size());
+    return numVectors;
+  }
+
+  VectorType sourceVectorType;
+  VectorType targetVectorType;
+  SmallVector<SourceElementRangeList> sourceElementRanges;
+};
+
+} // namespace
+
+static raw_ostream &operator<<(raw_ostream &os,
+                               const SmallVector<SourceElementRangeList> &vec) {
+  for (const auto &l : vec) {
+    for (auto it : llvm::enumerate(l)) {
+      os << "{ " << it.value().sourceElement << ": b@["
+         << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
+         << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
+    }
+    os << "\n";
+  }
+  return os;
+}
+
+BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
+                                             VectorType targetVectorType)
+    : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
+
+  assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
+         "requires -D non-scalable vector type");
+  int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
+  int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
+  LDBG("sourceVectorType: " << sourceVectorType);
+
+  int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
+  int64_t mostMinorTargetDim = targetVectorType.getShape().back();
+  LDBG("targetVectorType: " << targetVectorType);
+
+  int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
+  assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
+         "source and target bitwidths must match");
+
+  // Prepopulate one source element range per target element.
+  sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
+  for (int64_t resultBit = 0; resultBit < bitwidth;) {
+    int64_t resultElement = resultBit / targetBitWidth;
+    int64_t resultBitInElement = resultBit % targetBitWidth;
+    int64_t sourceElement = resultBit / sourceBitWidth;
+    int64_t sourceBitInElement = resultBit % sourceBitWidth;
+    int64_t step = std::min(sourceBitWidth - sourceBitInElement,
+                            targetBitWidth - resultBitInElement);
+    sourceElementRanges[resultElement].push_back(
+        {sourceElement, sourceBitInElement, sourceBitInElement + step});
+    resultBit += step;
+  }
+}
+
+namespace {
+/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
+/// advantage of high-level information to avoid leaving LLVM to scramble with
+/// peephole optimizations.
+struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
+                                PatternRewriter &rewriter) const override {
+    // The source must be a trunc op.
+    auto truncOp =
+        bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
+    if (!truncOp)
+      return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
+
+    VectorType targetVectorType = bitCastOp.getResultVectorType();
+    if (targetVectorType.getRank() != 1 || targetVectorType.isScalable())
+      return rewriter.notifyMatchFailure(bitCastOp, "scalable or >1-D vector");
+    // TODO: consider relaxing this restriction in the future if we find ways to
+    // really work with subbyte elements across the MLIR/LLVM boundary.
+    int64_t resultBitwidth = targetVectorType.getElementTypeBitWidth();
+    if (resultBitwidth % 8 != 0)
+      return rewriter.notifyMatchFailure(bitCastOp, "bitwidth is not k * 8");
+
+    VectorType sourceVectorType = bitCastOp.getSourceVectorType();
+    BitCastBitsEnumerator be(sourceVectorType, targetVectorType);
+    LDBG("\n" << be.sourceElementRanges);
+
+    Value initialValue = truncOp.getIn();
+    auto initalVectorType = initialValue.getType().cast<VectorType>();
+    auto initalElementType = initalVectorType.getElementType();
+    auto initalElementBitWidth = initalElementType.getIntOrFloatBitWidth();
+
+    // BitCastBitsEnumerator encodes for each element of the target vector the
+    // provenance of the bits in the source vector. We can "transpose" this
+    // information to build a sequence of shuffles and bitwise ops that will
+    // produce the desired result.
+    // The algorithm proceeds as follows:
+    //   1. there are as many shuffles as max entries in BitCastBitsEnumerator
+    //   2. for each shuffle:
+    //     a. collect the source vectors that participate in this shuffle. One
+    //     source vector per target element of the shuffle. If overflow, take 0.
----------------
qcolombet wrote:

What does it mean "if overflow"?

https://github.com/llvm/llvm-project/pull/66387


More information about the Mlir-commits mailing list