[Mlir-commits] [mlir] 3715de9 - [mlir] Add use-vector-alignment flag to ConvertVectorToLLVMPass (#137389)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 2 12:54:29 PDT 2025
Author: Lily Orth-Smith
Date: 2025-05-02T12:54:26-07:00
New Revision: 3715de976e41654db83a0aef5be50790024639ea
URL: https://github.com/llvm/llvm-project/commit/3715de976e41654db83a0aef5be50790024639ea
DIFF: https://github.com/llvm/llvm-project/commit/3715de976e41654db83a0aef5be50790024639ea.diff
LOG: [mlir] Add use-vector-alignment flag to ConvertVectorToLLVMPass (#137389)
In ConvertVectorToLLVM, the only option for setting alignment of
`vector.gather`, `vector.scatter`, and the `vector.load/store` ops was
to extract it from the datatype of the memref type. However, this is
insufficient for hardware backends requiring alignment of vector types.
This PR introduces the `use-vector-alignment` option to the
`ConvertVectorToLLVMPass`, which makes the pass use the alignment of the
vector type of these operations instead of the alignment of the memref
type.
---------
Co-authored-by: Lily Orth-Smith <lorthsmith at microsoft.com>
Added:
mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index ed88ada9778d2..f7b0112b32a33 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1411,6 +1411,13 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"true",
"Allows compiler to assume vector indices fit in 32-bit if that "
"yields faster code">,
+ Option<"useVectorAlignment", "use-vector-alignment",
+ "bool", /*default=*/"false",
+ "Use the preferred alignment of a vector type in load/store "
+ "operations instead of the alignment of the element type of the "
+ "memref. This flag is intended for use with hardware which requires"
+ "vector alignment, or in application contexts where it is known all "
+ "vector access are naturally aligned. ">,
Option<"amx", "enable-amx",
"bool", /*default=*/"false",
"Enables the use of AMX dialect while lowering the vector "
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 1e29bfeb9c392..f6b09deb4e44c 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -22,7 +22,8 @@ void populateVectorToLLVMMatrixConversionPatterns(
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- bool reassociateFPReductions = false, bool force32BitVectorIndices = false);
+ bool reassociateFPReductions = false, bool force32BitVectorIndices = false,
+ bool useVectorAlignment = false);
namespace vector {
void registerConvertVectorToLLVMInterface(DialectRegistry ®istry);
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 14cbbac99d9ae..299f198e4ab9c 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -32,7 +32,8 @@ def ApplyVectorToLLVMConversionPatternsOp : Op<Transform_Dialect,
let arguments = (ins
DefaultValuedAttr<BoolAttr, "false">:$reassociate_fp_reductions,
- DefaultValuedAttr<BoolAttr, "true">:$force_32bit_vector_indices);
+ DefaultValuedAttr<BoolAttr, "true">:$force_32bit_vector_indices,
+ DefaultValuedAttr<BoolAttr, "false">:$use_vector_alignment);
let assemblyFormat = "attr-dict";
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 076e5512f375b..5296013189b9e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -67,6 +67,21 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
}
+// Helper that returns data layout alignment of a vector.
+LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
+ VectorType vectorType, unsigned &align) {
+ Type convertedVectorTy = typeConverter.convertType(vectorType);
+ if (!convertedVectorTy)
+ return failure();
+
+ llvm::LLVMContext llvmContext;
+ align = LLVM::TypeToLLVMIRTranslator(llvmContext)
+ .getPreferredAlignment(convertedVectorTy,
+ typeConverter.getDataLayout());
+
+ return success();
+}
+
// Helper that returns data layout alignment of a memref.
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
MemRefType memrefType, unsigned &align) {
@@ -82,6 +97,28 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
return success();
}
+// Helper to resolve the alignment for vector load/store, gather and scatter
+// ops. If useVectorAlignment is true, get the preferred alignment for the
+// vector type in the operation. This option is used for hardware backends with
+// vectorization. Otherwise, use the preferred alignment of the element type of
+// the memref. Note that if you choose to use vector alignment, the shape of the
+// vector type must be resolved before the ConvertVectorToLLVM pass is run.
+LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
+ VectorType vectorType,
+ MemRefType memrefType, unsigned &align,
+ bool useVectorAlignment) {
+ if (useVectorAlignment) {
+ if (failed(getVectorAlignment(typeConverter, vectorType, align))) {
+ return failure();
+ }
+ } else {
+ if (failed(getMemRefAlignment(typeConverter, memrefType, align))) {
+ return failure();
+ }
+ }
+ return success();
+}
+
// Check if the last stride is non-unit and has a valid memory space.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
const LLVMTypeConverter &converter) {
@@ -224,6 +261,10 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
template <class LoadOrStoreOp>
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
public:
+ explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
+ bool useVectorAlign)
+ : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
+ useVectorAlignment(useVectorAlign) {}
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
@@ -240,8 +281,10 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
- return failure();
+ if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
+ memRefTy, align, useVectorAlignment)))
+ return rewriter.notifyMatchFailure(loadOrStoreOp,
+ "could not resolve alignment");
// Resolve address.
auto vtype = cast<VectorType>(
@@ -252,12 +295,23 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
rewriter);
return success();
}
+
+private:
+ // If true, use the preferred alignment of the vector type.
+ // If false, use the preferred alignment of the element type
+ // of the memref. This flag is intended for use with hardware
+ // backends that require alignment of vector operations.
+ const bool useVectorAlignment;
};
/// Conversion pattern for a vector.gather.
class VectorGatherOpConversion
: public ConvertOpToLLVMPattern<vector::GatherOp> {
public:
+ explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
+ bool useVectorAlign)
+ : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
+ useVectorAlignment(useVectorAlign) {}
using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
LogicalResult
@@ -278,10 +332,9 @@ class VectorGatherOpConversion
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
- return rewriter.notifyMatchFailure(gather,
- "could not resolve memref alignment");
- }
+ if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
+ memRefType, align, useVectorAlignment)))
+ return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
// Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
@@ -297,12 +350,24 @@ class VectorGatherOpConversion
adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
return success();
}
+
+private:
+ // If true, use the preferred alignment of the vector type.
+ // If false, use the preferred alignment of the element type
+ // of the memref. This flag is intended for use with hardware
+ // backends that require alignment of vector operations.
+ const bool useVectorAlignment;
};
/// Conversion pattern for a vector.scatter.
class VectorScatterOpConversion
: public ConvertOpToLLVMPattern<vector::ScatterOp> {
public:
+ explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
+ bool useVectorAlign)
+ : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
+ useVectorAlignment(useVectorAlign) {}
+
using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
LogicalResult
@@ -322,10 +387,10 @@ class VectorScatterOpConversion
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
+ if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
+ memRefType, align, useVectorAlignment)))
return rewriter.notifyMatchFailure(scatter,
- "could not resolve memref alignment");
- }
+ "could not resolve alignment");
// Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
@@ -340,6 +405,13 @@ class VectorScatterOpConversion
rewriter.getI32IntegerAttr(align));
return success();
}
+
+private:
+ // If true, use the preferred alignment of the vector type.
+ // If false, use the preferred alignment of the element type
+ // of the memref. This flag is intended for use with hardware
+ // backends that require alignment of vector operations.
+ const bool useVectorAlignment;
};
/// Conversion pattern for a vector.expandload.
@@ -1928,21 +2000,23 @@ void mlir::vector::populateVectorRankReducingFMAPattern(
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- bool reassociateFPReductions, bool force32BitVectorIndices) {
+ bool reassociateFPReductions, bool force32BitVectorIndices,
+ bool useVectorAlignment) {
// This function populates only ConversionPatterns, not RewritePatterns.
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
+ patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
+ VectorLoadStoreConversion<vector::MaskedLoadOp>,
+ VectorLoadStoreConversion<vector::StoreOp>,
+ VectorLoadStoreConversion<vector::MaskedStoreOp>,
+ VectorGatherOpConversion, VectorScatterOpConversion>(
+ converter, useVectorAlignment);
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
- VectorLoadStoreConversion<vector::LoadOp>,
- VectorLoadStoreConversion<vector::MaskedLoadOp>,
- VectorLoadStoreConversion<vector::StoreOp>,
- VectorLoadStoreConversion<vector::MaskedStoreOp>,
- VectorGatherOpConversion, VectorScatterOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
VectorSplatOpLowering, VectorSplatNdOpLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..0ee6dce9ee94b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -92,7 +92,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorTransferLoweringPatterns(patterns);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(
- converter, patterns, reassociateFPReductions, force32BitVectorIndices);
+ converter, patterns, reassociateFPReductions, force32BitVectorIndices,
+ useVectorAlignment);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
// Architecture specific augmentations.
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 1be436dd7bf41..125c3d918284c 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -34,7 +34,8 @@ void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
populateVectorToLLVMConversionPatterns(
static_cast<LLVMTypeConverter &>(typeConverter), patterns,
- getReassociateFpReductions(), getForce_32bitVectorIndices());
+ getReassociateFpReductions(), getForce_32bitVectorIndices(),
+ getUseVectorAlignment());
}
LogicalResult
diff --git a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
new file mode 100644
index 0000000000000..3fa248656cf3a
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
@@ -0,0 +1,102 @@
+// RUN: mlir-opt %s --convert-vector-to-llvm='use-vector-alignment=0' --split-input-file | FileCheck %s --check-prefix=MEMREF-ALIGN
+// RUN: mlir-opt %s --convert-vector-to-llvm='use-vector-alignment=1' --split-input-file | FileCheck %s --check-prefix=VEC-ALIGN
+
+
+//===----------------------------------------------------------------------===//
+// vector.load
+//===----------------------------------------------------------------------===//
+
+func.func @load(%base : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
+ %0 = vector.load %base[%i, %j] : memref<200x100xf32>, vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// ALL-LABEL: func @load
+
+// VEC-ALIGN: llvm.load %{{.*}} {alignment = 32 : i64} : !llvm.ptr -> vector<8xf32>
+// MEMREF-ALIGN: llvm.load %{{.*}} {alignment = 4 : i64} : !llvm.ptr -> vector<8xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.store
+//===----------------------------------------------------------------------===//
+
+func.func @store(%base : memref<200x100xf32>, %i : index, %j : index) {
+ %val = arith.constant dense<11.0> : vector<4xf32>
+ vector.store %val, %base[%i, %j] : memref<200x100xf32>, vector<4xf32>
+ return
+}
+
+// ALL-LABEL: func @store
+
+// VEC-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 16 : i64} : vector<4xf32>, !llvm.ptr
+// MEMREF-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 4 : i64} : vector<4xf32>, !llvm.ptr
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.maskedload
+//===----------------------------------------------------------------------===//
+
+func.func @masked_load(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
+ %c0 = arith.constant 0: index
+ %0 = vector.maskedload %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ return %0 : vector<16xf32>
+}
+
+// ALL-LABEL: func @masked_load
+
+// VEC-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 64 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+// MEMREF-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.maskedstore
+//===----------------------------------------------------------------------===//
+
+func.func @masked_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
+ %c0 = arith.constant 0: index
+ vector.maskedstore %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32>
+ return
+}
+
+// ALL-LABEL: func @masked_store
+
+// VEC-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 64 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
+// MEMREF-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.scatter
+//===----------------------------------------------------------------------===//
+
+func.func @scatter(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %value: vector<3xf32>) {
+ %0 = arith.constant 0: index
+ vector.scatter %base[%0][%index], %mask, %value : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+ return
+}
+
+// ALL-LABEL: func @scatter
+
+// VEC-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 16 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+// MEMREF-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.gather
+//===----------------------------------------------------------------------===//
+
+func.func @gather(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %passthru: vector<3xf32>) -> vector<3xf32> {
+ %0 = arith.constant 0: index
+ %1 = vector.gather %base[%0][%index], %mask, %passthru : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+ return %1 : vector<3xf32>
+}
+
+// ALL-LABEL: func @gather
+
+// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 16 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// MEMREF-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
More information about the Mlir-commits
mailing list