[Mlir-commits] [mlir] [mlir] Add use-vector-alignment flag to ConvertVectorToLLVMPass (PR #137389)
Lily Orth-Smith
llvmlistbot at llvm.org
Fri May 2 09:13:14 PDT 2025
https://github.com/electriclilies updated https://github.com/llvm/llvm-project/pull/137389
>From 27a25cd3f3469fa95f77fd51f8c6afc450a2908f Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Thu, 1 May 2025 19:44:12 +0000
Subject: [PATCH 1/2] Squashed changes
---
mlir/include/mlir/Conversion/Passes.td | 6 +
.../VectorToLLVM/ConvertVectorToLLVM.h | 3 +-
.../Vector/TransformOps/VectorTransformOps.td | 3 +-
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 104 +++++++++++++++---
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 3 +-
.../TransformOps/VectorTransformOps.cpp | 3 +-
.../VectorToLLVM/use-vector-alignment.mlir | 102 +++++++++++++++++
7 files changed, 205 insertions(+), 19 deletions(-)
create mode 100644 mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..4ab2f300d0ccc 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1394,6 +1394,12 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"true",
"Allows compiler to assume vector indices fit in 32-bit if that "
"yields faster code">,
+ Option<"useVectorAlignment", "use-vector-alignment",
+ "bool", /*default=*/"false",
+ "Use the preferred alignment of a vector type in load/store "
+ "operations instead of the alignment of the element type of the "
+ "memref. This flag is intended for use with hardware which requires"
+ "vector alignment.">,
Option<"amx", "enable-amx",
"bool", /*default=*/"false",
"Enables the use of AMX dialect while lowering the vector "
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 1e29bfeb9c392..f6b09deb4e44c 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -22,7 +22,8 @@ void populateVectorToLLVMMatrixConversionPatterns(
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- bool reassociateFPReductions = false, bool force32BitVectorIndices = false);
+ bool reassociateFPReductions = false, bool force32BitVectorIndices = false,
+ bool useVectorAlignment = false);
namespace vector {
void registerConvertVectorToLLVMInterface(DialectRegistry ®istry);
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 14cbbac99d9ae..299f198e4ab9c 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -32,7 +32,8 @@ def ApplyVectorToLLVMConversionPatternsOp : Op<Transform_Dialect,
let arguments = (ins
DefaultValuedAttr<BoolAttr, "false">:$reassociate_fp_reductions,
- DefaultValuedAttr<BoolAttr, "true">:$force_32bit_vector_indices);
+ DefaultValuedAttr<BoolAttr, "true">:$force_32bit_vector_indices,
+ DefaultValuedAttr<BoolAttr, "false">:$use_vector_alignment);
let assemblyFormat = "attr-dict";
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 076e5512f375b..5296013189b9e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -67,6 +67,21 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
}
+// Helper that returns data layout alignment of a vector.
+LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
+ VectorType vectorType, unsigned &align) {
+ Type convertedVectorTy = typeConverter.convertType(vectorType);
+ if (!convertedVectorTy)
+ return failure();
+
+ llvm::LLVMContext llvmContext;
+ align = LLVM::TypeToLLVMIRTranslator(llvmContext)
+ .getPreferredAlignment(convertedVectorTy,
+ typeConverter.getDataLayout());
+
+ return success();
+}
+
// Helper that returns data layout alignment of a memref.
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
MemRefType memrefType, unsigned &align) {
@@ -82,6 +97,28 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
return success();
}
+// Helper to resolve the alignment for vector load/store, gather and scatter
+// ops. If useVectorAlignment is true, get the preferred alignment for the
+// vector type in the operation. This option is used for hardware backends with
+// vectorization. Otherwise, use the preferred alignment of the element type of
+// the memref. Note that if you choose to use vector alignment, the shape of the
+// vector type must be resolved before the ConvertVectorToLLVM pass is run.
+LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
+ VectorType vectorType,
+ MemRefType memrefType, unsigned &align,
+ bool useVectorAlignment) {
+ if (useVectorAlignment) {
+ if (failed(getVectorAlignment(typeConverter, vectorType, align))) {
+ return failure();
+ }
+ } else {
+ if (failed(getMemRefAlignment(typeConverter, memrefType, align))) {
+ return failure();
+ }
+ }
+ return success();
+}
+
// Check if the last stride is non-unit and has a valid memory space.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
const LLVMTypeConverter &converter) {
@@ -224,6 +261,10 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
template <class LoadOrStoreOp>
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
public:
+ explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
+ bool useVectorAlign)
+ : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
+ useVectorAlignment(useVectorAlign) {}
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
@@ -240,8 +281,10 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
- return failure();
+ if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
+ memRefTy, align, useVectorAlignment)))
+ return rewriter.notifyMatchFailure(loadOrStoreOp,
+ "could not resolve alignment");
// Resolve address.
auto vtype = cast<VectorType>(
@@ -252,12 +295,23 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
rewriter);
return success();
}
+
+private:
+ // If true, use the preferred alignment of the vector type.
+ // If false, use the preferred alignment of the element type
+ // of the memref. This flag is intended for use with hardware
+ // backends that require alignment of vector operations.
+ const bool useVectorAlignment;
};
/// Conversion pattern for a vector.gather.
class VectorGatherOpConversion
: public ConvertOpToLLVMPattern<vector::GatherOp> {
public:
+ explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
+ bool useVectorAlign)
+ : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
+ useVectorAlignment(useVectorAlign) {}
using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
LogicalResult
@@ -278,10 +332,9 @@ class VectorGatherOpConversion
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
- return rewriter.notifyMatchFailure(gather,
- "could not resolve memref alignment");
- }
+ if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
+ memRefType, align, useVectorAlignment)))
+ return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
// Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
@@ -297,12 +350,24 @@ class VectorGatherOpConversion
adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
return success();
}
+
+private:
+ // If true, use the preferred alignment of the vector type.
+ // If false, use the preferred alignment of the element type
+ // of the memref. This flag is intended for use with hardware
+ // backends that require alignment of vector operations.
+ const bool useVectorAlignment;
};
/// Conversion pattern for a vector.scatter.
class VectorScatterOpConversion
: public ConvertOpToLLVMPattern<vector::ScatterOp> {
public:
+ explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
+ bool useVectorAlign)
+ : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
+ useVectorAlignment(useVectorAlign) {}
+
using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
LogicalResult
@@ -322,10 +387,10 @@ class VectorScatterOpConversion
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
+ if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
+ memRefType, align, useVectorAlignment)))
return rewriter.notifyMatchFailure(scatter,
- "could not resolve memref alignment");
- }
+ "could not resolve alignment");
// Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
@@ -340,6 +405,13 @@ class VectorScatterOpConversion
rewriter.getI32IntegerAttr(align));
return success();
}
+
+private:
+ // If true, use the preferred alignment of the vector type.
+ // If false, use the preferred alignment of the element type
+ // of the memref. This flag is intended for use with hardware
+ // backends that require alignment of vector operations.
+ const bool useVectorAlignment;
};
/// Conversion pattern for a vector.expandload.
@@ -1928,21 +2000,23 @@ void mlir::vector::populateVectorRankReducingFMAPattern(
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- bool reassociateFPReductions, bool force32BitVectorIndices) {
+ bool reassociateFPReductions, bool force32BitVectorIndices,
+ bool useVectorAlignment) {
// This function populates only ConversionPatterns, not RewritePatterns.
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
+ patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
+ VectorLoadStoreConversion<vector::MaskedLoadOp>,
+ VectorLoadStoreConversion<vector::StoreOp>,
+ VectorLoadStoreConversion<vector::MaskedStoreOp>,
+ VectorGatherOpConversion, VectorScatterOpConversion>(
+ converter, useVectorAlignment);
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
- VectorLoadStoreConversion<vector::LoadOp>,
- VectorLoadStoreConversion<vector::MaskedLoadOp>,
- VectorLoadStoreConversion<vector::StoreOp>,
- VectorLoadStoreConversion<vector::MaskedStoreOp>,
- VectorGatherOpConversion, VectorScatterOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
VectorSplatOpLowering, VectorSplatNdOpLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..0ee6dce9ee94b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -92,7 +92,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorTransferLoweringPatterns(patterns);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(
- converter, patterns, reassociateFPReductions, force32BitVectorIndices);
+ converter, patterns, reassociateFPReductions, force32BitVectorIndices,
+ useVectorAlignment);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
// Architecture specific augmentations.
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 1be436dd7bf41..125c3d918284c 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -34,7 +34,8 @@ void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
populateVectorToLLVMConversionPatterns(
static_cast<LLVMTypeConverter &>(typeConverter), patterns,
- getReassociateFpReductions(), getForce_32bitVectorIndices());
+ getReassociateFpReductions(), getForce_32bitVectorIndices(),
+ getUseVectorAlignment());
}
LogicalResult
diff --git a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
new file mode 100644
index 0000000000000..3fa248656cf3a
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
@@ -0,0 +1,102 @@
+// RUN: mlir-opt %s --convert-vector-to-llvm='use-vector-alignment=0' --split-input-file | FileCheck %s --check-prefix=MEMREF-ALIGN
+// RUN: mlir-opt %s --convert-vector-to-llvm='use-vector-alignment=1' --split-input-file | FileCheck %s --check-prefix=VEC-ALIGN
+
+
+//===----------------------------------------------------------------------===//
+// vector.load
+//===----------------------------------------------------------------------===//
+
+func.func @load(%base : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
+ %0 = vector.load %base[%i, %j] : memref<200x100xf32>, vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// ALL-LABEL: func @load
+
+// VEC-ALIGN: llvm.load %{{.*}} {alignment = 32 : i64} : !llvm.ptr -> vector<8xf32>
+// MEMREF-ALIGN: llvm.load %{{.*}} {alignment = 4 : i64} : !llvm.ptr -> vector<8xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.store
+//===----------------------------------------------------------------------===//
+
+func.func @store(%base : memref<200x100xf32>, %i : index, %j : index) {
+ %val = arith.constant dense<11.0> : vector<4xf32>
+ vector.store %val, %base[%i, %j] : memref<200x100xf32>, vector<4xf32>
+ return
+}
+
+// ALL-LABEL: func @store
+
+// VEC-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 16 : i64} : vector<4xf32>, !llvm.ptr
+// MEMREF-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 4 : i64} : vector<4xf32>, !llvm.ptr
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.maskedload
+//===----------------------------------------------------------------------===//
+
+func.func @masked_load(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
+ %c0 = arith.constant 0: index
+ %0 = vector.maskedload %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ return %0 : vector<16xf32>
+}
+
+// ALL-LABEL: func @masked_load
+
+// VEC-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 64 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+// MEMREF-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.maskedstore
+//===----------------------------------------------------------------------===//
+
+func.func @masked_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
+ %c0 = arith.constant 0: index
+ vector.maskedstore %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32>
+ return
+}
+
+// ALL-LABEL: func @masked_store
+
+// VEC-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 64 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
+// MEMREF-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.scatter
+//===----------------------------------------------------------------------===//
+
+func.func @scatter(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %value: vector<3xf32>) {
+ %0 = arith.constant 0: index
+ vector.scatter %base[%0][%index], %mask, %value : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+ return
+}
+
+// ALL-LABEL: func @scatter
+
+// VEC-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 16 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+// MEMREF-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.gather
+//===----------------------------------------------------------------------===//
+
+func.func @gather(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %passthru: vector<3xf32>) -> vector<3xf32> {
+ %0 = arith.constant 0: index
+ %1 = vector.gather %base[%0][%index], %mask, %passthru : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+ return %1 : vector<3xf32>
+}
+
+// ALL-LABEL: func @gather
+
+// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 16 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// MEMREF-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
>From 47a9f6383dda7bab7e92094fa36d0962c51708db Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Fri, 2 May 2025 16:11:44 +0000
Subject: [PATCH 2/2] Update comment
---
mlir/include/mlir/Conversion/Passes.td | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 4ab2f300d0ccc..f9c0f982c2118 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1399,7 +1399,8 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"Use the preferred alignment of a vector type in load/store "
"operations instead of the alignment of the element type of the "
"memref. This flag is intended for use with hardware which requires"
- "vector alignment.">,
+ "vector alignment, or in application contexts where it is known all "
+ "vector access are naturally aligned. ">,
Option<"amx", "enable-amx",
"bool", /*default=*/"false",
"Enables the use of AMX dialect while lowering the vector "
More information about the Mlir-commits
mailing list