[Mlir-commits] [mlir] [MLIR] Make 1-D memref flattening a prerequisite for vector narrow type emulation (PR #157771)
    Alan Li 
    llvmlistbot at llvm.org
       
    Fri Sep 12 12:05:53 PDT 2025
    
    
  
https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/157771
>From 686a25b9e2da02f2d89c18305f4ecadf011ed731 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 9 Sep 2025 19:24:24 -0400
Subject: [PATCH 1/4] First updates.
---
 .../Dialect/MemRef/Transforms/Transforms.h    |  4 ++
 .../MemRef/Transforms/FlattenMemRefs.cpp      | 21 +++++--
 .../Transforms/VectorEmulateNarrowType.cpp    | 59 ++++++++++++-------
 3 files changed, 58 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 33e3d94f02b1c..e7751df724f9c 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -145,6 +145,10 @@ FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
 /// ```
 void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
 
+/// Patterns for flattening multi-dimensional memref operations into
+/// one-dimensional memref operations.
+void populateFlattenVectorMemRefPatterns(RewritePatternSet &patterns);
+void populateFlattenMemRefOpsPatterns(RewritePatternSet &patterns);
 void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
 
 /// Build a new memref::AllocaOp whose dynamic sizes are independent of all
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 42be847811d52..d658d147a0a3a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -271,12 +271,8 @@ struct FlattenMemrefsPass
 
 } // namespace
 
-void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
-  patterns.insert<MemRefRewritePattern<memref::LoadOp>,
-                  MemRefRewritePattern<memref::StoreOp>,
-                  MemRefRewritePattern<memref::AllocOp>,
-                  MemRefRewritePattern<memref::AllocaOp>,
-                  MemRefRewritePattern<vector::LoadOp>,
+void memref::populateFlattenVectorMemrefPatterns(RewritePatternSet &patterns) {
+  patterns.insert<MemRefRewritePattern<vector::LoadOp>,
                   MemRefRewritePattern<vector::StoreOp>,
                   MemRefRewritePattern<vector::TransferReadOp>,
                   MemRefRewritePattern<vector::TransferWriteOp>,
@@ -284,3 +280,16 @@ void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
                   MemRefRewritePattern<vector::MaskedStoreOp>>(
       patterns.getContext());
 }
+
+void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
+  patterns.insert<MemRefRewritePattern<memref::LoadOp>,
+                  MemRefRewritePattern<memref::StoreOp>,
+                  MemRefRewritePattern<memref::AllocOp>,
+                  MemRefRewritePattern<memref::AllocaOp>>(
+      patterns.getContext());
+}
+
+void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
+  populateFlattenMemrefOpsPatterns(patterns);
+  populateFlattenVectorMemrefPatterns(patterns);
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index f78e579d6c099..2ea17dbe2f53e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -556,7 +556,6 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
   matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    // See #115653
     if (op.getValueToStore().getType().getRank() != 1)
       return rewriter.notifyMatchFailure(op,
                                          "only 1-D vectors are supported ATM");
@@ -817,7 +816,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
 // ConvertVectorMaskedStore
 //===----------------------------------------------------------------------===//
 
-// TODO: Document-me
+/// Converts `vector.maskedstore` operations on narrow element types to work
+/// with wider, byte-aligned container types by adjusting the mask and using
+/// bitcasting.
+///
+/// Example: Storing `vector<6xi4>` is emulated by bitcasting to `vector<3xi8>`
+/// and storing with an adjusted mask, since each `i8` container element holds
+/// two `i4` values.
 struct ConvertVectorMaskedStore final
     : OpConversionPattern<vector::MaskedStoreOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -826,10 +831,10 @@ struct ConvertVectorMaskedStore final
   matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    // See #115653
+    // Prerequisite: memref in the vector.maskedstore op is flattened into 1-D.
     if (op.getValueToStore().getType().getRank() != 1)
-      return rewriter.notifyMatchFailure(op,
-                                         "only 1-D vectors are supported ATM");
+      return rewriter.notifyMatchFailure(
+          op, "Memref in vector.maskedstore op must be flattened beforehand.");
 
     auto loc = op.getLoc();
     auto containerElemTy =
@@ -931,18 +936,27 @@ struct ConvertVectorMaskedStore final
 // ConvertVectorLoad
 //===----------------------------------------------------------------------===//
 
-// TODO: Document-me
+/// Converts `vector.load` on narrow element types to work with
+/// wider, byte-aligned container types by adjusting load sizes and using
+/// bitcasting.
+///
+/// Example: `vector.load` of `vector<4xi4>` from `memref<3x4xi4>` is emulated
+/// by loading `vector<2xi8>` from the linearized `memref<6xi8>` and bitcasting
+/// back, since each `i8` container holds two `i4` values.
+///
+/// There are cases where the number of elements to load is not byte-aligned. In
+/// those cases, loads are converted to byte-aligned, byte-sized loads and the
+/// target vector is extracted from the loaded vector.
 struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-
-    // See #115653
+    // Prerequisites:  memref in the vector.load op is flattened into 1-D.
     if (op.getVectorType().getRank() != 1)
-      return rewriter.notifyMatchFailure(op,
-                                         "only 1-D vectors are supported ATM");
+      return rewriter.notifyMatchFailure(
+          op, "Memref in emulated vector ops must be flattened beforehand.");
 
     auto loc = op.getLoc();
     auto containerElemTy =
@@ -961,8 +975,6 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
 
     // Adjust the number of elements to load when emulating narrow types,
     // and then cast back to the original type with vector.bitcast op.
-    // Here only the 1-D vector load is considered, and the N-D memref types
-    // should be linearized.
     // For example, to emulate i4 to i8, the following op:
     //
     // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
@@ -1037,7 +1049,12 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
 // ConvertVectorMaskedLoad
 //===----------------------------------------------------------------------===//
 
-// TODO: Document-me
+/// Converts `vector.maskedload` operations on narrow element types to work with
+/// wider, byte-aligned container types by adjusting the mask and using
+/// bitcasting.
+///
+/// Example: Loading `vector<6xi4>` is emulated by loading `vector<3xi8>` and
+/// bitcasting, since each `i8` container element holds two `i4` values.
 struct ConvertVectorMaskedLoad final
     : OpConversionPattern<vector::MaskedLoadOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -1045,10 +1062,9 @@ struct ConvertVectorMaskedLoad final
   LogicalResult
   matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // See #115653
     if (op.getVectorType().getRank() != 1)
-      return rewriter.notifyMatchFailure(op,
-                                         "only 1-D vectors are supported ATM");
+      return rewriter.notifyMatchFailure(
+          op, "Memref in emulated vector ops must be flattened beforehand.");
 
     auto loc = op.getLoc();
 
@@ -1229,7 +1245,6 @@ static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
 
   int elemsPerMultiByte = multiByteBits / subByteBits;
 
-  // TODO: This is a bit too restrictive for vectors rank > 1.
   return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
 }
 
@@ -1246,10 +1261,11 @@ struct ConvertVectorTransferRead final
   matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    // See #115653
+    // Prerequisites:  memref in the vector.transfer_read op is flattened into
+    // 1-D.
     if (op.getVectorType().getRank() != 1)
-      return rewriter.notifyMatchFailure(op,
-                                         "only 1-D vectors are supported ATM");
+      return rewriter.notifyMatchFailure(
+          op, "Memref in emulated vector ops must be flattened beforehand.");
 
     auto loc = op.getLoc();
     auto containerElemTy =
@@ -2228,6 +2244,9 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
     const arith::NarrowTypeEmulationConverter &typeConverter,
     RewritePatternSet &patterns, bool disableAtomicRMW) {
 
+  // As a prerequisite, make sure memrefs in vector ops are linearized.
+  memref::populateFlattenVectorMemrefPatterns(patterns);
+
   // Populate `vector.*` conversion patterns.
   // TODO: #119553 support atomicity
   patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
>From b061ab9c0ab90bebd83e219e936237bd751a54d2 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 9 Sep 2025 22:00:52 -0400
Subject: [PATCH 2/4] Forgot to push
---
 mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h      | 4 ++--
 .../lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 2 ++
 2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index e7751df724f9c..6b1cd02fec5b4 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -147,8 +147,8 @@ void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
 
 /// Patterns for flattening multi-dimensional memref operations into
 /// one-dimensional memref operations.
-void populateFlattenVectorMemRefPatterns(RewritePatternSet &patterns);
-void populateFlattenMemRefOpsPatterns(RewritePatternSet &patterns);
+void populateFlattenVectorMemrefPatterns(RewritePatternSet &patterns);
+void populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns);
 void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
 
 /// Build a new memref::AllocaOp whose dynamic sizes are independent of all
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 2ea17dbe2f53e..61b2df68c0f7c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -38,6 +38,8 @@
 #include <cstdint>
 #include <optional>
 
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+
 using namespace mlir;
 
 #define DEBUG_TYPE "vector-narrow-type-emulation"
>From f523d84b758647657fb1157d274204834d0a8b6c Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 10 Sep 2025 14:05:28 -0400
Subject: [PATCH 3/4] Adding new tests
---
 .../Vector/Transforms/VectorRewritePatterns.h |  6 ++
 .../Transforms/VectorEmulateNarrowType.cpp    | 13 ++--
 ...atten-memref-and-emulate-narrow-types.mlir | 38 +++++++++++
 .../Dialect/MemRef/TestEmulateNarrowType.cpp  | 63 +++++++++++++++++++
 4 files changed, 115 insertions(+), 5 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 0138f477cadea..66a14bc23a5ed 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -383,6 +383,12 @@ void populateVectorNarrowTypeEmulationPatterns(
     const arith::NarrowTypeEmulationConverter &typeConverter,
     RewritePatternSet &patterns, bool disableAtomicRMW = false);
 
+/// Populates patterns for both memref flattening and vector narrow type
+/// emulation.
+void populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
+    arith::NarrowTypeEmulationConverter &typeConverter,
+    RewritePatternSet &patterns);
+
 /// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
 /// vector operations comprising `shuffle` and `bitwise` ops.
 /// Warning: these patterns currently only work for little endian targets.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 61b2df68c0f7c..09a369f297ac0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -955,7 +955,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
   LogicalResult
   matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // Prerequisites:  memref in the vector.load op is flattened into 1-D.
+    // Prerequisite:  memref in the vector.load op is flattened into 1-D.
     if (op.getVectorType().getRank() != 1)
       return rewriter.notifyMatchFailure(
           op, "Memref in emulated vector ops must be flattened beforehand.");
@@ -2245,10 +2245,6 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
 void vector::populateVectorNarrowTypeEmulationPatterns(
     const arith::NarrowTypeEmulationConverter &typeConverter,
     RewritePatternSet &patterns, bool disableAtomicRMW) {
-
-  // As a prerequisite, make sure memrefs in vector ops are linearized.
-  memref::populateFlattenVectorMemrefPatterns(patterns);
-
   // Populate `vector.*` conversion patterns.
   // TODO: #119553 support atomicity
   patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
@@ -2287,3 +2283,10 @@ void vector::populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
 }
+
+void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
+    arith::NarrowTypeEmulationConverter &typeConverter,
+    RewritePatternSet &patterns) {
+  memref::populateFlattenVectorMemrefPatterns(patterns);
+  vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
+}
diff --git a/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir b/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir
new file mode 100644
index 0000000000000..a0a038c728f59
--- /dev/null
+++ b/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt --test-memref-flatten-and-vector-narrow-type-emulation --split-input-file %s | FileCheck %s
+
+// -----
+
+func.func @vector_load_2d_i4(%arg0: index, %arg1: index) -> vector<8xi4> {
+    %0 = memref.alloc() : memref<4x8xi4>
+    %1 = vector.load %0[%arg0, %arg1] : memref<4x8xi4>, vector<8xi4>
+    return %1 : vector<8xi4>
+}
+//      CHECK: func @vector_load_2d_i4
+//      CHECK:   vector.load
+// CHECK-SAME:   memref<16xi8>
+
+// -----
+
+func.func @vector_maskedload_2d_i4(%arg0: index, %arg1: index, %passthru: vector<8xi4>) -> vector<8xi4> {
+    %0 = memref.alloc() : memref<4x8xi4>
+    %mask = vector.constant_mask [6] : vector<8xi1>
+    %1 = vector.maskedload %0[%arg0, %arg1], %mask, %passthru :
+      memref<4x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4>
+    return %1 : vector<8xi4>
+}
+//      CHECK: func @vector_maskedload_2d_i4(
+//      CHECK:   vector.maskedload
+// CHECK-SAME:   memref<16xi8>
+
+// -----
+
+func.func @vector_maskedstore_2d_i4(%arg0: index, %arg1: index, %value: vector<8xi4>) {
+    %0 = memref.alloc() : memref<4x8xi4>
+    %mask = vector.constant_mask [5] : vector<8xi1>
+    vector.maskedstore %0[%arg0, %arg1], %mask, %value :
+      memref<4x8xi4>, vector<8xi1>, vector<8xi4>
+    return
+}
+//      CHECK: func @vector_maskedstore_2d_i4(
+//      CHECK:   vector.maskedstore
+// CHECK-SAME:   memref<16xi8>
diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
index ba2ea40e83d96..ef631eeec5bb5 100644
--- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 
@@ -126,10 +127,72 @@ struct TestEmulateNarrowTypePass
                      "normal sequence"),
       llvm::cl::init(false)};
 };
+
+struct TestMemRefFlattenAndVectorNarrowTypeEmulationPass
+    : public PassWrapper<TestMemRefFlattenAndVectorNarrowTypeEmulationPass,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestMemRefFlattenAndVectorNarrowTypeEmulationPass)
+
+  TestMemRefFlattenAndVectorNarrowTypeEmulationPass() = default;
+  TestMemRefFlattenAndVectorNarrowTypeEmulationPass(
+      const TestMemRefFlattenAndVectorNarrowTypeEmulationPass &pass)
+      : PassWrapper(pass) {}
+
+  void getDependentDialects(DialectRegistry ®istry) const override {
+    registry
+        .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
+                vector::VectorDialect, affine::AffineDialect>();
+  }
+
+  StringRef getArgument() const final {
+    return "test-memref-flatten-and-vector-narrow-type-emulation";
+  }
+
+  StringRef getDescription() const final {
+    return "Test MemRef flattening and vector narrow type emulation patterns";
+  }
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    MLIRContext *ctx = &getContext();
+
+    // Create a type converter for narrow type emulation (8-bit)
+    arith::NarrowTypeEmulationConverter typeConverter(8);
+
+    // Add conversions for memref types with i4 elements
+    memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);
+
+    ConversionTarget target(*ctx);
+    target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
+      return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
+    });
+    auto opLegalCallback = [&typeConverter](Operation *op) {
+      return typeConverter.isLegal(op);
+    };
+    target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
+    target.addDynamicallyLegalDialect<
+        arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect,
+        affine::AffineDialect>(opLegalCallback);
+
+    RewritePatternSet patterns(ctx);
+
+    // Populate all necessary patterns for narrow type emulation and flattening
+    arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
+    memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
+    vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
+    memref::populateFlattenVectorMemrefPatterns(patterns);
+
+    // Apply partial conversion
+    if (failed(applyPartialConversion(op, target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
 } // namespace
 
 namespace mlir::test {
 void registerTestEmulateNarrowTypePass() {
   PassRegistration<TestEmulateNarrowTypePass>();
+  PassRegistration<TestMemRefFlattenAndVectorNarrowTypeEmulationPass>();
 }
 } // namespace mlir::test
>From 6edd712154fb2afd3302e4e1f6d027b66beb5b53 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Fri, 12 Sep 2025 10:52:46 -0700
Subject: [PATCH 4/4] Fix according to comments.
---
 .../Dialect/MemRef/Transforms/Transforms.h    |  2 +-
 .../Vector/Transforms/VectorRewritePatterns.h |  6 ++-
 .../MemRef/Transforms/FlattenMemRefs.cpp      |  5 ++-
 .../Transforms/VectorEmulateNarrowType.cpp    | 10 ++---
 ...atten-memref-and-emulate-narrow-types.mlir | 40 +++++++++++++++----
 .../Dialect/MemRef/TestEmulateNarrowType.cpp  |  6 +--
 6 files changed, 49 insertions(+), 20 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 6b1cd02fec5b4..8b76930aed35a 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -147,7 +147,7 @@ void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
 
 /// Patterns for flattening multi-dimensional memref operations into
 /// one-dimensional memref operations.
-void populateFlattenVectorMemrefPatterns(RewritePatternSet &patterns);
+void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns);
 void populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns);
 void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
 
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 66a14bc23a5ed..08f439222a9a0 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -383,8 +383,12 @@ void populateVectorNarrowTypeEmulationPatterns(
     const arith::NarrowTypeEmulationConverter &typeConverter,
     RewritePatternSet &patterns, bool disableAtomicRMW = false);
 
-/// Populates patterns for both memref flattening and vector narrow type
+/// Populates patterns for both MeMref flattening and Vector narrow type
 /// emulation.
+///
+/// Patterns for narrow-type-emulation require "flattened" MemRef(s), so this
+/// composite populate* method can be used for narrow-type-emulation for Ops
+/// operating on MemRef(s) that are rank > 2.
 void populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
     arith::NarrowTypeEmulationConverter &typeConverter,
     RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index d658d147a0a3a..1208fddf37e0b 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -271,7 +271,8 @@ struct FlattenMemrefsPass
 
 } // namespace
 
-void memref::populateFlattenVectorMemrefPatterns(RewritePatternSet &patterns) {
+void memref::populateFlattenVectorOpsOnMemrefPatterns(
+    RewritePatternSet &patterns) {
   patterns.insert<MemRefRewritePattern<vector::LoadOp>,
                   MemRefRewritePattern<vector::StoreOp>,
                   MemRefRewritePattern<vector::TransferReadOp>,
@@ -291,5 +292,5 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
 
 void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
   populateFlattenMemrefOpsPatterns(patterns);
-  populateFlattenVectorMemrefPatterns(patterns);
+  populateFlattenVectorOpsOnMemrefPatterns(patterns);
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 09a369f297ac0..197bcf6eb3ff1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -823,8 +823,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
 /// bitcasting.
 ///
 /// Example: Storing `vector<6xi4>` is emulated by bitcasting to `vector<3xi8>`
-/// and storing with an adjusted mask, since each `i8` container element holds
-/// two `i4` values.
+/// (each `i8` container element holds two `i4` values) and storing with an
+/// adjusted mask .
 struct ConvertVectorMaskedStore final
     : OpConversionPattern<vector::MaskedStoreOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -943,8 +943,8 @@ struct ConvertVectorMaskedStore final
 /// bitcasting.
 ///
 /// Example: `vector.load` of `vector<4xi4>` from `memref<3x4xi4>` is emulated
-/// by loading `vector<2xi8>` from the linearized `memref<6xi8>` and bitcasting
-/// back, since each `i8` container holds two `i4` values.
+/// by loading `vector<2xi8>` from the linearized `memref<6xi8>` (each `i8`
+/// container holds two `i4` values) and bitcasting back.
 ///
 /// There are cases where the number of elements to load is not byte-aligned. In
 /// those cases, loads are converted to byte-aligned, byte-sized loads and the
@@ -2287,6 +2287,6 @@ void vector::populateVectorTransposeNarrowTypeRewritePatterns(
 void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
     arith::NarrowTypeEmulationConverter &typeConverter,
     RewritePatternSet &patterns) {
-  memref::populateFlattenVectorMemrefPatterns(patterns);
+  memref::populateFlattenVectorOpsOnMemrefPatterns(patterns);
   vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
 }
diff --git a/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir b/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir
index a0a038c728f59..ad89589c0e717 100644
--- a/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir
@@ -1,6 +1,13 @@
 // RUN: mlir-opt --test-memref-flatten-and-vector-narrow-type-emulation --split-input-file %s | FileCheck %s
 
-// -----
+// This test verifies that narrow-type-emulation works correctly for
+// rank > 1 memrefs by combining memref flattening with vector narrow type
+// emulation patterns. 
+//
+// The patterns tested here demonstrate the composition of two transformations,
+// memref flattening for vector ops and vector op narrow type emulation.
+//
+// TODO: Support `vector.transfer_write` operation.
 
 func.func @vector_load_2d_i4(%arg0: index, %arg1: index) -> vector<8xi4> {
     %0 = memref.alloc() : memref<4x8xi4>
@@ -8,8 +15,7 @@ func.func @vector_load_2d_i4(%arg0: index, %arg1: index) -> vector<8xi4> {
     return %1 : vector<8xi4>
 }
 //      CHECK: func @vector_load_2d_i4
-//      CHECK:   vector.load
-// CHECK-SAME:   memref<16xi8>
+//      CHECK:   vector.load {{.*}} memref<16xi8>
 
 // -----
 
@@ -21,8 +27,7 @@ func.func @vector_maskedload_2d_i4(%arg0: index, %arg1: index, %passthru: vector
     return %1 : vector<8xi4>
 }
 //      CHECK: func @vector_maskedload_2d_i4(
-//      CHECK:   vector.maskedload
-// CHECK-SAME:   memref<16xi8>
+//      CHECK:   vector.maskedload {{.*}} memref<16xi8>
 
 // -----
 
@@ -34,5 +39,26 @@ func.func @vector_maskedstore_2d_i4(%arg0: index, %arg1: index, %value: vector<8
     return
 }
 //      CHECK: func @vector_maskedstore_2d_i4(
-//      CHECK:   vector.maskedstore
-// CHECK-SAME:   memref<16xi8>
+//      CHECK:   vector.maskedstore {{.*}} memref<16xi8>
+
+// -----
+
+func.func @vector_store_2d_i4(%arg0: index, %arg1: index, %value: vector<8xi4>) {
+    %0 = memref.alloc() : memref<4x8xi4>
+    vector.store %value, %0[%arg0, %arg1] : memref<4x8xi4>, vector<8xi4>
+    return
+}
+//      CHECK: func @vector_store_2d_i4(
+//      CHECK:   vector.store {{.*}} memref<16xi8>
+
+// -----
+
+func.func @vector_transfer_read_2d_i4(%arg0: index, %arg1: index, %padding: i4) -> vector<8xi4> {
+    %0 = memref.alloc() : memref<4x8xi4>
+    %1 = vector.transfer_read %0[%arg0, %arg1], %padding {in_bounds = [true]} : memref<4x8xi4>, vector<8xi4>
+    return %1 : vector<8xi4>
+}
+//      CHECK: func @vector_transfer_read_2d_i4(
+//      CHECK-SAME: %{{.*}}: index, %{{.*}}: index, %[[PADDING_I4:.*]]: i4)
+//      CHECK:   %[[PADDING_I8:.*]] = arith.extui %[[PADDING_I4]] : i4 to i8
+//      CHECK:   vector.transfer_read {{.*}}, %[[PADDING_I8]] : memref<16xi8>, vector<4xi8>
diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
index ef631eeec5bb5..a0ecd574b7b9c 100644
--- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -177,11 +177,9 @@ struct TestMemRefFlattenAndVectorNarrowTypeEmulationPass
 
     RewritePatternSet patterns(ctx);
 
-    // Populate all necessary patterns for narrow type emulation and flattening
-    arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
     memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
-    vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
-    memref::populateFlattenVectorMemrefPatterns(patterns);
+    vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
+        typeConverter, patterns);
 
     // Apply partial conversion
     if (failed(applyPartialConversion(op, target, std::move(patterns))))
    
    
More information about the Mlir-commits
mailing list