[Mlir-commits] [mlir] [mlir][MemRef] Extend `memref.subview` sub-byte type emulation support (PR #89488)

Diego Caballero llvmlistbot at llvm.org
Fri Apr 26 07:17:33 PDT 2024


================
@@ -404,29 +387,68 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
 
 /// Emulating narrow ints on subview have limited support, supporting only
 /// static offset and size and stride of 1. Ideally, the subview should be
-/// folded away before running narrow type emulation, and this pattern would
-/// never run. This pattern is mostly used for testing pruposes.
+/// folded away before running narrow type emulation, and this pattern should
+/// only run for cases that can't be folded.
 struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
+  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    MemRefType newTy =
-        dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+    MemRefType newTy = dyn_cast<MemRefType>(
+        getTypeConverter()->convertType(subViewOp.getType()));
     if (!newTy) {
       return rewriter.notifyMatchFailure(
-          op->getLoc(),
-          llvm::formatv("failed to convert memref type: {0}", op.getType()));
+          subViewOp->getLoc(),
+          llvm::formatv("failed to convert memref type: {0}",
+                        subViewOp.getType()));
     }
 
-    // Only support offset for 1-D subview.
-    if (op.getType().getRank() != 1) {
+    Location loc = subViewOp.getLoc();
+    Type convertedElementType = newTy.getElementType();
+    Type oldElementType = subViewOp.getType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = convertedElementType.getIntOrFloatBitWidth();
+    if (dstBits % srcBits != 0)
       return rewriter.notifyMatchFailure(
-          op->getLoc(), "subview with rank > 1 is not supported");
+          subViewOp, "only dstBits % srcBits == 0 supported");
+
+    // Only support stride of 1.
+    if (llvm::any_of(subViewOp.getStaticStrides(),
+                     [](int64_t stride) { return stride != 1; })) {
+      return rewriter.notifyMatchFailure(subViewOp->getLoc(),
+                                         "stride != 1 is not supported");
+    }
+
+    auto sizes = subViewOp.getStaticSizes();
+    int64_t lastOffset = subViewOp.getStaticOffsets().back();
+    // Only support static sizes and offsets.
+    if (llvm::any_of(
+            sizes, [](int64_t size) { return size == ShapedType::kDynamic; }) ||
+        lastOffset == ShapedType::kDynamic) {
+      return rewriter.notifyMatchFailure(
+          subViewOp->getLoc(), "dynamic size or offset is not supported");
     }
 
-    return convertCastingOp(rewriter, adaptor, op, newTy);
+    // Transform the offsets, sizes and strides according to the emulation.
+    auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+        loc, subViewOp.getViewSource());
+
+    OpFoldResult linearizedIndices;
+    auto strides = stridedMetadata.getConstifiedMixedStrides();
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, srcBits, dstBits,
+            stridedMetadata.getConstifiedMixedOffset(),
+            subViewOp.getMixedSizes(), strides,
+            getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
+                           rewriter));
+
+    rewriter.replaceOpWithNewOp<memref::SubViewOp>(
----------------
dcaballe wrote:

Hey, I tried with a few tests like these:

```
func.func @memref_subview_dynamic_offset_i4_1(%idx : index) -> i4 {                                                                                                                                                  
  %c0 = arith.constant 0 : index                                                                                                                                                                                     
  %arr = memref.alloc() : memref<512x64x8x16xi4>                                                                                                                                                                     
  %subview = memref.subview %arr[%idx, 0, 0, 0] [16, 64, 8, 4] [1, 1, 1, 4] : memref<512x64x8x16xi4>                                                                                                                 
                                                                            to memref<16x64x8x4xi4, strided<[8192, 128, 16, 4], offset: ?>>                                                                          
  %ld = memref.load %subview[%c0, %c0, %c0, %c0] : memref<16x64x8x4xi4, strided<[8192, 128, 16, 4], offset: ?>>                                                                                                      
  return %ld : i4                                                                                                                                                                                                    
}                                                                                                                                                                                                                    
                                                                                                                                                                                                                     
func.func @memref_subview_dynamic_offset_i4_2(%idx : index) -> i4 {                                                                                                                                                  
  %c0 = arith.constant 0 : index                                                                                                                                                                                     
  %arr = memref.alloc() : memref<512x64x8x16xi4>                                                                                                                                                                     
  %subview = memref.subview %arr[%idx, 0, 0, 0] [16, 4, 8, 16] [1, 16, 1, 1] : memref<512x64x8x16xi4>                                                                                                                
                                                                            to memref<16x4x8x16xi4, strided<[8192, 2048, 16, 1], offset: ?>>                                                                         
  %ld = memref.load %subview[%c0, %c0, %c0, %c0] : memref<16x4x8x16xi4, strided<[8192, 2048, 16, 1], offset: ?>>                                                                                                     
  return %ld : i4                                                                                                                                                                                                    
} 
```

and all of them seem to be covered with the existing rules. Can you think of any other example?

> you have to check the strides of the memref type of the result. Those strides need to be contiguous?
 
Note that the type of the new subview, `newTy` is coming from the emulation converter (L398), where we check that the original memref has input strides and then we linearize the shape, also resulting in a memref with a single unit stride. Not sure what else I can check

https://github.com/llvm/llvm-project/pull/89488


More information about the Mlir-commits mailing list