[Mlir-commits] [mlir] [mlir][Vector] Add a rewrite pattern for better low-precision ext(bit… (PR #66648)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 18 07:31:49 PDT 2023
================
@@ -359,93 +487,93 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
if (!truncOp)
return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
+ // Set up the BitCastRewriter and verify the precondition.
+ VectorType sourceVectorType = bitCastOp.getSourceVectorType();
VectorType targetVectorType = bitCastOp.getResultVectorType();
- if (targetVectorType.getRank() != 1 || targetVectorType.isScalable())
- return rewriter.notifyMatchFailure(bitCastOp, "scalable or >1-D vector");
- // TODO: consider relaxing this restriction in the future if we find ways
- // to really work with subbyte elements across the MLIR/LLVM boundary.
- int64_t resultBitwidth = targetVectorType.getElementTypeBitWidth();
- if (resultBitwidth % 8 != 0)
- return rewriter.notifyMatchFailure(bitCastOp, "bitwidth is not k * 8");
+ BitCastRewriter bcr(sourceVectorType, targetVectorType);
+ if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp)))
+ return failure();
- VectorType sourceVectorType = bitCastOp.getSourceVectorType();
- BitCastBitsEnumerator be(sourceVectorType, targetVectorType);
- LDBG("\n" << be.sourceElementRanges);
-
- Value initialValue = truncOp.getIn();
- auto initalVectorType = initialValue.getType().cast<VectorType>();
- auto initalElementType = initalVectorType.getElementType();
- auto initalElementBitWidth = initalElementType.getIntOrFloatBitWidth();
-
- Value res;
- for (int64_t shuffleIdx = 0, e = be.getMaxNumberOfEntries(); shuffleIdx < e;
- ++shuffleIdx) {
- SmallVector<int64_t> shuffles;
- SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
-
- // Create the attribute quantities for the shuffle / mask / shift ops.
- for (auto &srcEltRangeList : be.sourceElementRanges) {
- bool idxContributesBits =
- (shuffleIdx < (int64_t)srcEltRangeList.size());
- int64_t sourceElementIdx =
- idxContributesBits ? srcEltRangeList[shuffleIdx].sourceElementIdx
- : 0;
- shuffles.push_back(sourceElementIdx);
-
- int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
- ? srcEltRangeList[shuffleIdx].sourceBitBegin
- : 0;
- int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
- ? srcEltRangeList[shuffleIdx].sourceBitEnd
- : 0;
- IntegerAttr mask = IntegerAttr::get(
- rewriter.getIntegerType(initalElementBitWidth),
- llvm::APInt::getBitsSet(initalElementBitWidth, bitLo, bitHi));
- masks.push_back(mask);
-
- int64_t shiftRight = bitLo;
- shiftRightAmounts.push_back(IntegerAttr::get(
- rewriter.getIntegerType(initalElementBitWidth), shiftRight));
-
- int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
- shiftLeftAmounts.push_back(IntegerAttr::get(
- rewriter.getIntegerType(initalElementBitWidth), shiftLeft));
- }
-
- // Create vector.shuffle #shuffleIdx.
- auto shuffleOp = rewriter.create<vector::ShuffleOp>(
- bitCastOp.getLoc(), initialValue, initialValue, shuffles);
- // And with the mask.
- VectorType vt = VectorType::Builder(initalVectorType)
- .setDim(initalVectorType.getRank() - 1, masks.size());
- auto constOp = rewriter.create<arith::ConstantOp>(
- bitCastOp.getLoc(), DenseElementsAttr::get(vt, masks));
- Value andValue = rewriter.create<arith::AndIOp>(bitCastOp.getLoc(),
- shuffleOp, constOp);
- // Align right on 0.
- auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
- bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftRightAmounts));
- Value shiftedRight = rewriter.create<arith::ShRUIOp>(
- bitCastOp.getLoc(), andValue, shiftRightConstantOp);
-
- auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
- bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftLeftAmounts));
- Value shiftedLeft = rewriter.create<arith::ShLIOp>(
- bitCastOp.getLoc(), shiftedRight, shiftLeftConstantOp);
-
- res = res ? rewriter.create<arith::OrIOp>(bitCastOp.getLoc(), res,
- shiftedLeft)
- : shiftedLeft;
+ // Perform the rewrite.
+ Value truncValue = truncOp.getIn();
+ auto shuffledElementType =
+ cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
+ Value runningResult;
+ for (const BitCastRewriter ::Metadata &metadata :
+ bcr.precomputeMetadata(shuffledElementType)) {
+ runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
+ runningResult, metadata);
}
- bool narrowing = resultBitwidth <= initalElementBitWidth;
+ // Finalize the rewrite.
+ bool narrowing = targetVectorType.getElementTypeBitWidth() <=
+ shuffledElementType.getIntOrFloatBitWidth();
if (narrowing) {
rewriter.replaceOpWithNewOp<arith::TruncIOp>(
- bitCastOp, bitCastOp.getResultVectorType(), res);
+ bitCastOp, bitCastOp.getResultVectorType(), runningResult);
} else {
rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
- bitCastOp, bitCastOp.getResultVectorType(), res);
+ bitCastOp, bitCastOp.getResultVectorType(), runningResult);
}
+
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// RewriteExtOfBitCast
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
----------------
qcolombet wrote:
Nit: ext(bitcast)
https://github.com/llvm/llvm-project/pull/66648
More information about the Mlir-commits
mailing list