[Mlir-commits] [mlir] [mlir][memref] Fix offset update in emulating narrow type for strided memref (PR #68043)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Mon Oct  2 15:11:07 PDT 2023
    
    
  
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
<details>
<summary>Changes</summary>
The offset when converting type in emulating narrow types did not account for the offset in strided memrefs. This patch fixes this.
---
Full diff: https://github.com/llvm/llvm-project/pull/68043.diff
1 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp (+15-2) 
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 2a524ceb9db887b..2a9b27debaece3f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -271,9 +271,22 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
           return std::nullopt;
 
         StridedLayoutAttr layoutAttr;
+        // If the offset is 0, we do not need a strided layout as the stride is
+        // 1, so we only use the strided layout if the offset is not 0.
         if (offset != 0) {
-          layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
-                                              ArrayRef<int64_t>{1});
+          if (offset == ShapedType::kDynamic) {
+            layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
+                                                ArrayRef<int64_t>{1});
+          } else {
+            // Check if the number of bytes are a multiple of the loadStoreWidth
+            // and if so, divide it by the loadStoreWidth to get the offset.
+            if ((offset * width) % loadStoreWidth != 0)
+              return std::nullopt;
+            offset = (offset * width) / loadStoreWidth;
+
+            layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
+                                                ArrayRef<int64_t>{1});
+          }
         }
 
         return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
``````````
</details>
https://github.com/llvm/llvm-project/pull/68043
    
    
More information about the Mlir-commits
mailing list