[Mlir-commits] [mlir] [mlir] Fix MemRefType alignment in ConvertVectorToLLVM (PR #137389)
Lily Orth-Smith
llvmlistbot at llvm.org
Tue Apr 29 16:04:08 PDT 2025
https://github.com/electriclilies updated https://github.com/llvm/llvm-project/pull/137389
>From de7365648aabbd0834b4d50667938be211139f2f Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Fri, 25 Apr 2025 19:51:30 +0000
Subject: [PATCH 1/2] Fix how we get alignment for memrefs
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 19 +++++++++++++++----
1 file changed, 15 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 076e5512f375b..d5b312d6929be 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -70,15 +70,26 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
// Helper that returns data layout alignment of a memref.
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
MemRefType memrefType, unsigned &align) {
- Type elementTy = typeConverter.convertType(memrefType.getElementType());
- if (!elementTy)
+ // If shape is statically known, assign MemRefTypes to the alignment of a
+ // VectorType with the same size and dtype. Otherwise, fall back to the
+ // alignment of the element type.
+ Type convertedType;
+ if (memrefType.hasStaticShape()) {
+ convertedType = typeConverter.convertType(VectorType::get(
+ memrefType.getNumElements(), memrefType.getElementType()));
+ } else {
+ convertedType = typeConverter.convertType(memrefType.getElementType());
+ }
+
+ if (!convertedType)
return failure();
// TODO: this should use the MLIR data layout when it becomes available and
// stop depending on translation.
llvm::LLVMContext llvmContext;
- align = LLVM::TypeToLLVMIRTranslator(llvmContext)
- .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
+ align =
+ LLVM::TypeToLLVMIRTranslator(llvmContext)
+ .getPreferredAlignment(convertedType, typeConverter.getDataLayout());
return success();
}
>From 0baf2d2b5960ab9292a319ce37f588f7737273e6 Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Tue, 29 Apr 2025 22:51:13 +0000
Subject: [PATCH 2/2] Use flag to determine whether we use vector alignment or
memref alignment
---
mlir/include/mlir/Conversion/Passes.td | 5 +
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 128 +++++++++++++-----
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 3 +-
3 files changed, 99 insertions(+), 37 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..3f64e719a3fcb 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1394,6 +1394,11 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"true",
"Allows compiler to assume vector indices fit in 32-bit if that "
"yields faster code">,
+ Option<"useVectorAlignment", "use-vector-alignment",
+ "bool", /*default=*/"false",
+ "Use the preferred alignment of a vector type in load/store "
+ "operations instead of the alignment of the element type of the "
+ "memref">,
Option<"amx", "enable-amx",
"bool", /*default=*/"false",
"Enables the use of AMX dialect while lowering the vector "
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d5b312d6929be..86590c2bc3666 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -67,29 +67,33 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
}
+// Helper that returns data layout alignment of a vector.
+LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
+ VectorType vectorType, unsigned &align) {
+ Type convertedVectorTy = typeConverter.convertType(vectorType);
+ if (!convertedVectorTy)
+ return failure();
+
+ llvm::LLVMContext llvmContext;
+ align = LLVM::TypeToLLVMIRTranslator(llvmContext)
+ .getPreferredAlignment(convertedVectorTy,
+ typeConverter.getDataLayout());
+
+ return success();
+}
+
// Helper that returns data layout alignment of a memref.
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
MemRefType memrefType, unsigned &align) {
- // If shape is statically known, assign MemRefTypes to the alignment of a
- // VectorType with the same size and dtype. Otherwise, fall back to the
- // alignment of the element type.
- Type convertedType;
- if (memrefType.hasStaticShape()) {
- convertedType = typeConverter.convertType(VectorType::get(
- memrefType.getNumElements(), memrefType.getElementType()));
- } else {
- convertedType = typeConverter.convertType(memrefType.getElementType());
- }
-
- if (!convertedType)
+ Type elementTy = typeConverter.convertType(memrefType.getElementType());
+ if (!elementTy)
return failure();
// TODO: this should use the MLIR data layout when it becomes available and
// stop depending on translation.
llvm::LLVMContext llvmContext;
- align =
- LLVM::TypeToLLVMIRTranslator(llvmContext)
- .getPreferredAlignment(convertedType, typeConverter.getDataLayout());
+ align = LLVM::TypeToLLVMIRTranslator(llvmContext)
+ .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
return success();
}
@@ -235,6 +239,10 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
template <class LoadOrStoreOp>
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
public:
+ explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
+ bool useVectorAlign)
+ : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
+ useVectorAlignment(useVectorAlign) {}
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
@@ -251,8 +259,17 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
- return failure();
+ if (useVectorAlignment) {
+ if (failed(
+ getVectorAlignment(*this->getTypeConverter(), vectorTy, align)))
+ return rewriter.notifyMatchFailure(
+ loadOrStoreOp, "could not resolve vector alignment");
+ } else {
+ if (failed(
+ getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
+ return rewriter.notifyMatchFailure(
+ loadOrStoreOp, "could not resolve memref alignment");
+ }
// Resolve address.
auto vtype = cast<VectorType>(
@@ -263,12 +280,19 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
rewriter);
return success();
}
+
+private:
+ const bool useVectorAlignment;
};
/// Conversion pattern for a vector.gather.
class VectorGatherOpConversion
: public ConvertOpToLLVMPattern<vector::GatherOp> {
public:
+ explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
+ bool useVectorAlign)
+ : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
+ useVectorAlignment(useVectorAlign) {}
using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
LogicalResult
@@ -289,9 +313,15 @@ class VectorGatherOpConversion
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
- return rewriter.notifyMatchFailure(gather,
- "could not resolve memref alignment");
+ if (useVectorAlignment) {
+ if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align)))
+ return rewriter.notifyMatchFailure(
+ gather, "could not resolve vector alignment");
+ } else {
+ if (failed(
+ getMemRefAlignment(*this->getTypeConverter(), memRefType, align)))
+ return rewriter.notifyMatchFailure(
+ gather, "could not resolve memref alignment");
}
// Resolve address.
@@ -308,12 +338,20 @@ class VectorGatherOpConversion
adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
return success();
}
+
+private:
+ const bool useVectorAlignment;
};
/// Conversion pattern for a vector.scatter.
class VectorScatterOpConversion
: public ConvertOpToLLVMPattern<vector::ScatterOp> {
public:
+ explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
+ bool useVectorAlign)
+ : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
+ useVectorAlignment(useVectorAlign) {}
+
using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
LogicalResult
@@ -333,9 +371,18 @@ class VectorScatterOpConversion
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
- return rewriter.notifyMatchFailure(scatter,
- "could not resolve memref alignment");
+
+ // Resolve alignment.
+ unsigned align;
+ if (useVectorAlignment) {
+ if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align)))
+ return rewriter.notifyMatchFailure(
+ scatter, "could not resolve vector alignment");
+ } else {
+ if (failed(
+ getMemRefAlignment(*this->getTypeConverter(), memRefType, align)))
+ return rewriter.notifyMatchFailure(
+ scatter, "could not resolve memref alignment");
}
// Resolve address.
@@ -351,6 +398,9 @@ class VectorScatterOpConversion
rewriter.getI32IntegerAttr(align));
return success();
}
+
+private:
+ const bool useVectorAlignment;
};
/// Conversion pattern for a vector.expandload.
@@ -1939,7 +1989,8 @@ void mlir::vector::populateVectorRankReducingFMAPattern(
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- bool reassociateFPReductions, bool force32BitVectorIndices) {
+ bool reassociateFPReductions, bool force32BitVectorIndices,
+ bool useVectorAlignment) {
// This function populates only ConversionPatterns, not RewritePatterns.
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
@@ -1948,18 +1999,23 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
VectorInsertOpConversion, VectorPrintOpConversion,
- VectorTypeCastOpConversion, VectorScaleOpConversion,
- VectorLoadStoreConversion<vector::LoadOp>,
- VectorLoadStoreConversion<vector::MaskedLoadOp>,
- VectorLoadStoreConversion<vector::StoreOp>,
- VectorLoadStoreConversion<vector::MaskedStoreOp>,
- VectorGatherOpConversion, VectorScatterOpConversion,
- VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
- VectorSplatOpLowering, VectorSplatNdOpLowering,
- VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
- MaskedReductionOpConversion, VectorInterleaveOpLowering,
- VectorDeinterleaveOpLowering, VectorFromElementsLowering,
- VectorScalableStepOpLowering>(converter);
+ VectorTypeCastOpConversion, VectorScaleOpConversion>(ctx),
+ patterns.add<VectorLoadStoreConversion<vector::LoadOp>>(
+ ctx, useVectorAlignment),
+ patterns.add<VectorLoadStoreConversion<vector::MaskedLoadOp>>(
+ ctx, useVectorAlignment),
+ patterns.add<VectorLoadStoreConversion<vector::StoreOp>>(
+ ctx, useVectorAlignment),
+ patterns.add<VectorLoadStoreConversion<vector::MaskedStoreOp>>(
+ ctx, useVectorAlignment),
+ patterns.add<VectorGatherOpConversion, VectorScatterOpConversion,
+ VectorExpandLoadOpConversion,
+ VectorCompressStoreOpConversion, VectorSplatOpLowering,
+ VectorSplatNdOpLowering, VectorScalableInsertOpLowering,
+ VectorScalableExtractOpLowering, MaskedReductionOpConversion,
+ VectorInterleaveOpLowering, VectorDeinterleaveOpLowering,
+ VectorFromElementsLowering, VectorScalableStepOpLowering>(
+ converter);
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..0ee6dce9ee94b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -92,7 +92,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorTransferLoweringPatterns(patterns);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(
- converter, patterns, reassociateFPReductions, force32BitVectorIndices);
+ converter, patterns, reassociateFPReductions, force32BitVectorIndices,
+ useVectorAlignment);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
// Architecture specific augmentations.
More information about the Mlir-commits
mailing list