[Mlir-commits] [mlir] [MLIR] Make 1-D memref flattening a prerequisite for vector narrow type emulation (PR #157771)
Alan Li
llvmlistbot at llvm.org
Fri Sep 12 12:01:39 PDT 2025
https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/157771
>From 686a25b9e2da02f2d89c18305f4ecadf011ed731 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 9 Sep 2025 19:24:24 -0400
Subject: [PATCH 1/4] First updates.
---
.../Dialect/MemRef/Transforms/Transforms.h | 4 ++
.../MemRef/Transforms/FlattenMemRefs.cpp | 21 +++++--
.../Transforms/VectorEmulateNarrowType.cpp | 59 ++++++++++++-------
3 files changed, 58 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 33e3d94f02b1c..e7751df724f9c 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -145,6 +145,10 @@ FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
/// ```
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
+/// Patterns for flattening multi-dimensional memref operations into
+/// one-dimensional memref operations.
+void populateFlattenVectorMemRefPatterns(RewritePatternSet &patterns);
+void populateFlattenMemRefOpsPatterns(RewritePatternSet &patterns);
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
/// Build a new memref::AllocaOp whose dynamic sizes are independent of all
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 42be847811d52..d658d147a0a3a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -271,12 +271,8 @@ struct FlattenMemrefsPass
} // namespace
-void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
- patterns.insert<MemRefRewritePattern<memref::LoadOp>,
- MemRefRewritePattern<memref::StoreOp>,
- MemRefRewritePattern<memref::AllocOp>,
- MemRefRewritePattern<memref::AllocaOp>,
- MemRefRewritePattern<vector::LoadOp>,
+void memref::populateFlattenVectorMemrefPatterns(RewritePatternSet &patterns) {
+ patterns.insert<MemRefRewritePattern<vector::LoadOp>,
MemRefRewritePattern<vector::StoreOp>,
MemRefRewritePattern<vector::TransferReadOp>,
MemRefRewritePattern<vector::TransferWriteOp>,
@@ -284,3 +280,16 @@ void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
MemRefRewritePattern<vector::MaskedStoreOp>>(
patterns.getContext());
}
+
+void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
+ patterns.insert<MemRefRewritePattern<memref::LoadOp>,
+ MemRefRewritePattern<memref::StoreOp>,
+ MemRefRewritePattern<memref::AllocOp>,
+ MemRefRewritePattern<memref::AllocaOp>>(
+ patterns.getContext());
+}
+
+void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
+ populateFlattenMemrefOpsPatterns(patterns);
+ populateFlattenVectorMemrefPatterns(patterns);
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index f78e579d6c099..2ea17dbe2f53e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -556,7 +556,6 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // See #115653
if (op.getValueToStore().getType().getRank() != 1)
return rewriter.notifyMatchFailure(op,
"only 1-D vectors are supported ATM");
@@ -817,7 +816,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// ConvertVectorMaskedStore
//===----------------------------------------------------------------------===//
-// TODO: Document-me
+/// Converts `vector.maskedstore` operations on narrow element types to work
+/// with wider, byte-aligned container types by adjusting the mask and using
+/// bitcasting.
+///
+/// Example: Storing `vector<6xi4>` is emulated by bitcasting to `vector<3xi8>`
+/// and storing with an adjusted mask, since each `i8` container element holds
+/// two `i4` values.
struct ConvertVectorMaskedStore final
: OpConversionPattern<vector::MaskedStoreOp> {
using OpConversionPattern::OpConversionPattern;
@@ -826,10 +831,10 @@ struct ConvertVectorMaskedStore final
matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // See #115653
+ // Prerequisite: memref in the vector.maskedstore op is flattened into 1-D.
if (op.getValueToStore().getType().getRank() != 1)
- return rewriter.notifyMatchFailure(op,
- "only 1-D vectors are supported ATM");
+ return rewriter.notifyMatchFailure(
+ op, "Memref in vector.maskedstore op must be flattened beforehand.");
auto loc = op.getLoc();
auto containerElemTy =
@@ -931,18 +936,27 @@ struct ConvertVectorMaskedStore final
// ConvertVectorLoad
//===----------------------------------------------------------------------===//
-// TODO: Document-me
+/// Converts `vector.load` on narrow element types to work with
+/// wider, byte-aligned container types by adjusting load sizes and using
+/// bitcasting.
+///
+/// Example: `vector.load` of `vector<4xi4>` from `memref<3x4xi4>` is emulated
+/// by loading `vector<2xi8>` from the linearized `memref<6xi8>` and bitcasting
+/// back, since each `i8` container holds two `i4` values.
+///
+/// There are cases where the number of elements to load is not byte-aligned. In
+/// those cases, loads are converted to byte-aligned, byte-sized loads and the
+/// target vector is extracted from the loaded vector.
struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
-
- // See #115653
+ // Prerequisites: memref in the vector.load op is flattened into 1-D.
if (op.getVectorType().getRank() != 1)
- return rewriter.notifyMatchFailure(op,
- "only 1-D vectors are supported ATM");
+ return rewriter.notifyMatchFailure(
+ op, "Memref in emulated vector ops must be flattened beforehand.");
auto loc = op.getLoc();
auto containerElemTy =
@@ -961,8 +975,6 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
- // Here only the 1-D vector load is considered, and the N-D memref types
- // should be linearized.
// For example, to emulate i4 to i8, the following op:
//
// %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
@@ -1037,7 +1049,12 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
// ConvertVectorMaskedLoad
//===----------------------------------------------------------------------===//
-// TODO: Document-me
+/// Converts `vector.maskedload` operations on narrow element types to work with
+/// wider, byte-aligned container types by adjusting the mask and using
+/// bitcasting.
+///
+/// Example: Loading `vector<6xi4>` is emulated by loading `vector<3xi8>` and
+/// bitcasting, since each `i8` container element holds two `i4` values.
struct ConvertVectorMaskedLoad final
: OpConversionPattern<vector::MaskedLoadOp> {
using OpConversionPattern::OpConversionPattern;
@@ -1045,10 +1062,9 @@ struct ConvertVectorMaskedLoad final
LogicalResult
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // See #115653
if (op.getVectorType().getRank() != 1)
- return rewriter.notifyMatchFailure(op,
- "only 1-D vectors are supported ATM");
+ return rewriter.notifyMatchFailure(
+ op, "Memref in emulated vector ops must be flattened beforehand.");
auto loc = op.getLoc();
@@ -1229,7 +1245,6 @@ static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
int elemsPerMultiByte = multiByteBits / subByteBits;
- // TODO: This is a bit too restrictive for vectors rank > 1.
return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
}
@@ -1246,10 +1261,11 @@ struct ConvertVectorTransferRead final
matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // See #115653
+ // Prerequisites: memref in the vector.transfer_read op is flattened into
+ // 1-D.
if (op.getVectorType().getRank() != 1)
- return rewriter.notifyMatchFailure(op,
- "only 1-D vectors are supported ATM");
+ return rewriter.notifyMatchFailure(
+ op, "Memref in emulated vector ops must be flattened beforehand.");
auto loc = op.getLoc();
auto containerElemTy =
@@ -2228,6 +2244,9 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns, bool disableAtomicRMW) {
+ // As a prerequisite, make sure memrefs in vector ops are linearized.
+ memref::populateFlattenVectorMemrefPatterns(patterns);
+
// Populate `vector.*` conversion patterns.
// TODO: #119553 support atomicity
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
>From b061ab9c0ab90bebd83e219e936237bd751a54d2 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 9 Sep 2025 22:00:52 -0400
Subject: [PATCH 2/4] Forgot to push
---
mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h | 4 ++--
.../lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 2 ++
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index e7751df724f9c..6b1cd02fec5b4 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -147,8 +147,8 @@ void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
/// Patterns for flattening multi-dimensional memref operations into
/// one-dimensional memref operations.
-void populateFlattenVectorMemRefPatterns(RewritePatternSet &patterns);
-void populateFlattenMemRefOpsPatterns(RewritePatternSet &patterns);
+void populateFlattenVectorMemrefPatterns(RewritePatternSet &patterns);
+void populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns);
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
/// Build a new memref::AllocaOp whose dynamic sizes are independent of all
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 2ea17dbe2f53e..61b2df68c0f7c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -38,6 +38,8 @@
#include <cstdint>
#include <optional>
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+
using namespace mlir;
#define DEBUG_TYPE "vector-narrow-type-emulation"
>From f523d84b758647657fb1157d274204834d0a8b6c Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Wed, 10 Sep 2025 14:05:28 -0400
Subject: [PATCH 3/4] Adding new tests
---
.../Vector/Transforms/VectorRewritePatterns.h | 6 ++
.../Transforms/VectorEmulateNarrowType.cpp | 13 ++--
...atten-memref-and-emulate-narrow-types.mlir | 38 +++++++++++
.../Dialect/MemRef/TestEmulateNarrowType.cpp | 63 +++++++++++++++++++
4 files changed, 115 insertions(+), 5 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 0138f477cadea..66a14bc23a5ed 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -383,6 +383,12 @@ void populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns, bool disableAtomicRMW = false);
+/// Populates patterns for both memref flattening and vector narrow type
+/// emulation.
+void populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
+ arith::NarrowTypeEmulationConverter &typeConverter,
+ RewritePatternSet &patterns);
+
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
/// vector operations comprising `shuffle` and `bitwise` ops.
/// Warning: these patterns currently only work for little endian targets.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 61b2df68c0f7c..09a369f297ac0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -955,7 +955,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
LogicalResult
matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // Prerequisites: memref in the vector.load op is flattened into 1-D.
+ // Prerequisite: memref in the vector.load op is flattened into 1-D.
if (op.getVectorType().getRank() != 1)
return rewriter.notifyMatchFailure(
op, "Memref in emulated vector ops must be flattened beforehand.");
@@ -2245,10 +2245,6 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
void vector::populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns, bool disableAtomicRMW) {
-
- // As a prerequisite, make sure memrefs in vector ops are linearized.
- memref::populateFlattenVectorMemrefPatterns(patterns);
-
// Populate `vector.*` conversion patterns.
// TODO: #119553 support atomicity
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
@@ -2287,3 +2283,10 @@ void vector::populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
}
+
+void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
+ arith::NarrowTypeEmulationConverter &typeConverter,
+ RewritePatternSet &patterns) {
+ memref::populateFlattenVectorMemrefPatterns(patterns);
+ vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
+}
diff --git a/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir b/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir
new file mode 100644
index 0000000000000..a0a038c728f59
--- /dev/null
+++ b/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt --test-memref-flatten-and-vector-narrow-type-emulation --split-input-file %s | FileCheck %s
+
+// -----
+
+func.func @vector_load_2d_i4(%arg0: index, %arg1: index) -> vector<8xi4> {
+ %0 = memref.alloc() : memref<4x8xi4>
+ %1 = vector.load %0[%arg0, %arg1] : memref<4x8xi4>, vector<8xi4>
+ return %1 : vector<8xi4>
+}
+// CHECK: func @vector_load_2d_i4
+// CHECK: vector.load
+// CHECK-SAME: memref<16xi8>
+
+// -----
+
+func.func @vector_maskedload_2d_i4(%arg0: index, %arg1: index, %passthru: vector<8xi4>) -> vector<8xi4> {
+ %0 = memref.alloc() : memref<4x8xi4>
+ %mask = vector.constant_mask [6] : vector<8xi1>
+ %1 = vector.maskedload %0[%arg0, %arg1], %mask, %passthru :
+ memref<4x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4>
+ return %1 : vector<8xi4>
+}
+// CHECK: func @vector_maskedload_2d_i4(
+// CHECK: vector.maskedload
+// CHECK-SAME: memref<16xi8>
+
+// -----
+
+func.func @vector_maskedstore_2d_i4(%arg0: index, %arg1: index, %value: vector<8xi4>) {
+ %0 = memref.alloc() : memref<4x8xi4>
+ %mask = vector.constant_mask [5] : vector<8xi1>
+ vector.maskedstore %0[%arg0, %arg1], %mask, %value :
+ memref<4x8xi4>, vector<8xi1>, vector<8xi4>
+ return
+}
+// CHECK: func @vector_maskedstore_2d_i4(
+// CHECK: vector.maskedstore
+// CHECK-SAME: memref<16xi8>
diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
index ba2ea40e83d96..ef631eeec5bb5 100644
--- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
@@ -126,10 +127,72 @@ struct TestEmulateNarrowTypePass
"normal sequence"),
llvm::cl::init(false)};
};
+
+struct TestMemRefFlattenAndVectorNarrowTypeEmulationPass
+ : public PassWrapper<TestMemRefFlattenAndVectorNarrowTypeEmulationPass,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestMemRefFlattenAndVectorNarrowTypeEmulationPass)
+
+ TestMemRefFlattenAndVectorNarrowTypeEmulationPass() = default;
+ TestMemRefFlattenAndVectorNarrowTypeEmulationPass(
+ const TestMemRefFlattenAndVectorNarrowTypeEmulationPass &pass)
+ : PassWrapper(pass) {}
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry
+ .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
+ vector::VectorDialect, affine::AffineDialect>();
+ }
+
+ StringRef getArgument() const final {
+ return "test-memref-flatten-and-vector-narrow-type-emulation";
+ }
+
+ StringRef getDescription() const final {
+ return "Test MemRef flattening and vector narrow type emulation patterns";
+ }
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ MLIRContext *ctx = &getContext();
+
+ // Create a type converter for narrow type emulation (8-bit)
+ arith::NarrowTypeEmulationConverter typeConverter(8);
+
+ // Add conversions for memref types with i4 elements
+ memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);
+
+ ConversionTarget target(*ctx);
+ target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
+ return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
+ });
+ auto opLegalCallback = [&typeConverter](Operation *op) {
+ return typeConverter.isLegal(op);
+ };
+ target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
+ target.addDynamicallyLegalDialect<
+ arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect,
+ affine::AffineDialect>(opLegalCallback);
+
+ RewritePatternSet patterns(ctx);
+
+ // Populate all necessary patterns for narrow type emulation and flattening
+ arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
+ memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
+ vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
+ memref::populateFlattenVectorMemrefPatterns(patterns);
+
+ // Apply partial conversion
+ if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
} // namespace
namespace mlir::test {
void registerTestEmulateNarrowTypePass() {
PassRegistration<TestEmulateNarrowTypePass>();
+ PassRegistration<TestMemRefFlattenAndVectorNarrowTypeEmulationPass>();
}
} // namespace mlir::test
>From 2fe656ddea8d1c8ae958073b0a255d685ba4c053 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Fri, 12 Sep 2025 10:52:46 -0700
Subject: [PATCH 4/4] Fix according to comments.
---
.../Dialect/MemRef/Transforms/Transforms.h | 2 +-
.../Vector/Transforms/VectorRewritePatterns.h | 6 ++-
.../MemRef/Transforms/FlattenMemRefs.cpp | 5 ++-
.../Transforms/VectorEmulateNarrowType.cpp | 10 ++---
...atten-memref-and-emulate-narrow-types.mlir | 40 +++++++++++++++----
.../Dialect/MemRef/TestEmulateNarrowType.cpp | 6 +--
6 files changed, 49 insertions(+), 20 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 6b1cd02fec5b4..8b76930aed35a 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -147,7 +147,7 @@ void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
/// Patterns for flattening multi-dimensional memref operations into
/// one-dimensional memref operations.
-void populateFlattenVectorMemrefPatterns(RewritePatternSet &patterns);
+void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns);
void populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns);
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 66a14bc23a5ed..c510e3c67325a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -383,8 +383,12 @@ void populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns, bool disableAtomicRMW = false);
-/// Populates patterns for both memref flattening and vector narrow type
+/// Populates patterns for both MeMref flattening and Vector narrow type
/// emulation.
+///
+/// Patterns for narrow-type-emulation require "flattened" MemRef(s), so this
+/// composite populate* method can be used for narrow-type-emulation for Ops
+/// operating on MemRef(s) that are rank > 2.
void populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index d658d147a0a3a..1208fddf37e0b 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -271,7 +271,8 @@ struct FlattenMemrefsPass
} // namespace
-void memref::populateFlattenVectorMemrefPatterns(RewritePatternSet &patterns) {
+void memref::populateFlattenVectorOpsOnMemrefPatterns(
+ RewritePatternSet &patterns) {
patterns.insert<MemRefRewritePattern<vector::LoadOp>,
MemRefRewritePattern<vector::StoreOp>,
MemRefRewritePattern<vector::TransferReadOp>,
@@ -291,5 +292,5 @@ void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
populateFlattenMemrefOpsPatterns(patterns);
- populateFlattenVectorMemrefPatterns(patterns);
+ populateFlattenVectorOpsOnMemrefPatterns(patterns);
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 09a369f297ac0..5bec5d3a8f847 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -823,8 +823,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
/// bitcasting.
///
/// Example: Storing `vector<6xi4>` is emulated by bitcasting to `vector<3xi8>`
-/// and storing with an adjusted mask, since each `i8` container element holds
-/// two `i4` values.
+/// (each `i8` container element holds two `i4` values) and storing with an
+/// adjusted mask .
struct ConvertVectorMaskedStore final
: OpConversionPattern<vector::MaskedStoreOp> {
using OpConversionPattern::OpConversionPattern;
@@ -943,8 +943,8 @@ struct ConvertVectorMaskedStore final
/// bitcasting.
///
/// Example: `vector.load` of `vector<4xi4>` from `memref<3x4xi4>` is emulated
-/// by loading `vector<2xi8>` from the linearized `memref<6xi8>` and bitcasting
-/// back, since each `i8` container holds two `i4` values.
+/// by loading `vector<2xi8>` from the linearized `memref<6xi8>` (each `i8`
+/// container holds two `i4` values) and bitcasting back.
///
/// There are cases where the number of elements to load is not byte-aligned. In
/// those cases, loads are converted to byte-aligned, byte-sized loads and the
@@ -2287,6 +2287,6 @@ void vector::populateVectorTransposeNarrowTypeRewritePatterns(
void vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns) {
- memref::populateFlattenVectorMemrefPatterns(patterns);
+ memref::populateFlattenVectorOpsOnMemrefPatterns(patterns);
vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
}
diff --git a/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir b/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir
index a0a038c728f59..ad89589c0e717 100644
--- a/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir
@@ -1,6 +1,13 @@
// RUN: mlir-opt --test-memref-flatten-and-vector-narrow-type-emulation --split-input-file %s | FileCheck %s
-// -----
+// This test verifies that narrow-type-emulation works correctly for
+// rank > 1 memrefs by combining memref flattening with vector narrow type
+// emulation patterns.
+//
+// The patterns tested here demonstrate the composition of two transformations,
+// memref flattening for vector ops and vector op narrow type emulation.
+//
+// TODO: Support `vector.transfer_write` operation.
func.func @vector_load_2d_i4(%arg0: index, %arg1: index) -> vector<8xi4> {
%0 = memref.alloc() : memref<4x8xi4>
@@ -8,8 +15,7 @@ func.func @vector_load_2d_i4(%arg0: index, %arg1: index) -> vector<8xi4> {
return %1 : vector<8xi4>
}
// CHECK: func @vector_load_2d_i4
-// CHECK: vector.load
-// CHECK-SAME: memref<16xi8>
+// CHECK: vector.load {{.*}} memref<16xi8>
// -----
@@ -21,8 +27,7 @@ func.func @vector_maskedload_2d_i4(%arg0: index, %arg1: index, %passthru: vector
return %1 : vector<8xi4>
}
// CHECK: func @vector_maskedload_2d_i4(
-// CHECK: vector.maskedload
-// CHECK-SAME: memref<16xi8>
+// CHECK: vector.maskedload {{.*}} memref<16xi8>
// -----
@@ -34,5 +39,26 @@ func.func @vector_maskedstore_2d_i4(%arg0: index, %arg1: index, %value: vector<8
return
}
// CHECK: func @vector_maskedstore_2d_i4(
-// CHECK: vector.maskedstore
-// CHECK-SAME: memref<16xi8>
+// CHECK: vector.maskedstore {{.*}} memref<16xi8>
+
+// -----
+
+func.func @vector_store_2d_i4(%arg0: index, %arg1: index, %value: vector<8xi4>) {
+ %0 = memref.alloc() : memref<4x8xi4>
+ vector.store %value, %0[%arg0, %arg1] : memref<4x8xi4>, vector<8xi4>
+ return
+}
+// CHECK: func @vector_store_2d_i4(
+// CHECK: vector.store {{.*}} memref<16xi8>
+
+// -----
+
+func.func @vector_transfer_read_2d_i4(%arg0: index, %arg1: index, %padding: i4) -> vector<8xi4> {
+ %0 = memref.alloc() : memref<4x8xi4>
+ %1 = vector.transfer_read %0[%arg0, %arg1], %padding {in_bounds = [true]} : memref<4x8xi4>, vector<8xi4>
+ return %1 : vector<8xi4>
+}
+// CHECK: func @vector_transfer_read_2d_i4(
+// CHECK-SAME: %{{.*}}: index, %{{.*}}: index, %[[PADDING_I4:.*]]: i4)
+// CHECK: %[[PADDING_I8:.*]] = arith.extui %[[PADDING_I4]] : i4 to i8
+// CHECK: vector.transfer_read {{.*}}, %[[PADDING_I8]] : memref<16xi8>, vector<4xi8>
diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
index ef631eeec5bb5..8ba018fec9f74 100644
--- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
+++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
@@ -177,11 +177,9 @@ struct TestMemRefFlattenAndVectorNarrowTypeEmulationPass
RewritePatternSet patterns(ctx);
- // Populate all necessary patterns for narrow type emulation and flattening
- arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
- vector::populateVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
- memref::populateFlattenVectorMemrefPatterns(patterns);
+ vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(typeConverter, patterns);
+
// Apply partial conversion
if (failed(applyPartialConversion(op, target, std::move(patterns))))
More information about the Mlir-commits
mailing list