[Mlir-commits] [mlir] [mlir] Add narrow type emulation conversions (PR #72181)

Han-Chung Wang llvmlistbot at llvm.org
Wed Nov 15 14:54:57 PST 2023


================
@@ -211,6 +257,150 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertMemRefReinterpretCast
+//===----------------------------------------------------------------------===//
+
+/// Currently there is very limited support for memref::ReinterpretCastOp
+/// conversion. Only the 0 dimensional case is supported.
+struct ConvertMemRefReinterpretCast final
+    : OpConversionPattern<memref::ReinterpretCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    MemRefType newTy =
+        dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+    if (!newTy) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          llvm::formatv("failed to convert memref type: {0}", op.getType()));
+    }
+
+    auto convertedElementType = newTy.getElementType();
+    auto oldElementType = op.getType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = convertedElementType.getIntOrFloatBitWidth();
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+
+    // Only support offset for 0-D subview.
+    if (op.getType().getRank() != 0) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "subview with rank > 0 is not supported");
+    }
+
+    int64_t offset = op.getStaticOffset(0);
+    // Only support static sizes and offsets.
+    if (offset == ShapedType::kDynamic) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "subview with dynamic offset is not supported");
+    }
+
+    int elementsPerByte = dstBits / srcBits;
+    if (offset % elementsPerByte != 0) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          "subview with offset not multiple of elementsPerByte is not "
+          "supported");
+    }
+
+    offset = offset / elementsPerByte;
+
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        op, newTy, *adaptor.getODSOperands(0).begin(), offset,
+        SmallVector<int64_t>{}, op.getStaticStrides());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemrefStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
+    auto convertedElementType = convertedType.getElementType();
+    auto oldElementType = op.getMemRefType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = convertedElementType.getIntOrFloatBitWidth();
+    auto dstIntegerType = rewriter.getIntegerType(dstBits);
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+
+    Location loc = op.getLoc();
+    Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
+                                                          adaptor.getValue());
+
+    // Special case 0-rank memref stores. We can compute the mask at compile
+    // time.
+    if (convertedType.getRank() == 0) {
+      // Shift extended value to be left aligned
+      auto shiftValAttr =
+          rewriter.getIntegerAttr(dstIntegerType, dstBits - srcBits);
+      Value shiftVal =
+          rewriter.create<arith::ConstantOp>(loc, dstIntegerType, shiftValAttr)
+              .getResult();
+      Value alignedVal =
+          rewriter.create<arith::ShLIOp>(loc, extendedInput, shiftVal)
+              .getResult();
+      // Create mask to clear destination bits
+      auto writeMaskValAttr = rewriter.getIntegerAttr(
+          dstIntegerType, (1 << (dstBits - srcBits)) - 1);
+      Value writeMask =
+          rewriter
+              .create<arith::ConstantOp>(loc, dstIntegerType, writeMaskValAttr)
+              .getResult();
+
+      // Clear destination bits
+      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
+                                           writeMask, adaptor.getMemref(),
+                                           ValueRange{});
+      // Write srcs bits to destination
+      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
+                                           alignedVal, adaptor.getMemref(),
+                                           ValueRange{});
+      rewriter.eraseOp(op);
----------------
hanhanW wrote:

Can we use `replaceOp` instead? That's more common in pattern-rewrite.

https://github.com/llvm/llvm-project/pull/72181


More information about the Mlir-commits mailing list