[Mlir-commits] [mlir] [mlir][MemRef] Extend `memref.subview` sub-byte type emulation support (PR #89488)

Diego Caballero llvmlistbot at llvm.org
Mon Apr 22 10:04:20 PDT 2024


================
@@ -404,29 +387,68 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
 
 /// Emulating narrow ints on subview have limited support, supporting only
 /// static offset and size and stride of 1. Ideally, the subview should be
-/// folded away before running narrow type emulation, and this pattern would
-/// never run. This pattern is mostly used for testing pruposes.
+/// folded away before running narrow type emulation, and this pattern should
+/// only run for cases that can't be folded.
 struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
+  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    MemRefType newTy =
-        dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+    MemRefType newTy = dyn_cast<MemRefType>(
+        getTypeConverter()->convertType(subViewOp.getType()));
     if (!newTy) {
       return rewriter.notifyMatchFailure(
-          op->getLoc(),
-          llvm::formatv("failed to convert memref type: {0}", op.getType()));
+          subViewOp->getLoc(),
+          llvm::formatv("failed to convert memref type: {0}",
+                        subViewOp.getType()));
     }
 
-    // Only support offset for 1-D subview.
-    if (op.getType().getRank() != 1) {
+    Location loc = subViewOp.getLoc();
+    Type convertedElementType = newTy.getElementType();
+    Type oldElementType = subViewOp.getType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = convertedElementType.getIntOrFloatBitWidth();
+    if (dstBits % srcBits != 0)
       return rewriter.notifyMatchFailure(
-          op->getLoc(), "subview with rank > 1 is not supported");
+          subViewOp, "only dstBits % srcBits == 0 supported");
+
+    // Only support stride of 1.
+    if (llvm::any_of(subViewOp.getStaticStrides(),
+                     [](int64_t stride) { return stride != 1; })) {
+      return rewriter.notifyMatchFailure(subViewOp->getLoc(),
+                                         "stride != 1 is not supported");
+    }
+
+    auto sizes = subViewOp.getStaticSizes();
+    int64_t lastOffset = subViewOp.getStaticOffsets().back();
+    // Only support static sizes and offsets.
+    if (llvm::any_of(
+            sizes, [](int64_t size) { return size == ShapedType::kDynamic; }) ||
+        lastOffset == ShapedType::kDynamic) {
+      return rewriter.notifyMatchFailure(
+          subViewOp->getLoc(), "dynamic size or offset is not supported");
     }
 
-    return convertCastingOp(rewriter, adaptor, op, newTy);
+    // Transform the offsets, sizes and strides according to the emulation.
+    auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+        loc, subViewOp.getViewSource());
+
+    OpFoldResult linearizedIndices;
+    auto strides = stridedMetadata.getConstifiedMixedStrides();
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, srcBits, dstBits,
+            stridedMetadata.getConstifiedMixedOffset(),
+            subViewOp.getMixedSizes(), strides,
+            getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
+                           rewriter));
+
+    rewriter.replaceOpWithNewOp<memref::SubViewOp>(
----------------
dcaballe wrote:

Let me add a test to better understand what happens... 

https://github.com/llvm/llvm-project/pull/89488


More information about the Mlir-commits mailing list