[Mlir-commits] [mlir] [mlir][MemRef] Extend `memref.subview` sub-byte type emulation support (PR #89488)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 22 09:09:05 PDT 2024

@@ -404,29 +387,68 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
 /// Emulating narrow ints on subview have limited support, supporting only
 /// static offset and size and stride of 1. Ideally, the subview should be
-/// folded away before running narrow type emulation, and this pattern would
-/// never run. This pattern is mostly used for testing pruposes.
+/// folded away before running narrow type emulation, and this pattern should
+/// only run for cases that can't be folded.
 struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
   using OpConversionPattern::OpConversionPattern;
-  matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
+  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    MemRefType newTy =
-        dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+    MemRefType newTy = dyn_cast<MemRefType>(
+        getTypeConverter()->convertType(subViewOp.getType()));
     if (!newTy) {
       return rewriter.notifyMatchFailure(
-          op->getLoc(),
-          llvm::formatv("failed to convert memref type: {0}", op.getType()));
+          subViewOp->getLoc(),
+          llvm::formatv("failed to convert memref type: {0}",
+                        subViewOp.getType()));
-    // Only support offset for 1-D subview.
-    if (op.getType().getRank() != 1) {
+    Location loc = subViewOp.getLoc();
+    Type convertedElementType = newTy.getElementType();
+    Type oldElementType = subViewOp.getType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = convertedElementType.getIntOrFloatBitWidth();
+    if (dstBits % srcBits != 0)
       return rewriter.notifyMatchFailure(
-          op->getLoc(), "subview with rank > 1 is not supported");
+          subViewOp, "only dstBits % srcBits == 0 supported");
+    // Only support stride of 1.
+    if (llvm::any_of(subViewOp.getStaticStrides(),
+                     [](int64_t stride) { return stride != 1; })) {
+      return rewriter.notifyMatchFailure(subViewOp->getLoc(),
+                                         "stride != 1 is not supported");
+    }
+    auto sizes = subViewOp.getStaticSizes();
+    int64_t lastOffset = subViewOp.getStaticOffsets().back();
+    // Only support static sizes and offsets.
+    if (llvm::any_of(
+            sizes, [](int64_t size) { return size == ShapedType::kDynamic; }) ||
+        lastOffset == ShapedType::kDynamic) {
+      return rewriter.notifyMatchFailure(
+          subViewOp->getLoc(), "dynamic size or offset is not supported");
-    return convertCastingOp(rewriter, adaptor, op, newTy);
+    // Transform the offsets, sizes and strides according to the emulation.
+    auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+        loc, subViewOp.getViewSource());
+    OpFoldResult linearizedIndices;
+    auto strides = stridedMetadata.getConstifiedMixedStrides();
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, srcBits, dstBits,
+            stridedMetadata.getConstifiedMixedOffset(),
+            subViewOp.getMixedSizes(), strides,
+            getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
+                           rewriter));
+    rewriter.replaceOpWithNewOp<memref::SubViewOp>(
MaheshRavishankar wrote:

I dont think just checking the strides of the subview are enough... you have to check the strides of the memref type of the result. Those strides need to be contiguous?


More information about the Mlir-commits mailing list