[Mlir-commits] [mlir] [mlir][MemRef] Extend `memref.subview` sub-byte type emulation support (PR #89488)
Diego Caballero
llvmlistbot at llvm.org
Mon Apr 22 07:39:11 PDT 2024
================
@@ -404,29 +387,68 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
/// Emulating narrow ints on subview have limited support, supporting only
/// static offset and size and stride of 1. Ideally, the subview should be
-/// folded away before running narrow type emulation, and this pattern would
-/// never run. This pattern is mostly used for testing pruposes.
+/// folded away before running narrow type emulation, and this pattern should
+/// only run for cases that can't be folded.
struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
+ matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- MemRefType newTy =
- dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+ MemRefType newTy = dyn_cast<MemRefType>(
+ getTypeConverter()->convertType(subViewOp.getType()));
if (!newTy) {
return rewriter.notifyMatchFailure(
- op->getLoc(),
- llvm::formatv("failed to convert memref type: {0}", op.getType()));
+ subViewOp->getLoc(),
+ llvm::formatv("failed to convert memref type: {0}",
+ subViewOp.getType()));
}
- // Only support offset for 1-D subview.
- if (op.getType().getRank() != 1) {
+ Location loc = subViewOp.getLoc();
+ Type convertedElementType = newTy.getElementType();
+ Type oldElementType = subViewOp.getType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
+ if (dstBits % srcBits != 0)
return rewriter.notifyMatchFailure(
- op->getLoc(), "subview with rank > 1 is not supported");
+ subViewOp, "only dstBits % srcBits == 0 supported");
+
+ // Only support stride of 1.
+ if (llvm::any_of(subViewOp.getStaticStrides(),
+ [](int64_t stride) { return stride != 1; })) {
+ return rewriter.notifyMatchFailure(subViewOp->getLoc(),
+ "stride != 1 is not supported");
+ }
+
+ auto sizes = subViewOp.getStaticSizes();
+ int64_t lastOffset = subViewOp.getStaticOffsets().back();
+ // Only support static sizes and offsets.
+ if (llvm::any_of(
+ sizes, [](int64_t size) { return size == ShapedType::kDynamic; }) ||
+ lastOffset == ShapedType::kDynamic) {
+ return rewriter.notifyMatchFailure(
+ subViewOp->getLoc(), "dynamic size or offset is not supported");
}
- return convertCastingOp(rewriter, adaptor, op, newTy);
+ // Transform the offsets, sizes and strides according to the emulation.
+ auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+ loc, subViewOp.getViewSource());
+
+ OpFoldResult linearizedIndices;
+ auto strides = stridedMetadata.getConstifiedMixedStrides();
+ memref::LinearizedMemRefInfo linearizedInfo;
+ std::tie(linearizedInfo, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ rewriter, loc, srcBits, dstBits,
+ stridedMetadata.getConstifiedMixedOffset(),
+ subViewOp.getMixedSizes(), strides,
+ getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
+ rewriter));
+
+ rewriter.replaceOpWithNewOp<memref::SubViewOp>(
----------------
dcaballe wrote:
Thanks! Lines 417-421 check that the subview has only unit strides. Is there anything else needed? I'm using `strides.back()` (L450) because I know that all the strides are one and therefore the stride of the new subview would be one.
https://github.com/llvm/llvm-project/pull/89488
More information about the Mlir-commits
mailing list