[Mlir-commits] [mlir] [mlir][MemRef] Extend `memref.subview` sub-byte type emulation support (PR #89488)

Diego Caballero llvmlistbot at llvm.org
Sat Apr 20 06:39:57 PDT 2024


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/89488

>From 38fe8a668d0cb86c50c2aedfac1d0f4ba53aac91 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Sat, 20 Apr 2024 04:23:01 +0000
Subject: [PATCH 1/2] [mlir][MemRef] Extend `memref.subview` sub-byte type
 emulation support

In some cases (see https://github.com/iree-org/iree/issues/16285), `memref.subview`
ops can't be folded into transfer ops and sub-byte type emulation fails. This issue
has been blocking a few things, including the enablement of vector flattening
transformations (https://github.com/iree-org/iree/pull/16456). This PR extends the
existing sub-byte type emulation support of `memref.subview` to handle multi-dimensional
subviews with dynamic offsets and addresses the issues for some of the
`memref.subview` cases that can't be folded.
---
 .../mlir/Dialect/Vector/IR/VectorOps.h        |   4 +
 .../MemRef/Transforms/EmulateNarrowType.cpp   | 162 ++++++++++--------
 mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp |   4 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |   7 +
 .../Dialect/MemRef/emulate-narrow-type.mlir   |  23 +++
 5 files changed, 127 insertions(+), 73 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 4603953cb40fa5..ddeda1fe1c692c 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -145,6 +145,10 @@ inline bool isReductionIterator(Attribute attr) {
 /// constant operations.
 SmallVector<int64_t> getAsIntegers(ArrayRef<Value> values);
 
+/// Returns the integer numbers in `foldResult`. `foldResult` is expected to
+/// be constant operations.
+int64_t getAsInteger(OpFoldResult foldResult);
+
 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
 /// be constant operations.
 SmallVector<int64_t> getAsIntegers(ArrayRef<OpFoldResult> foldResults);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 8236a4c475f17c..8325da357f34c6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -13,7 +13,6 @@
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -24,7 +23,6 @@
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/MathExtras.h"
 #include <cassert>
 #include <type_traits>
 
@@ -34,62 +32,6 @@ using namespace mlir;
 // Utility functions
 //===----------------------------------------------------------------------===//
 
-/// Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted
-/// type. The result MemRefType of the old op must have a rank and stride of 1,
-/// with static offset and size. The number of bits in the offset must evenly
-/// divide the bitwidth of the new converted type.
-template <typename MemRefOpTy>
-static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter,
-                                      typename MemRefOpTy::Adaptor adaptor,
-                                      MemRefOpTy op, MemRefType newTy) {
-  static_assert(std::is_same<MemRefOpTy, memref::SubViewOp>() ||
-                    std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
-                "Expected only memref::SubViewOp or memref::ReinterpretCastOp");
-
-  auto convertedElementType = newTy.getElementType();
-  auto oldElementType = op.getType().getElementType();
-  int srcBits = oldElementType.getIntOrFloatBitWidth();
-  int dstBits = convertedElementType.getIntOrFloatBitWidth();
-  if (dstBits % srcBits != 0) {
-    return rewriter.notifyMatchFailure(op,
-                                       "only dstBits % srcBits == 0 supported");
-  }
-
-  // Only support stride of 1.
-  if (llvm::any_of(op.getStaticStrides(),
-                   [](int64_t stride) { return stride != 1; })) {
-    return rewriter.notifyMatchFailure(op->getLoc(),
-                                       "stride != 1 is not supported");
-  }
-
-  auto sizes = op.getStaticSizes();
-  int64_t offset = op.getStaticOffset(0);
-  // Only support static sizes and offsets.
-  if (llvm::any_of(sizes,
-                   [](int64_t size) { return size == ShapedType::kDynamic; }) ||
-      offset == ShapedType::kDynamic) {
-    return rewriter.notifyMatchFailure(
-        op->getLoc(), "dynamic size or offset is not supported");
-  }
-
-  int elementsPerByte = dstBits / srcBits;
-  if (offset % elementsPerByte != 0) {
-    return rewriter.notifyMatchFailure(
-        op->getLoc(), "offset not multiple of elementsPerByte is not "
-                      "supported");
-  }
-
-  SmallVector<int64_t> size;
-  if (sizes.size())
-    size.push_back(ceilDiv(sizes[0], elementsPerByte));
-  offset = offset / elementsPerByte;
-
-  rewriter.replaceOpWithNewOp<MemRefOpTy>(op, newTy,
-                                          *adaptor.getODSOperands(0).begin(),
-                                          offset, size, op.getStaticStrides());
-  return success();
-}
-
 /// When data is loaded/stored in `targetBits` granularity, but is used in
 /// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
 /// treated as an array of elements of width `sourceBits`.
@@ -337,7 +279,48 @@ struct ConvertMemRefReinterpretCast final
           op->getLoc(), "subview with rank > 1 is not supported");
     }
 
-    return convertCastingOp(rewriter, adaptor, op, newTy);
+    Type convertedElementType = newTy.getElementType();
+    Type oldElementType = op.getType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = convertedElementType.getIntOrFloatBitWidth();
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+
+    // Only support stride of 1.
+    if (llvm::any_of(op.getStaticStrides(),
+                     [](int64_t stride) { return stride != 1; })) {
+      return rewriter.notifyMatchFailure(op->getLoc(),
+                                         "stride != 1 is not supported");
+    }
+
+    auto sizes = op.getStaticSizes();
+    int64_t offset = op.getStaticOffset(0);
+    // Only support static sizes and offsets.
+    if (llvm::any_of(
+            sizes, [](int64_t size) { return size == ShapedType::kDynamic; }) ||
+        offset == ShapedType::kDynamic) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "dynamic size or offset is not supported");
+    }
+
+    int elementsPerByte = dstBits / srcBits;
+    if (offset % elementsPerByte != 0) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "offset not multiple of elementsPerByte is not "
+                        "supported");
+    }
+
+    SmallVector<int64_t> size;
+    if (sizes.size())
+      size.push_back(ceilDiv(sizes[0], elementsPerByte));
+    offset = offset / elementsPerByte;
+
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        op, newTy, *adaptor.getODSOperands(0).begin(), offset, size,
+        op.getStaticStrides());
+    return success();
   }
 };
 
@@ -404,29 +387,68 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
 
 /// Emulating narrow ints on subview have limited support, supporting only
 /// static offset and size and stride of 1. Ideally, the subview should be
-/// folded away before running narrow type emulation, and this pattern would
-/// never run. This pattern is mostly used for testing pruposes.
+/// folded away before running narrow type emulation, and this pattern should
+/// only run for cases that can't be folded.
 struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
+  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    MemRefType newTy =
-        dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+    MemRefType newTy = dyn_cast<MemRefType>(
+        getTypeConverter()->convertType(subViewOp.getType()));
     if (!newTy) {
       return rewriter.notifyMatchFailure(
-          op->getLoc(),
-          llvm::formatv("failed to convert memref type: {0}", op.getType()));
+          subViewOp->getLoc(),
+          llvm::formatv("failed to convert memref type: {0}",
+                        subViewOp.getType()));
     }
 
-    // Only support offset for 1-D subview.
-    if (op.getType().getRank() != 1) {
+    Location loc = subViewOp.getLoc();
+    Type convertedElementType = newTy.getElementType();
+    Type oldElementType = subViewOp.getType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = convertedElementType.getIntOrFloatBitWidth();
+    if (dstBits % srcBits != 0)
       return rewriter.notifyMatchFailure(
-          op->getLoc(), "subview with rank > 1 is not supported");
+          subViewOp, "only dstBits % srcBits == 0 supported");
+
+    // Only support stride of 1.
+    if (llvm::any_of(subViewOp.getStaticStrides(),
+                     [](int64_t stride) { return stride != 1; })) {
+      return rewriter.notifyMatchFailure(subViewOp->getLoc(),
+                                         "stride != 1 is not supported");
+    }
+
+    auto sizes = subViewOp.getStaticSizes();
+    int64_t lastOffset = subViewOp.getStaticOffsets().back();
+    // Only support static sizes and offsets.
+    if (llvm::any_of(
+            sizes, [](int64_t size) { return size == ShapedType::kDynamic; }) ||
+        lastOffset == ShapedType::kDynamic) {
+      return rewriter.notifyMatchFailure(
+          subViewOp->getLoc(), "dynamic size or offset is not supported");
     }
 
-    return convertCastingOp(rewriter, adaptor, op, newTy);
+    // Transform the offsets, sizes and strides according to the emulation.
+    auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+        loc, subViewOp.getViewSource());
+
+    OpFoldResult linearizedIndices;
+    auto strides = stridedMetadata.getConstifiedMixedStrides();
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, srcBits, dstBits,
+            stridedMetadata.getConstifiedMixedOffset(),
+            subViewOp.getMixedSizes(), strides,
+            getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
+                           rewriter));
+
+    rewriter.replaceOpWithNewOp<memref::SubViewOp>(
+        subViewOp, newTy, adaptor.getSource(), linearizedIndices,
+        linearizedInfo.linearizedSize, strides.back());
+    return success();
   }
 };
 
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 556a82de2166f7..588612b55fbd41 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -67,7 +67,6 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
   AffineExpr mulMap = builder.getAffineConstantExpr(1);
 
   SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
-  SmallVector<OpFoldResult> sizeValues(sourceRank);
 
   for (unsigned i = 0; i < sourceRank; ++i) {
     unsigned offsetIdx = 2 * i;
@@ -78,8 +77,7 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
     mulMap = mulMap * symbols[i];
   }
 
-  // Adjust linearizedIndices, size and offset by the scale factor (dstBits /
-  // srcBits).
+  // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
   int64_t scaler = dstBits / srcBits;
   addMulMap = addMulMap.floorDiv(scaler);
   mulMap = mulMap.floorDiv(scaler);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3e6425879cc67f..4f62dffa48a935 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -291,6 +291,13 @@ SmallVector<int64_t> vector::getAsIntegers(ArrayRef<Value> values) {
   return ints;
 }
 
+/// Returns the integer numbers in `foldResult`. `foldResult` is expected to
+/// be constant operations.
+int64_t vector::getAsInteger(OpFoldResult foldResult) {
+  assert(foldResult.is<Attribute>() && "Unexpected non-constant index");
+  return cast<IntegerAttr>(foldResult.get<Attribute>()).getInt();
+}
+
 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
 /// be constant operations.
 SmallVector<int64_t> vector::getAsIntegers(ArrayRef<OpFoldResult> foldResults) {
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index fd37b7ff0a2713..7e4002b8fff549 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -177,6 +177,29 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
 
 // -----
 
+func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 {
+  %c0 = arith.constant 0 : index
+  %arr = memref.alloc() : memref<512x64x8x16xi4>
+  %subview = memref.subview %arr[%idx, 0, 0, 0] [16, 64, 8, 16] [1, 1, 1, 1] : memref<512x64x8x16xi4>
+                                                                            to memref<16x64x8x16xi4, strided<[8192, 128, 16, 1], offset: ?>>
+  %ld = memref.load %subview[%c0, %c0, %c0, %c0] : memref<16x64x8x16xi4, strided<[8192, 128, 16, 1], offset: ?>>
+  return %ld : i4
+}
+
+// CHECK-LABEL:   func.func @memref_subview_dynamic_offset_i4(
+// CHECK:           %[[ALLOC:.*]] = memref.alloc() : memref<2097152xi8>
+// CHECK:           %[[IDX:.*]] = affine.apply
+// CHECK:           %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [65536] [1] : memref<2097152xi8> to memref<65536xi8, strided<[1], offset: ?>>
+// CHECK:           memref.load %[[SUBVIEW]]
+
+// CHECK32-LABEL:   func.func @memref_subview_dynamic_offset_i4(
+// CHECK32:           %[[ALLOC:.*]] = memref.alloc() : memref<524288xi32>
+// CHECK32:           %[[IDX:.*]] = affine.apply
+// CHECK32:           %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [16384] [1] : memref<524288xi32> to memref<16384xi32, strided<[1], offset: ?>>
+// CHECK32:           memref.load %[[SUBVIEW]]
+
+// -----
+
 func.func @reinterpret_cast_memref_load_0D() -> i4 {
     %0 = memref.alloc() : memref<5xi4>
     %reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref<i4>

>From 5fc88bde67d1521be233c69be29f7f2593c5cac0 Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Sat, 20 Apr 2024 13:39:27 +0000
Subject: [PATCH 2/2] Remove `getAsInteger`

---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.h | 4 ----
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp        | 7 -------
 2 files changed, 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index ddeda1fe1c692c..4603953cb40fa5 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -145,10 +145,6 @@ inline bool isReductionIterator(Attribute attr) {
 /// constant operations.
 SmallVector<int64_t> getAsIntegers(ArrayRef<Value> values);
 
-/// Returns the integer numbers in `foldResult`. `foldResult` is expected to
-/// be constant operations.
-int64_t getAsInteger(OpFoldResult foldResult);
-
 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
 /// be constant operations.
 SmallVector<int64_t> getAsIntegers(ArrayRef<OpFoldResult> foldResults);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 4f62dffa48a935..3e6425879cc67f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -291,13 +291,6 @@ SmallVector<int64_t> vector::getAsIntegers(ArrayRef<Value> values) {
   return ints;
 }
 
-/// Returns the integer numbers in `foldResult`. `foldResult` is expected to
-/// be constant operations.
-int64_t vector::getAsInteger(OpFoldResult foldResult) {
-  assert(foldResult.is<Attribute>() && "Unexpected non-constant index");
-  return cast<IntegerAttr>(foldResult.get<Attribute>()).getInt();
-}
-
 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
 /// be constant operations.
 SmallVector<int64_t> vector::getAsIntegers(ArrayRef<OpFoldResult> foldResults) {



More information about the Mlir-commits mailing list