[Mlir-commits] [mlir] [mlir][memref] Fix emulate narrow types for strided memref offset (PR #68181)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 3 22:58:58 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

This patch fixes strided memref offset calculation for emulating narrow types.

As a side effect, this patch also adds support for a 1-D subviews with static sizes, static offsets and strides of 1 for testing. Emulate narrow types pass was not tested for strided memrefs before this patch.

---
Full diff: https://github.com/llvm/llvm-project/pull/68181.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp (+86-5) 
- (modified) mlir/test/Dialect/MemRef/emulate-narrow-type.mlir (+19) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 2a524ceb9db887b..453a18ff3c254e6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -209,6 +209,74 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
     return success();
   }
 };
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefAssumeAlignment
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto convertedType =
+        cast<MemRefType>(getTypeConverter()->convertType(op.getSourceType()));
+    auto convertedElementType = convertedType.getElementType();
+    auto oldElementType = op.getSourceType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = convertedElementType.getIntOrFloatBitWidth();
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+
+    MemRefType newTy =
+        cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+    if (!newTy) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          llvm::formatv("failed to convert memref type: {0}", op.getType()));
+    }
+
+    // Only support offset for 1-D subview.
+    if (op.getType().getRank() != 1) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "subview with rank > 1 is not supported");
+    }
+
+    // Only support stride of 1.
+    if (op.getStaticStride(0) != 1) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "subview with stride != 1 is not supported");
+    }
+
+    auto size = op.getStaticSize(0);
+    auto offset = op.getStaticOffset(0);
+    // Only support static sizes and offsets.
+    if (size == ShapedType::kDynamic || offset == ShapedType::kDynamic) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "subview with dynamic size or offset is not supported");
+    }
+
+    int elementsPerByte = dstBits / srcBits;
+    if (size % elementsPerByte != 0 || offset % elementsPerByte != 0) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          "subview with size or offset not multiple of elementsPerByte is not "
+          "supported");
+    }
+
+    size = size / elementsPerByte;
+    offset = offset / elementsPerByte;
+
+    rewriter.replaceOpWithNewOp<memref::SubViewOp>(
+        op, newTy, *adaptor.getODSOperands(0).begin(), offset, size,
+        op.getStaticStrides());
+    return success();
+  }
+};
+
 } // end anonymous namespace
 
 //===----------------------------------------------------------------------===//
@@ -220,9 +288,9 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
     RewritePatternSet &patterns) {
 
   // Populate `memref.*` conversion patterns.
-  patterns
-      .add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment>(
-          typeConverter, patterns.getContext());
+  patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
+               ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
+      typeConverter, patterns.getContext());
   memref::populateResolveExtractStridedMetadataPatterns(patterns);
 }
 
@@ -271,9 +339,22 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
           return std::nullopt;
 
         StridedLayoutAttr layoutAttr;
+        // If the offset is 0, we do not need a strided layout as the stride is
+        // 1, so we only use the strided layout if the offset is not 0.
         if (offset != 0) {
-          layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
-                                              ArrayRef<int64_t>{1});
+          if (offset == ShapedType::kDynamic) {
+            layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
+                                                ArrayRef<int64_t>{1});
+          } else {
+            // Check if the number of bytes are a multiple of the loadStoreWidth
+            // and if so, divide it by the loadStoreWidth to get the offset.
+            if ((offset * width) % loadStoreWidth != 0)
+              return std::nullopt;
+            offset = (offset * width) / loadStoreWidth;
+
+            layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
+                                                ArrayRef<int64_t>{1});
+          }
         }
 
         return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index c0050d8c510d53f..6ed97f05aa7cff2 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -155,3 +155,22 @@ func.func @rank_zero_memref() -> i4 {
 //       CHECK32:   %[[LOAD:.+]] = memref.load %[[ALLOC]][] : memref<i32>
 //       CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i32 to i4
 //       CHECK32:   return %[[TRUNC]]
+
+// -----
+
+func.func @memref_strided_i4(%idx : index) -> i4 {
+  %arr = memref.alloc() : memref<128xi4>
+  %subview = memref.subview %arr[32] [32] [1] : memref<128xi4> to memref<32xi4, strided<[1], offset:32>>
+  %1 = memref.load %subview[%idx] : memref<32xi4, strided<[1], offset:32>>
+  return %1 : i4
+}
+
+// CHECK-LABEL: func @memref_strided_i4
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<64xi8>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16] [16] [1] : memref<64xi8> to memref<16xi8, strided<[1], offset: 16>>
+//       CHECK:   %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
+
+// CHECK32-LABEL: func @memref_strided_i4
+//       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
+//       CHECK32:   %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
+//       CHECK32:   %[[LOAD:.+]] = memref.load %[[SUBVIEW]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/68181


More information about the Mlir-commits mailing list