[Mlir-commits] [mlir] [mlir][MemRef] Extend `memref.subview` sub-byte type emulation support (PR #89488)
Diego Caballero
llvmlistbot at llvm.org
Fri Apr 26 07:17:33 PDT 2024
================
@@ -404,29 +387,68 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
/// Emulating narrow ints on subview have limited support, supporting only
/// static offset and size and stride of 1. Ideally, the subview should be
-/// folded away before running narrow type emulation, and this pattern would
-/// never run. This pattern is mostly used for testing pruposes.
+/// folded away before running narrow type emulation, and this pattern should
+/// only run for cases that can't be folded.
struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
+ matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- MemRefType newTy =
- dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+ MemRefType newTy = dyn_cast<MemRefType>(
+ getTypeConverter()->convertType(subViewOp.getType()));
if (!newTy) {
return rewriter.notifyMatchFailure(
- op->getLoc(),
- llvm::formatv("failed to convert memref type: {0}", op.getType()));
+ subViewOp->getLoc(),
+ llvm::formatv("failed to convert memref type: {0}",
+ subViewOp.getType()));
}
- // Only support offset for 1-D subview.
- if (op.getType().getRank() != 1) {
+ Location loc = subViewOp.getLoc();
+ Type convertedElementType = newTy.getElementType();
+ Type oldElementType = subViewOp.getType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
+ if (dstBits % srcBits != 0)
return rewriter.notifyMatchFailure(
- op->getLoc(), "subview with rank > 1 is not supported");
+ subViewOp, "only dstBits % srcBits == 0 supported");
+
+ // Only support stride of 1.
+ if (llvm::any_of(subViewOp.getStaticStrides(),
+ [](int64_t stride) { return stride != 1; })) {
+ return rewriter.notifyMatchFailure(subViewOp->getLoc(),
+ "stride != 1 is not supported");
+ }
+
+ auto sizes = subViewOp.getStaticSizes();
+ int64_t lastOffset = subViewOp.getStaticOffsets().back();
+ // Only support static sizes and offsets.
+ if (llvm::any_of(
+ sizes, [](int64_t size) { return size == ShapedType::kDynamic; }) ||
+ lastOffset == ShapedType::kDynamic) {
+ return rewriter.notifyMatchFailure(
+ subViewOp->getLoc(), "dynamic size or offset is not supported");
}
- return convertCastingOp(rewriter, adaptor, op, newTy);
+ // Transform the offsets, sizes and strides according to the emulation.
+ auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+ loc, subViewOp.getViewSource());
+
+ OpFoldResult linearizedIndices;
+ auto strides = stridedMetadata.getConstifiedMixedStrides();
+ memref::LinearizedMemRefInfo linearizedInfo;
+ std::tie(linearizedInfo, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, srcBits, dstBits,
+ stridedMetadata.getConstifiedMixedOffset(),
+ subViewOp.getMixedSizes(), strides,
+ getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
+ rewriter));
+
+ rewriter.replaceOpWithNewOp<memref::SubViewOp>(
----------------
dcaballe wrote:
Hey, I tried with a few tests like these:
```
func.func @memref_subview_dynamic_offset_i4_1(%idx : index) -> i4 {
%c0 = arith.constant 0 : index
%arr = memref.alloc() : memref<512x64x8x16xi4>
%subview = memref.subview %arr[%idx, 0, 0, 0] [16, 64, 8, 4] [1, 1, 1, 4] : memref<512x64x8x16xi4>
to memref<16x64x8x4xi4, strided<[8192, 128, 16, 4], offset: ?>>
%ld = memref.load %subview[%c0, %c0, %c0, %c0] : memref<16x64x8x4xi4, strided<[8192, 128, 16, 4], offset: ?>>
return %ld : i4
}
func.func @memref_subview_dynamic_offset_i4_2(%idx : index) -> i4 {
%c0 = arith.constant 0 : index
%arr = memref.alloc() : memref<512x64x8x16xi4>
%subview = memref.subview %arr[%idx, 0, 0, 0] [16, 4, 8, 16] [1, 16, 1, 1] : memref<512x64x8x16xi4>
to memref<16x4x8x16xi4, strided<[8192, 2048, 16, 1], offset: ?>>
%ld = memref.load %subview[%c0, %c0, %c0, %c0] : memref<16x4x8x16xi4, strided<[8192, 2048, 16, 1], offset: ?>>
return %ld : i4
}
```
and all of them seem to be covered with the existing rules. Can you think of any other example?
> you have to check the strides of the memref type of the result. Those strides need to be contiguous?
Note that the type of the new subview, `newTy` is coming from the emulation converter (L398), where we check that the original memref has input strides and then we linearize the shape, also resulting in a memref with a single unit stride. Not sure what else I can check
https://github.com/llvm/llvm-project/pull/89488
More information about the Mlir-commits
mailing list