[Mlir-commits] [mlir] [mlir] Add narrow type emulation conversions (PR #72181)
Han-Chung Wang
llvmlistbot at llvm.org
Wed Nov 15 14:54:57 PST 2023
================
@@ -211,6 +257,150 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertMemRefReinterpretCast
+//===----------------------------------------------------------------------===//
+
+/// Currently there is very limited support for memref::ReinterpretCastOp
+/// conversion. Only the 0 dimensional case is supported.
+struct ConvertMemRefReinterpretCast final
+ : OpConversionPattern<memref::ReinterpretCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ MemRefType newTy =
+ dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+ if (!newTy) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ llvm::formatv("failed to convert memref type: {0}", op.getType()));
+ }
+
+ auto convertedElementType = newTy.getElementType();
+ auto oldElementType = op.getType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ // Only support offset for 0-D subview.
+ if (op.getType().getRank() != 0) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "subview with rank > 0 is not supported");
+ }
+
+ int64_t offset = op.getStaticOffset(0);
+ // Only support static sizes and offsets.
+ if (offset == ShapedType::kDynamic) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "subview with dynamic offset is not supported");
+ }
+
+ int elementsPerByte = dstBits / srcBits;
+ if (offset % elementsPerByte != 0) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ "subview with offset not multiple of elementsPerByte is not "
+ "supported");
+ }
+
+ offset = offset / elementsPerByte;
+
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ op, newTy, *adaptor.getODSOperands(0).begin(), offset,
+ SmallVector<int64_t>{}, op.getStaticStrides());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemrefStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
+ auto convertedElementType = convertedType.getElementType();
+ auto oldElementType = op.getMemRefType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
+ auto dstIntegerType = rewriter.getIntegerType(dstBits);
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ Location loc = op.getLoc();
+ Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
+ adaptor.getValue());
+
+ // Special case 0-rank memref stores. We can compute the mask at compile
+ // time.
+ if (convertedType.getRank() == 0) {
+ // Shift extended value to be left aligned
+ auto shiftValAttr =
+ rewriter.getIntegerAttr(dstIntegerType, dstBits - srcBits);
+ Value shiftVal =
+ rewriter.create<arith::ConstantOp>(loc, dstIntegerType, shiftValAttr)
+ .getResult();
+ Value alignedVal =
+ rewriter.create<arith::ShLIOp>(loc, extendedInput, shiftVal)
+ .getResult();
+ // Create mask to clear destination bits
+ auto writeMaskValAttr = rewriter.getIntegerAttr(
+ dstIntegerType, (1 << (dstBits - srcBits)) - 1);
+ Value writeMask =
+ rewriter
+ .create<arith::ConstantOp>(loc, dstIntegerType, writeMaskValAttr)
+ .getResult();
+
+ // Clear destination bits
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
+ writeMask, adaptor.getMemref(),
+ ValueRange{});
+ // Write srcs bits to destination
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
+ alignedVal, adaptor.getMemref(),
+ ValueRange{});
+ rewriter.eraseOp(op);
----------------
hanhanW wrote:
Can we use `replaceOp` instead? That's more common in pattern-rewrite.
https://github.com/llvm/llvm-project/pull/72181
More information about the Mlir-commits
mailing list