[Mlir-commits] [mlir] [MLIR] Make 1-D memref flattening a prerequisite for vector narrow type emulation (PR #157771)
Alan Li
llvmlistbot at llvm.org
Tue Sep 9 19:01:13 PDT 2025
https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/157771
>From 686a25b9e2da02f2d89c18305f4ecadf011ed731 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 9 Sep 2025 19:24:24 -0400
Subject: [PATCH 1/2] First updates.
---
.../Dialect/MemRef/Transforms/Transforms.h | 4 ++
.../MemRef/Transforms/FlattenMemRefs.cpp | 21 +++++--
.../Transforms/VectorEmulateNarrowType.cpp | 59 ++++++++++++-------
3 files changed, 58 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 33e3d94f02b1c..e7751df724f9c 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -145,6 +145,10 @@ FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
/// ```
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
+/// Patterns for flattening multi-dimensional memref operations into
+/// one-dimensional memref operations.
+void populateFlattenVectorMemRefPatterns(RewritePatternSet &patterns);
+void populateFlattenMemRefOpsPatterns(RewritePatternSet &patterns);
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
/// Build a new memref::AllocaOp whose dynamic sizes are independent of all
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 42be847811d52..d658d147a0a3a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -271,12 +271,8 @@ struct FlattenMemrefsPass
} // namespace
-void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
- patterns.insert<MemRefRewritePattern<memref::LoadOp>,
- MemRefRewritePattern<memref::StoreOp>,
- MemRefRewritePattern<memref::AllocOp>,
- MemRefRewritePattern<memref::AllocaOp>,
- MemRefRewritePattern<vector::LoadOp>,
+void memref::populateFlattenVectorMemrefPatterns(RewritePatternSet &patterns) {
+ patterns.insert<MemRefRewritePattern<vector::LoadOp>,
MemRefRewritePattern<vector::StoreOp>,
MemRefRewritePattern<vector::TransferReadOp>,
MemRefRewritePattern<vector::TransferWriteOp>,
@@ -284,3 +280,16 @@ void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
MemRefRewritePattern<vector::MaskedStoreOp>>(
patterns.getContext());
}
+
+void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
+ patterns.insert<MemRefRewritePattern<memref::LoadOp>,
+ MemRefRewritePattern<memref::StoreOp>,
+ MemRefRewritePattern<memref::AllocOp>,
+ MemRefRewritePattern<memref::AllocaOp>>(
+ patterns.getContext());
+}
+
+void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
+ populateFlattenMemrefOpsPatterns(patterns);
+ populateFlattenVectorMemrefPatterns(patterns);
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index f78e579d6c099..2ea17dbe2f53e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -556,7 +556,6 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // See #115653
if (op.getValueToStore().getType().getRank() != 1)
return rewriter.notifyMatchFailure(op,
"only 1-D vectors are supported ATM");
@@ -817,7 +816,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// ConvertVectorMaskedStore
//===----------------------------------------------------------------------===//
-// TODO: Document-me
+/// Converts `vector.maskedstore` operations on narrow element types to work
+/// with wider, byte-aligned container types by adjusting the mask and using
+/// bitcasting.
+///
+/// Example: Storing `vector<6xi4>` is emulated by bitcasting to `vector<3xi8>`
+/// and storing with an adjusted mask, since each `i8` container element holds
+/// two `i4` values.
struct ConvertVectorMaskedStore final
: OpConversionPattern<vector::MaskedStoreOp> {
using OpConversionPattern::OpConversionPattern;
@@ -826,10 +831,10 @@ struct ConvertVectorMaskedStore final
matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // See #115653
+ // Prerequisite: memref in the vector.maskedstore op is flattened into 1-D.
if (op.getValueToStore().getType().getRank() != 1)
- return rewriter.notifyMatchFailure(op,
- "only 1-D vectors are supported ATM");
+ return rewriter.notifyMatchFailure(
+ op, "Memref in vector.maskedstore op must be flattened beforehand.");
auto loc = op.getLoc();
auto containerElemTy =
@@ -931,18 +936,27 @@ struct ConvertVectorMaskedStore final
// ConvertVectorLoad
//===----------------------------------------------------------------------===//
-// TODO: Document-me
+/// Converts `vector.load` on narrow element types to work with
+/// wider, byte-aligned container types by adjusting load sizes and using
+/// bitcasting.
+///
+/// Example: `vector.load` of `vector<4xi4>` from `memref<3x4xi4>` is emulated
+/// by loading `vector<2xi8>` from the linearized `memref<6xi8>` and bitcasting
+/// back, since each `i8` container holds two `i4` values.
+///
+/// There are cases where the number of elements to load is not byte-aligned. In
+/// those cases, loads are converted to byte-aligned, byte-sized loads and the
+/// target vector is extracted from the loaded vector.
struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
-
- // See #115653
+ // Prerequisites: memref in the vector.load op is flattened into 1-D.
if (op.getVectorType().getRank() != 1)
- return rewriter.notifyMatchFailure(op,
- "only 1-D vectors are supported ATM");
+ return rewriter.notifyMatchFailure(
+ op, "Memref in emulated vector ops must be flattened beforehand.");
auto loc = op.getLoc();
auto containerElemTy =
@@ -961,8 +975,6 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
- // Here only the 1-D vector load is considered, and the N-D memref types
- // should be linearized.
// For example, to emulate i4 to i8, the following op:
//
// %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4>
@@ -1037,7 +1049,12 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
// ConvertVectorMaskedLoad
//===----------------------------------------------------------------------===//
-// TODO: Document-me
+/// Converts `vector.maskedload` operations on narrow element types to work with
+/// wider, byte-aligned container types by adjusting the mask and using
+/// bitcasting.
+///
+/// Example: Loading `vector<6xi4>` is emulated by loading `vector<3xi8>` and
+/// bitcasting, since each `i8` container element holds two `i4` values.
struct ConvertVectorMaskedLoad final
: OpConversionPattern<vector::MaskedLoadOp> {
using OpConversionPattern::OpConversionPattern;
@@ -1045,10 +1062,9 @@ struct ConvertVectorMaskedLoad final
LogicalResult
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // See #115653
if (op.getVectorType().getRank() != 1)
- return rewriter.notifyMatchFailure(op,
- "only 1-D vectors are supported ATM");
+ return rewriter.notifyMatchFailure(
+ op, "Memref in emulated vector ops must be flattened beforehand.");
auto loc = op.getLoc();
@@ -1229,7 +1245,6 @@ static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
int elemsPerMultiByte = multiByteBits / subByteBits;
- // TODO: This is a bit too restrictive for vectors rank > 1.
return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
}
@@ -1246,10 +1261,11 @@ struct ConvertVectorTransferRead final
matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // See #115653
+ // Prerequisites: memref in the vector.transfer_read op is flattened into
+ // 1-D.
if (op.getVectorType().getRank() != 1)
- return rewriter.notifyMatchFailure(op,
- "only 1-D vectors are supported ATM");
+ return rewriter.notifyMatchFailure(
+ op, "Memref in emulated vector ops must be flattened beforehand.");
auto loc = op.getLoc();
auto containerElemTy =
@@ -2228,6 +2244,9 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
const arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns, bool disableAtomicRMW) {
+ // As a prerequisite, make sure memrefs in vector ops are linearized.
+ memref::populateFlattenVectorMemrefPatterns(patterns);
+
// Populate `vector.*` conversion patterns.
// TODO: #119553 support atomicity
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
>From b061ab9c0ab90bebd83e219e936237bd751a54d2 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Tue, 9 Sep 2025 22:00:52 -0400
Subject: [PATCH 2/2] Forgot to push
---
mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h | 4 ++--
.../lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 2 ++
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index e7751df724f9c..6b1cd02fec5b4 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -147,8 +147,8 @@ void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
/// Patterns for flattening multi-dimensional memref operations into
/// one-dimensional memref operations.
-void populateFlattenVectorMemRefPatterns(RewritePatternSet &patterns);
-void populateFlattenMemRefOpsPatterns(RewritePatternSet &patterns);
+void populateFlattenVectorMemrefPatterns(RewritePatternSet &patterns);
+void populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns);
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
/// Build a new memref::AllocaOp whose dynamic sizes are independent of all
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 2ea17dbe2f53e..61b2df68c0f7c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -38,6 +38,8 @@
#include <cstdint>
#include <optional>
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+
using namespace mlir;
#define DEBUG_TYPE "vector-narrow-type-emulation"
More information about the Mlir-commits
mailing list