[Mlir-commits] [mlir] [MLIR] Make 1-D memref flattening a prerequisite for vector narrow type emulation (PR #157771)

Alan Li llvmlistbot at llvm.org
Tue Sep 9 19:01:13 PDT 2025


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/157771

>From 686a25b9e2da02f2d89c18305f4ecadf011ed731 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 9 Sep 2025 19:24:24 -0400
Subject: [PATCH 1/2] First updates.

---
 .../Dialect/MemRef/Transforms/Transforms.h    |  4 ++
 .../MemRef/Transforms/FlattenMemRefs.cpp      | 21 +++++--
 .../Transforms/VectorEmulateNarrowType.cpp    | 59 ++++++++++++-------
 3 files changed, 58 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 33e3d94f02b1c..e7751df724f9c 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -145,6 +145,10 @@ FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
 /// ```
 void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
 
+/// Patterns for flattening multi-dimensional memref operations into
+/// one-dimensional memref operations.
+void populateFlattenVectorMemRefPatterns(RewritePatternSet &patterns);
+void populateFlattenMemRefOpsPatterns(RewritePatternSet &patterns);
 void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
 
 /// Build a new memref::AllocaOp whose dynamic sizes are independent of all
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 42be847811d52..d658d147a0a3a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -271,12 +271,8 @@ struct FlattenMemrefsPass
 
 } // namespace
 
-void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
-  patterns.insert<MemRefRewritePattern<memref::LoadOp>,
-                  MemRefRewritePattern<memref::StoreOp>,
-                  MemRefRewritePattern<memref::AllocOp>,
-                  MemRefRewritePattern<memref::AllocaOp>,
-                  MemRefRewritePattern<vector::LoadOp>,
+void memref::populateFlattenVectorMemrefPatterns(RewritePatternSet &patterns) {
+  patterns.insert<MemRefRewritePattern<vector::LoadOp>,
                   MemRefRewritePattern<vector::StoreOp>,
                   MemRefRewritePattern<vector::TransferReadOp>,
                   MemRefRewritePattern<vector::TransferWriteOp>,
@@ -284,3 +280,16 @@ void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
                   MemRefRewritePattern<vector::MaskedStoreOp>>(
       patterns.getContext());
 }
+
+void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
+  patterns.insert<MemRefRewritePattern<memref::LoadOp>,
+                  MemRefRewritePattern<memref::StoreOp>,
+                  MemRefRewritePattern<memref::AllocOp>,
+                  MemRefRewritePattern<memref::AllocaOp>>(
+      patterns.getContext());
+}
+
+void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
+  populateFlattenMemrefOpsPatterns(patterns);
+  populateFlattenVectorMemrefPatterns(patterns);
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index f78e579d6c099..2ea17dbe2f53e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -556,7 +556,6 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
   matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    // See #115653
     if (op.getValueToStore().getType().getRank() != 1)
       return rewriter.notifyMatchFailure(op,
                                          "only 1-D vectors are supported ATM");
@@ -817,7 +816,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
 // ConvertVectorMaskedStore
 //===----------------------------------------------------------------------===//
 
-// TODO: Document-me
+/// Converts `vector.maskedstore` operations on narrow element types to work
+/// with wider, byte-aligned container types by adjusting the mask and using
+/// bitcasting.
+///
+/// Example: Storing `vector<6xi4>` is emulated by bitcasting to `vector<3xi8>`
+/// and storing with an adjusted mask, since each `i8` container element holds
+/// two `i4` values.
 struct ConvertVectorMaskedStore final
     : OpConversionPattern<vector::MaskedStoreOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -826,10 +831,10 @@ struct ConvertVectorMaskedStore final
   matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    // See #115653
+    // Prerequisite: memref in the vector.maskedstore op is flattened into 1-D.
     if (op.getValueToStore().getType().getRank() != 1)
-      return rewriter.notifyMatchFailure(op,
-                                         "only 1-D vectors are supported ATM");
+      return rewriter.notifyMatchFailure(
+          op, "Memref in vector.maskedstore op must be flattened beforehand.");
 
     auto loc = op.getLoc();
     auto containerElemTy =
@@ -931,18 +936,27 @@ struct ConvertVectorMaskedStore final
 // ConvertVectorLoad
 //===----------------------------------------------------------------------===//
 
-// TODO: Document-me
+/// Converts `vector.load` on narrow element types to work with
+/// wider, byte-aligned container types by adjusting load sizes and using
+/// bitcasting.
+///
+/// Example: `vector.load` of `vector<4xi4>` from `memref<3x4xi4>` is emulated
+/// by loading `vector<2xi8>` from the linearized `memref<6xi8>` and bitcasting
+/// back, since each `i8` container holds two `i4` values.
+///
+/// There are cases where the number of elements to load is not byte-aligned. In
+/// those cases, loads are converted to byte-aligned, byte-sized loads and the
+/// target vector is extracted from the loaded vector.
 struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-
-    // See #115653
+    // Prerequisites:  memref in the vector.load op is flattened into 1-D.
     if (op.getVectorType().getRank() != 1)
-      return rewriter.notifyMatchFailure(op,
-                                         "only 1-D vectors are supported ATM");
+      return rewriter.notifyMatchFailure(
+          op, "Memref in emulated vector ops must be flattened beforehand.");
 
     auto loc = op.getLoc();
     auto containerElemTy =
@@ -961,8 +975,6 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
 
     // Adjust the number of elements to load when emulating narrow types,
     // and then cast back to the original type with vector.bitcast op.
-    // Here only the 1-D vector load is considered, and the N-D memref types
-    // should be linearized.
     // For example, to emulate i4 to i8, the following op:
     //
     // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
@@ -1037,7 +1049,12 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
 // ConvertVectorMaskedLoad
 //===----------------------------------------------------------------------===//
 
-// TODO: Document-me
+/// Converts `vector.maskedload` operations on narrow element types to work with
+/// wider, byte-aligned container types by adjusting the mask and using
+/// bitcasting.
+///
+/// Example: Loading `vector<6xi4>` is emulated by loading `vector<3xi8>` and
+/// bitcasting, since each `i8` container element holds two `i4` values.
 struct ConvertVectorMaskedLoad final
     : OpConversionPattern<vector::MaskedLoadOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -1045,10 +1062,9 @@ struct ConvertVectorMaskedLoad final
   LogicalResult
   matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // See #115653
     if (op.getVectorType().getRank() != 1)
-      return rewriter.notifyMatchFailure(op,
-                                         "only 1-D vectors are supported ATM");
+      return rewriter.notifyMatchFailure(
+          op, "Memref in emulated vector ops must be flattened beforehand.");
 
     auto loc = op.getLoc();
 
@@ -1229,7 +1245,6 @@ static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
 
   int elemsPerMultiByte = multiByteBits / subByteBits;
 
-  // TODO: This is a bit too restrictive for vectors rank > 1.
   return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
 }
 
@@ -1246,10 +1261,11 @@ struct ConvertVectorTransferRead final
   matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    // See #115653
+    // Prerequisites:  memref in the vector.transfer_read op is flattened into
+    // 1-D.
     if (op.getVectorType().getRank() != 1)
-      return rewriter.notifyMatchFailure(op,
-                                         "only 1-D vectors are supported ATM");
+      return rewriter.notifyMatchFailure(
+          op, "Memref in emulated vector ops must be flattened beforehand.");
 
     auto loc = op.getLoc();
     auto containerElemTy =
@@ -2228,6 +2244,9 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
     const arith::NarrowTypeEmulationConverter &typeConverter,
     RewritePatternSet &patterns, bool disableAtomicRMW) {
 
+  // As a prerequisite, make sure memrefs in vector ops are linearized.
+  memref::populateFlattenVectorMemrefPatterns(patterns);
+
   // Populate `vector.*` conversion patterns.
   // TODO: #119553 support atomicity
   patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,

>From b061ab9c0ab90bebd83e219e936237bd751a54d2 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 9 Sep 2025 22:00:52 -0400
Subject: [PATCH 2/2] Forgot to push

---
 mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h      | 4 ++--
 .../lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 2 ++
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index e7751df724f9c..6b1cd02fec5b4 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -147,8 +147,8 @@ void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
 
 /// Patterns for flattening multi-dimensional memref operations into
 /// one-dimensional memref operations.
-void populateFlattenVectorMemRefPatterns(RewritePatternSet &patterns);
-void populateFlattenMemRefOpsPatterns(RewritePatternSet &patterns);
+void populateFlattenVectorMemrefPatterns(RewritePatternSet &patterns);
+void populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns);
 void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
 
 /// Build a new memref::AllocaOp whose dynamic sizes are independent of all
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 2ea17dbe2f53e..61b2df68c0f7c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -38,6 +38,8 @@
 #include <cstdint>
 #include <optional>
 
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+
 using namespace mlir;
 
 #define DEBUG_TYPE "vector-narrow-type-emulation"



More information about the Mlir-commits mailing list