[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for better low-precision ext(bit… (PR #65774)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Wed Sep 13 01:39:53 PDT 2023
================
@@ -155,6 +170,256 @@ struct ConvertVectorTransferRead final
};
} // end anonymous namespace
+//===----------------------------------------------------------------------===//
+// RewriteExtOfBitCast
+//===----------------------------------------------------------------------===//
+
+/// Create a vector of bit masks: `idx .. idx + step - 1` and broadcast it
+/// `numOccurrences` times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0xc, 0x3, 0xc, 0x3, 0xc, 0x3].
+static SmallVector<Attribute> computeExtOfBitCastMasks(MLIRContext *ctx,
+ int64_t bitwidth,
+ int64_t step,
+ int64_t numOccurrences) {
+ assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+ IntegerType interimIntType = IntegerType::get(ctx, bitwidth);
+ SmallVector<Attribute> tmpMasks;
+ tmpMasks.reserve(bitwidth / step);
+ // Create a vector of bit masks: `idx .. idx + step - 1`.
+ for (int64_t idx = 0; idx < bitwidth; idx += step) {
+ LDBG("Mask bits " << idx << " .. " << idx + step - 1 << " out of "
+ << bitwidth);
+ IntegerAttr mask = IntegerAttr::get(
+ interimIntType, llvm::APInt::getBitsSet(bitwidth, idx, idx + step));
+ tmpMasks.push_back(mask);
+ }
+ // Replicate the vector of bit masks to the desired size.
+ SmallVector<Attribute> masks;
+ masks.reserve(numOccurrences * tmpMasks.size());
+ for (int64_t idx = 0; idx < numOccurrences; ++idx)
+ llvm::append_range(masks, tmpMasks);
+ return masks;
+}
+
+/// Create a vector of bit shifts by `k * idx` and broadcast it `numOccurrences`
+/// times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0x0, 0x2, 0x0, 0x2, 0x0, 0x2].
+static SmallVector<Attribute>
+computeExtOfBitCastShifts(MLIRContext *ctx, int64_t bitwidth, int64_t step,
+ int64_t numOccurrences) {
+ assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+ IntegerType interimIntType = IntegerType::get(ctx, bitwidth);
+ SmallVector<Attribute> tmpShifts;
+ for (int64_t idx = 0; idx < bitwidth; idx += step) {
+ IntegerAttr shift = IntegerAttr::get(interimIntType, idx);
+ tmpShifts.push_back(shift);
+ }
+ SmallVector<Attribute> shifts;
+ for (int64_t idx = 0; idx < numOccurrences; ++idx)
+ llvm::append_range(shifts, tmpShifts);
+ return shifts;
+}
+
+/// Create a vector of bit shuffles: `numOccurrences * idx` and broadcast it
+/// `bitwidth/step` times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0x0, 0x1, 0x0, 0x1, 0x0, 0x1].
+static SmallVector<int64_t>
+computeExtOfBitCastShuffles(MLIRContext *ctx, int64_t bitwidth, int64_t step,
+ int64_t numOccurrences) {
+ assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+ SmallVector<int64_t> shuffles;
+ int64_t n = floorDiv(bitwidth, step);
+ for (int64_t idx = 0; idx < numOccurrences; ++idx)
+ llvm::append_range(shuffles, SmallVector<int64_t>(n, idx));
+ return shuffles;
+}
+
+/// Compute the intermediate vector type, its elemental type must be an integer
+/// with bitwidth that:
+/// 1. is smaller than 64 (TODO: in the future we may want target-specific
+/// control).
+/// 2. divides sourceBitWidth * mostMinorSourceDim
+static int64_t computeExtOfBitCastBitWidth(int64_t sourceBitWidth,
+ int64_t mostMinorSourceDim,
+ int64_t targetBitWidth) {
+ for (int64_t mult : {32, 16, 8, 4, 2, 1}) {
+ int64_t interimBitWidth =
+ std::lcm(mult, std::lcm(sourceBitWidth, targetBitWidth));
+ if (interimBitWidth > 64)
+ continue;
+ if ((sourceBitWidth * mostMinorSourceDim) % interimBitWidth != 0)
+ continue;
+ return interimBitWidth;
+ }
+ return 0;
+}
+
+FailureOr<Value>
+mlir::vector::rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
+ vector::BitCastOp bitCastOp,
+ vector::BroadcastOp maybeBroadcastOp) {
+ assert(
+ (llvm::isa<arith::ExtSIOp>(extOp) || llvm::isa<arith::ExtUIOp>(extOp)) &&
+ "unsupported op");
+
+ // The bitcast op is the load-bearing part, capture the source and bitCast
+ // types as well as bitwidth and most minor dimension.
+ VectorType sourceVectorType = bitCastOp.getSourceVectorType();
+ int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
+ int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
+ LDBG("sourceVectorType: " << sourceVectorType);
+
+ VectorType bitCastVectorType = bitCastOp.getResultVectorType();
+ int64_t targetBitWidth = bitCastVectorType.getElementTypeBitWidth();
+ LDBG("bitCastVectorType: " << bitCastVectorType);
+
+ int64_t interimBitWidth = computeExtOfBitCastBitWidth(
+ sourceBitWidth, mostMinorSourceDim, targetBitWidth);
+ LDBG("interimBitWidth: " << interimBitWidth);
+ if (!interimBitWidth) {
+ return rewriter.notifyMatchFailure(
+ extOp, "heuristic could not find a reasonable interim bitwidth");
+ }
+ if (sourceBitWidth == interimBitWidth || targetBitWidth == interimBitWidth) {
+ return rewriter.notifyMatchFailure(
+ extOp, "interim bitwidth is equal to source or target, nothing to do");
+ }
+
+ int64_t interimMostMinorDim =
+ sourceBitWidth * mostMinorSourceDim / interimBitWidth;
+ LDBG("interimMostMinorDim: " << interimMostMinorDim);
+
+ Location loc = extOp->getLoc();
+ MLIRContext *ctx = extOp->getContext();
+
+ VectorType interimVectorType =
+ VectorType::Builder(sourceVectorType)
+ .setDim(sourceVectorType.getRank() - 1, interimMostMinorDim)
+ .setElementType(IntegerType::get(ctx, interimBitWidth));
+ LDBG("interimVectorType: " << interimVectorType);
+
+ IntegerType interimIntType = IntegerType::get(ctx, interimBitWidth);
+ VectorType vt =
+ VectorType::Builder(bitCastVectorType).setElementType(interimIntType);
+
+ // Rewrite the original bitcast to the interim vector type and shuffle to
+ // broadcast to the desired size.
+ auto newBitCastOp = rewriter.create<vector::BitCastOp>(loc, interimVectorType,
+ bitCastOp.getSource());
+ SmallVector<int64_t> shuffles = computeExtOfBitCastShuffles(
+ ctx, interimBitWidth, targetBitWidth, interimMostMinorDim);
+ auto shuffleOp = rewriter.create<vector::ShuffleOp>(loc, newBitCastOp,
+ newBitCastOp, shuffles);
+ LDBG("shuffle: " << shuffleOp);
+
+ // Compute the constants for masking.
+ SmallVector<Attribute> masks = computeExtOfBitCastMasks(
+ ctx, interimBitWidth, targetBitWidth, interimMostMinorDim);
+ auto maskConstantOp = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(vt, masks));
+ LDBG("maskConstant: " << maskConstantOp);
+ auto andOp = rewriter.create<arith::AndIOp>(loc, shuffleOp, maskConstantOp);
+ LDBG("andOp: " << andOp);
+
+ // Preserve the intermediate type: this may have serious consequences on the
+ // backend's ability to generate efficient vector operations.
+ // For instance on x86, converting to f16 without going through i32 has severe
+ // performance implications.
+ // As a consequence, this pattern must preserve the original behavior.
+ VectorType resultType = cast<VectorType>(extOp->getResultTypes().front());
+ Type resultElementType = getElementTypeOrSelf(resultType);
+ SmallVector<Attribute> shifts = computeExtOfBitCastShifts(
+ ctx, interimBitWidth, targetBitWidth, interimMostMinorDim);
+ auto shiftConstantOp = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(vt, shifts));
+ LDBG("shiftConstant: " << shiftConstantOp);
+ Value newResult =
+ TypeSwitch<Operation *, Value>(extOp)
+ .template Case<arith::ExtSIOp>([&](arith::ExtSIOp op) {
+ Value shifted =
+ rewriter.create<arith::ShRSIOp>(loc, andOp, shiftConstantOp);
+ auto vt = shifted.getType().cast<VectorType>();
+ VectorType extVt =
+ VectorType::Builder(vt).setElementType(resultElementType);
+ Operation *res =
+ (resultElementType.getIntOrFloatBitWidth() > interimBitWidth)
+ ? rewriter.create<arith::ExtSIOp>(loc, extVt, shifted)
+ : rewriter.create<arith::TruncIOp>(loc, extVt, shifted);
+ return res->getResult(0);
+ })
+ .template Case<arith::ExtUIOp>([&](arith::ExtUIOp op) {
+ Value shifted =
+ rewriter.create<arith::ShRUIOp>(loc, andOp, shiftConstantOp);
+ auto vt = shifted.getType().cast<VectorType>();
+ VectorType extVt =
+ VectorType::Builder(vt).setElementType(resultElementType);
+ Operation *res =
+ (resultElementType.getIntOrFloatBitWidth() > interimBitWidth)
+ ? rewriter.create<arith::ExtUIOp>(loc, extVt, shifted)
+ : rewriter.create<arith::TruncIOp>(loc, extVt, shifted);
+ return res->getResult(0);
+ })
+ .Default([&](Operation *op) {
+ llvm_unreachable("unexpected op type");
+ return nullptr;
+ });
+
+ if (maybeBroadcastOp) {
----------------
ftynse wrote:
Nit: return early instead.
https://github.com/llvm/llvm-project/pull/65774
More information about the Mlir-commits
mailing list