[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for better low-precision bitcast… (PR #66387)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 15 08:10:14 PDT 2023
================
@@ -155,6 +164,221 @@ struct ConvertVectorTransferRead final
};
} // end anonymous namespace
+//===----------------------------------------------------------------------===//
+// RewriteBitCastOfTruncI
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Helper struct to keep track of the provenance of a contiguous set of bits
+/// in a source vector.
+struct SourceElementRange {
+ int64_t sourceElement;
+ int64_t sourceBitBegin;
+ int64_t sourceBitEnd;
+};
+
+struct SourceElementRangeList : public SmallVector<SourceElementRange> {
+ /// Given the index of a SourceElementRange in the SourceElementRangeList,
+ /// compute the amount of bits that need to be shifted to the left to get the
+ /// bits in their final location. This shift amount is simply the sum of the
+ /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always
+ /// the LSBs, the bits of `shuffleIdx = ` come next, etc).
+ int64_t computeLeftShiftAmount(int64_t shuffleIdx) const {
+ int64_t res = 0;
+ for (int64_t i = 0; i < shuffleIdx; ++i)
+ res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin;
+ return res;
+ }
+};
+
+/// Helper struct to enumerate the source elements and bit ranges that are
+/// involved in a bitcast operation.
+/// This allows rewriting a vector.bitcast into shuffles and bitwise ops for
+/// any 1-D vector shape and any source/target bitwidths.
+struct BitCastBitsEnumerator {
+ BitCastBitsEnumerator(VectorType sourceVectorType,
+ VectorType targetVectorType);
+
+ int64_t getMaxNumberOfEntries() {
+ int64_t numVectors = 0;
+ for (const auto &l : sourceElementRanges)
+ numVectors = std::max(numVectors, (int64_t)l.size());
+ return numVectors;
+ }
+
+ VectorType sourceVectorType;
+ VectorType targetVectorType;
+ SmallVector<SourceElementRangeList> sourceElementRanges;
+};
+
+} // namespace
+
+static raw_ostream &operator<<(raw_ostream &os,
+ const SmallVector<SourceElementRangeList> &vec) {
+ for (const auto &l : vec) {
+ for (auto it : llvm::enumerate(l)) {
+ os << "{ " << it.value().sourceElement << ": b@["
+ << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd
+ << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } ";
+ }
+ os << "\n";
+ }
+ return os;
+}
+
+BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
+ VectorType targetVectorType)
+ : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) {
+
+ assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() &&
+ "requires -D non-scalable vector type");
+ int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
+ int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
+ LDBG("sourceVectorType: " << sourceVectorType);
+
+ int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth();
+ int64_t mostMinorTargetDim = targetVectorType.getShape().back();
+ LDBG("targetVectorType: " << targetVectorType);
+
+ int64_t bitwidth = targetBitWidth * mostMinorTargetDim;
+ assert(bitwidth == sourceBitWidth * mostMinorSourceDim &&
+ "source and target bitwidths must match");
+
+ // Prepopulate one source element range per target element.
+ sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim);
+ for (int64_t resultBit = 0; resultBit < bitwidth;) {
+ int64_t resultElement = resultBit / targetBitWidth;
+ int64_t resultBitInElement = resultBit % targetBitWidth;
+ int64_t sourceElement = resultBit / sourceBitWidth;
+ int64_t sourceBitInElement = resultBit % sourceBitWidth;
+ int64_t step = std::min(sourceBitWidth - sourceBitInElement,
+ targetBitWidth - resultBitInElement);
+ sourceElementRanges[resultElement].push_back(
+ {sourceElement, sourceBitInElement, sourceBitInElement + step});
+ resultBit += step;
+ }
+}
+
+namespace {
+/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
+/// advantage of high-level information to avoid leaving LLVM to scramble with
+/// peephole optimizations.
+struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
+ PatternRewriter &rewriter) const override {
+ // The source must be a trunc op.
+ auto truncOp =
+ bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>();
+ if (!truncOp)
+ return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
+
+ VectorType targetVectorType = bitCastOp.getResultVectorType();
+ if (targetVectorType.getRank() != 1 || targetVectorType.isScalable())
+ return rewriter.notifyMatchFailure(bitCastOp, "scalable or >1-D vector");
+ // TODO: consider relaxing this restriction in the future if we find ways to
+ // really work with subbyte elements across the MLIR/LLVM boundary.
+ int64_t resultBitwidth = targetVectorType.getElementTypeBitWidth();
+ if (resultBitwidth % 8 != 0)
+ return rewriter.notifyMatchFailure(bitCastOp, "bitwidth is not k * 8");
+
+ VectorType sourceVectorType = bitCastOp.getSourceVectorType();
+ BitCastBitsEnumerator be(sourceVectorType, targetVectorType);
+ LDBG("\n" << be.sourceElementRanges);
+
+ Value initialValue = truncOp.getIn();
+ auto initalVectorType = initialValue.getType().cast<VectorType>();
+ auto initalElementType = initalVectorType.getElementType();
+ auto initalElementBitWidth = initalElementType.getIntOrFloatBitWidth();
+
+ // BitCastBitsEnumerator encodes for each element of the target vector the
+ // provenance of the bits in the source vector. We can "transpose" this
+ // information to build a sequence of shuffles and bitwise ops that will
+ // produce the desired result.
+ // The algorithm proceeds as follows:
+ // 1. there are as many shuffles as max entries in BitCastBitsEnumerator
+ // 2. for each shuffle:
+ // a. collect the source vectors that participate in this shuffle. One
+ // source vector per target element of the shuffle. If overflow, take 0.
+ // b. the bitrange in the source vector as a mask. If overflow, take 0.
+ // c. the number of bits to shift right to align the source bitrange at
+ // position 0. This is exactly the low end of the bitrange.
+ // d. number of bits to shift left to align to the desired position in
+ // the result element vector.
----------------
qcolombet wrote:
The comment is difficult to read in that `a.` is an action and `b.`-`d.` are a list of what we need.
https://github.com/llvm/llvm-project/pull/66387
More information about the Mlir-commits
mailing list