[Mlir-commits] [mlir] [mlir] Add narrow type emulation conversions (PR #72181)
Han-Chung Wang
llvmlistbot at llvm.org
Wed Nov 15 14:54:56 PST 2023
================
@@ -211,6 +257,150 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertMemRefReinterpretCast
+//===----------------------------------------------------------------------===//
+
+/// Currently there is very limited support for memref::ReinterpretCastOp
+/// conversion. Only the 0 dimensional case is supported.
+struct ConvertMemRefReinterpretCast final
+ : OpConversionPattern<memref::ReinterpretCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ MemRefType newTy =
+ dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+ if (!newTy) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ llvm::formatv("failed to convert memref type: {0}", op.getType()));
+ }
+
+ auto convertedElementType = newTy.getElementType();
+ auto oldElementType = op.getType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ // Only support offset for 0-D subview.
+ if (op.getType().getRank() != 0) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "subview with rank > 0 is not supported");
+ }
+
+ int64_t offset = op.getStaticOffset(0);
+ // Only support static sizes and offsets.
+ if (offset == ShapedType::kDynamic) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "subview with dynamic offset is not supported");
+ }
+
+ int elementsPerByte = dstBits / srcBits;
+ if (offset % elementsPerByte != 0) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ "subview with offset not multiple of elementsPerByte is not "
+ "supported");
+ }
+
+ offset = offset / elementsPerByte;
+
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ op, newTy, *adaptor.getODSOperands(0).begin(), offset,
+ SmallVector<int64_t>{}, op.getStaticStrides());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemrefStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
+ auto convertedElementType = convertedType.getElementType();
+ auto oldElementType = op.getMemRefType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
+ auto dstIntegerType = rewriter.getIntegerType(dstBits);
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ Location loc = op.getLoc();
+ Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
+ adaptor.getValue());
+
+ // Special case 0-rank memref stores. We can compute the mask at compile
+ // time.
+ if (convertedType.getRank() == 0) {
----------------
hanhanW wrote:
The code is becoming larger than I expect.. Let's create two static functions. One for 0D case, and the other for non-0D case. What do you think?
https://github.com/llvm/llvm-project/pull/72181
More information about the Mlir-commits
mailing list