[Mlir-commits] [mlir] [mlir] Add use-vector-alignment flag to ConvertVectorToLLVMPass (PR #137389)
Lily Orth-Smith
llvmlistbot at llvm.org
Wed Apr 30 15:50:00 PDT 2025
https://github.com/electriclilies updated https://github.com/llvm/llvm-project/pull/137389
>From de7365648aabbd0834b4d50667938be211139f2f Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Fri, 25 Apr 2025 19:51:30 +0000
Subject: [PATCH 01/11] Fix how we get alignment for memrefs
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 19 +++++++++++++++----
1 file changed, 15 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 076e5512f375b..d5b312d6929be 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -70,15 +70,26 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
// Helper that returns data layout alignment of a memref.
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
MemRefType memrefType, unsigned &align) {
- Type elementTy = typeConverter.convertType(memrefType.getElementType());
- if (!elementTy)
+ // If shape is statically known, assign MemRefTypes to the alignment of a
+ // VectorType with the same size and dtype. Otherwise, fall back to the
+ // alignment of the element type.
+ Type convertedType;
+ if (memrefType.hasStaticShape()) {
+ convertedType = typeConverter.convertType(VectorType::get(
+ memrefType.getNumElements(), memrefType.getElementType()));
+ } else {
+ convertedType = typeConverter.convertType(memrefType.getElementType());
+ }
+
+ if (!convertedType)
return failure();
// TODO: this should use the MLIR data layout when it becomes available and
// stop depending on translation.
llvm::LLVMContext llvmContext;
- align = LLVM::TypeToLLVMIRTranslator(llvmContext)
- .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
+ align =
+ LLVM::TypeToLLVMIRTranslator(llvmContext)
+ .getPreferredAlignment(convertedType, typeConverter.getDataLayout());
return success();
}
>From efa44d5d6f6d3aeec92c5207bc5a2e1d27a26ac8 Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Tue, 29 Apr 2025 22:51:13 +0000
Subject: [PATCH 02/11] Use flag to determine whether we use vector alignment
or memref alignment
---
mlir/include/mlir/Conversion/Passes.td | 5 +
.../VectorToLLVM/ConvertVectorToLLVM.h | 3 +-
.../Vector/TransformOps/VectorTransformOps.td | 3 +-
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 126 +++++++++++++-----
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 3 +-
.../TransformOps/VectorTransformOps.cpp | 3 +-
6 files changed, 103 insertions(+), 40 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..3f64e719a3fcb 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1394,6 +1394,11 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"true",
"Allows compiler to assume vector indices fit in 32-bit if that "
"yields faster code">,
+ Option<"useVectorAlignment", "use-vector-alignment",
+ "bool", /*default=*/"false",
+ "Use the preferred alignment of a vector type in load/store "
+ "operations instead of the alignment of the element type of the "
+ "memref">,
Option<"amx", "enable-amx",
"bool", /*default=*/"false",
"Enables the use of AMX dialect while lowering the vector "
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 1e29bfeb9c392..f6b09deb4e44c 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -22,7 +22,8 @@ void populateVectorToLLVMMatrixConversionPatterns(
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- bool reassociateFPReductions = false, bool force32BitVectorIndices = false);
+ bool reassociateFPReductions = false, bool force32BitVectorIndices = false,
+ bool useVectorAlignment = false);
namespace vector {
void registerConvertVectorToLLVMInterface(DialectRegistry ®istry);
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 14cbbac99d9ae..299f198e4ab9c 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -32,7 +32,8 @@ def ApplyVectorToLLVMConversionPatternsOp : Op<Transform_Dialect,
let arguments = (ins
DefaultValuedAttr<BoolAttr, "false">:$reassociate_fp_reductions,
- DefaultValuedAttr<BoolAttr, "true">:$force_32bit_vector_indices);
+ DefaultValuedAttr<BoolAttr, "true">:$force_32bit_vector_indices,
+ DefaultValuedAttr<BoolAttr, "false">:$use_vector_alignment);
let assemblyFormat = "attr-dict";
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d5b312d6929be..7aa6a4257477a 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -67,29 +67,33 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
}
+// Helper that returns data layout alignment of a vector.
+LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
+ VectorType vectorType, unsigned &align) {
+ Type convertedVectorTy = typeConverter.convertType(vectorType);
+ if (!convertedVectorTy)
+ return failure();
+
+ llvm::LLVMContext llvmContext;
+ align = LLVM::TypeToLLVMIRTranslator(llvmContext)
+ .getPreferredAlignment(convertedVectorTy,
+ typeConverter.getDataLayout());
+
+ return success();
+}
+
// Helper that returns data layout alignment of a memref.
LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
MemRefType memrefType, unsigned &align) {
- // If shape is statically known, assign MemRefTypes to the alignment of a
- // VectorType with the same size and dtype. Otherwise, fall back to the
- // alignment of the element type.
- Type convertedType;
- if (memrefType.hasStaticShape()) {
- convertedType = typeConverter.convertType(VectorType::get(
- memrefType.getNumElements(), memrefType.getElementType()));
- } else {
- convertedType = typeConverter.convertType(memrefType.getElementType());
- }
-
- if (!convertedType)
+ Type elementTy = typeConverter.convertType(memrefType.getElementType());
+ if (!elementTy)
return failure();
// TODO: this should use the MLIR data layout when it becomes available and
// stop depending on translation.
llvm::LLVMContext llvmContext;
- align =
- LLVM::TypeToLLVMIRTranslator(llvmContext)
- .getPreferredAlignment(convertedType, typeConverter.getDataLayout());
+ align = LLVM::TypeToLLVMIRTranslator(llvmContext)
+ .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
return success();
}
@@ -235,6 +239,10 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
template <class LoadOrStoreOp>
class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
public:
+ explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
+ bool useVectorAlign)
+ : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
+ useVectorAlignment(useVectorAlign) {}
using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
@@ -251,8 +259,17 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
- return failure();
+ if (useVectorAlignment) {
+ if (failed(
+ getVectorAlignment(*this->getTypeConverter(), vectorTy, align)))
+ return rewriter.notifyMatchFailure(
+ loadOrStoreOp, "could not resolve vector alignment");
+ } else {
+ if (failed(
+ getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
+ return rewriter.notifyMatchFailure(
+ loadOrStoreOp, "could not resolve memref alignment");
+ }
// Resolve address.
auto vtype = cast<VectorType>(
@@ -263,12 +280,19 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
rewriter);
return success();
}
+
+private:
+ const bool useVectorAlignment;
};
/// Conversion pattern for a vector.gather.
class VectorGatherOpConversion
: public ConvertOpToLLVMPattern<vector::GatherOp> {
public:
+ explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
+ bool useVectorAlign)
+ : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
+ useVectorAlignment(useVectorAlign) {}
using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
LogicalResult
@@ -289,9 +313,15 @@ class VectorGatherOpConversion
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
- return rewriter.notifyMatchFailure(gather,
- "could not resolve memref alignment");
+ if (useVectorAlignment) {
+ if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align)))
+ return rewriter.notifyMatchFailure(
+ gather, "could not resolve vector alignment");
+ } else {
+ if (failed(
+ getMemRefAlignment(*this->getTypeConverter(), memRefType, align)))
+ return rewriter.notifyMatchFailure(
+ gather, "could not resolve memref alignment");
}
// Resolve address.
@@ -308,12 +338,20 @@ class VectorGatherOpConversion
adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
return success();
}
+
+private:
+ const bool useVectorAlignment;
};
/// Conversion pattern for a vector.scatter.
class VectorScatterOpConversion
: public ConvertOpToLLVMPattern<vector::ScatterOp> {
public:
+ explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
+ bool useVectorAlign)
+ : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
+ useVectorAlignment(useVectorAlign) {}
+
using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
LogicalResult
@@ -333,9 +371,15 @@ class VectorScatterOpConversion
// Resolve alignment.
unsigned align;
- if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
- return rewriter.notifyMatchFailure(scatter,
- "could not resolve memref alignment");
+ if (useVectorAlignment) {
+ if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align)))
+ return rewriter.notifyMatchFailure(
+ scatter, "could not resolve vector alignment");
+ } else {
+ if (failed(
+ getMemRefAlignment(*this->getTypeConverter(), memRefType, align)))
+ return rewriter.notifyMatchFailure(
+ scatter, "could not resolve memref alignment");
}
// Resolve address.
@@ -351,6 +395,9 @@ class VectorScatterOpConversion
rewriter.getI32IntegerAttr(align));
return success();
}
+
+private:
+ const bool useVectorAlignment;
};
/// Conversion pattern for a vector.expandload.
@@ -1939,7 +1986,8 @@ void mlir::vector::populateVectorRankReducingFMAPattern(
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- bool reassociateFPReductions, bool force32BitVectorIndices) {
+ bool reassociateFPReductions, bool force32BitVectorIndices,
+ bool useVectorAlignment) {
// This function populates only ConversionPatterns, not RewritePatterns.
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
@@ -1948,18 +1996,24 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
VectorInsertOpConversion, VectorPrintOpConversion,
- VectorTypeCastOpConversion, VectorScaleOpConversion,
- VectorLoadStoreConversion<vector::LoadOp>,
- VectorLoadStoreConversion<vector::MaskedLoadOp>,
- VectorLoadStoreConversion<vector::StoreOp>,
- VectorLoadStoreConversion<vector::MaskedStoreOp>,
- VectorGatherOpConversion, VectorScatterOpConversion,
- VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
- VectorSplatOpLowering, VectorSplatNdOpLowering,
- VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
- MaskedReductionOpConversion, VectorInterleaveOpLowering,
- VectorDeinterleaveOpLowering, VectorFromElementsLowering,
- VectorScalableStepOpLowering>(converter);
+ VectorTypeCastOpConversion, VectorScaleOpConversion>(ctx),
+ patterns.add<VectorLoadStoreConversion<vector::LoadOp>>(
+ ctx, useVectorAlignment),
+ patterns.add<VectorLoadStoreConversion<vector::MaskedLoadOp>>(
+ ctx, useVectorAlignment),
+ patterns.add<VectorLoadStoreConversion<vector::StoreOp>>(
+ ctx, useVectorAlignment),
+ patterns.add<VectorLoadStoreConversion<vector::MaskedStoreOp>>(
+ ctx, useVectorAlignment),
+ patterns.add<VectorGatherOpConversion>(ctx, useVectorAlignment),
+ patterns.add<VectorScatterOpConversion>(ctx, useVectorAlignment),
+ patterns.add<VectorExpandLoadOpConversion,
+ VectorCompressStoreOpConversion, VectorSplatOpLowering,
+ VectorSplatNdOpLowering, VectorScalableInsertOpLowering,
+ VectorScalableExtractOpLowering, MaskedReductionOpConversion,
+ VectorInterleaveOpLowering, VectorDeinterleaveOpLowering,
+ VectorFromElementsLowering, VectorScalableStepOpLowering>(
+ converter);
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..0ee6dce9ee94b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -92,7 +92,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorTransferLoweringPatterns(patterns);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(
- converter, patterns, reassociateFPReductions, force32BitVectorIndices);
+ converter, patterns, reassociateFPReductions, force32BitVectorIndices,
+ useVectorAlignment);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
// Architecture specific augmentations.
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 1be436dd7bf41..125c3d918284c 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -34,7 +34,8 @@ void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
populateVectorToLLVMConversionPatterns(
static_cast<LLVMTypeConverter &>(typeConverter), patterns,
- getReassociateFpReductions(), getForce_32bitVectorIndices());
+ getReassociateFpReductions(), getForce_32bitVectorIndices(),
+ getUseVectorAlignment());
}
LogicalResult
>From c9cc43c128ae73f3e38011321e511b2d338911cd Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Wed, 30 Apr 2025 00:15:26 +0000
Subject: [PATCH 03/11] Add test for use-vector-alignment flag
---
.../VectorToLLVM/use-vector-alignment.mlir | 144 ++++++++++++++++++
1 file changed, 144 insertions(+)
create mode 100644 mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
diff --git a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
new file mode 100644
index 0000000000000..791fe81fcc652
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
@@ -0,0 +1,144 @@
+// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=f32:32' %s --convert-vector-to-llvm='use-vector-alignment=1' | FileCheck %s --check-prefix=VEC-ALIGN
+// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=v256:256' %s --convert-vector-to-llvm='use-vector-alignment=0' | FileCheck %s --check-prefix=MEMREF-ALIGN
+
+
+//===----------------------------------------------------------------------===//
+// vector.load
+//===----------------------------------------------------------------------===//
+
+func.func @load(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
+ %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// VEC-ALIGN-LABEL: func @load
+// VEC-ALIGN: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
+// VEC-ALIGN: %[[MUL:.*]] = llvm.mul %{{.*}}, %[[C100]] : i64
+// VEC-ALIGN: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64
+// VEC-ALIGN: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// VEC-ALIGN: llvm.load %[[GEP]] {alignment = 32 : i64} : !llvm.ptr -> vector<8xf32>
+
+// MEMREF-ALIGN-LABEL: func @load
+// MEMREF-ALIGN: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
+// MEMREF-ALIGN: %[[MUL:.*]] = llvm.mul %{{.*}}, %[[C100]] : i64
+// MEMREF-ALIGN: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64
+// MEMREF-ALIGN: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// MEMREF-ALIGN: llvm.load %[[GEP]] {alignment = 4 : i64} : !llvm.ptr -> vector<8xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.store
+//===----------------------------------------------------------------------===//
+
+func.func @store(%memref : memref<200x100xf32>, %i : index, %j : index) {
+ %val = arith.constant dense<11.0> : vector<4xf32>
+ vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32>
+ return
+}
+
+// VEC-ALIGN-LABEL: func @store
+// VEC-ALIGN: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
+// VEC-ALIGN: %[[MUL:.*]] = llvm.mul %{{.*}}, %[[C100]] : i64
+// VEC-ALIGN: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64
+// VEC-ALIGN: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// VEC-ALIGN: llvm.store %{{.*}}, %[[GEP]] {alignment = 32 : i64} : vector<4xf32>, !llvm.ptr
+
+// MEMREF-ALIGN-LABEL: func @store
+// MEMREF-ALIGN: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
+// MEMREF-ALIGN: %[[MUL:.*]] = llvm.mul %{{.*}}, %[[C100]] : i64
+// MEMREF-ALIGN: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64
+// MEMREF-ALIGN: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// MEMREF-ALIGN: llvm.store %{{.*}}, %[[GEP]] {alignment = 4 : i64} : vector<4xf32>, !llvm.ptr
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.maskedload
+//===----------------------------------------------------------------------===//
+
+func.func @masked_load(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
+ %c0 = arith.constant 0: index
+ %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ return %0 : vector<16xf32>
+}
+
+// VEC-ALIGN-LABEL: func @masked_load
+// VEC-ALIGN: %[[CO:.*]] = arith.constant 0 : index
+// VEC-ALIGN: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64
+// VEC-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// VEC-ALIGN: %[[L:.*]] = llvm.intr.masked.load %[[P]], %{{.*}}, %{{.*}} {alignment = 32 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+// VEC-ALIGN: return %[[L]] : vector<16xf32>
+
+// MEMREF-ALIGN-LABEL: func @masked_load
+// MEMREF-ALIGN: %[[CO:.*]] = arith.constant 0 : index
+// MEMREF-ALIGN: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64
+// MEMREF-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// MEMREF-ALIGN: %[[L:.*]] = llvm.intr.masked.load %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+// MEMREF-ALIGN: return %[[L]] : vector<16xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.maskedstore
+//===----------------------------------------------------------------------===//
+
+func.func @masked_store(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
+ %c0 = arith.constant 0: index
+ vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32>
+ return
+}
+
+// VEC-ALIGN-LABEL: func @masked_store
+// VEC-ALIGN: %[[CO:.*]] = arith.constant 0 : index
+// VEC-ALIGN: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64
+// VEC-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// VEC-ALIGN: llvm.intr.masked.store %{{.*}}, %[[P]], %{{.*}} {alignment = 32 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
+
+// MEMREF-ALIGN-LABEL: func @masked_store
+// MEMREF-ALIGN: %[[CO:.*]] = arith.constant 0 : index
+// MEMREF-ALIGN: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64
+// MEMREF-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// MEMREF-ALIGN: llvm.intr.masked.store %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.scatter
+//===----------------------------------------------------------------------===//
+
+func.func @scatter(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
+ %0 = arith.constant 0: index
+ vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+ return
+}
+
+// VEC-ALIGN-LABEL: func @scatter
+// VEC-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
+// VEC-ALIGN: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 32 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+
+// MEMREF-ALIGN-LABEL: func @scatter
+// MEMREF-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
+// MEMREF-ALIGN: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.gather
+//===----------------------------------------------------------------------===//
+
+func.func @gather(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
+ %0 = arith.constant 0: index
+ %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+ return %1 : vector<3xf32>
+}
+
+// VEC-ALIGN-LABEL: func @gather
+// VEC-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
+// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 32 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// VEC-ALIGN: return %[[G]] : vector<3xf32>
+
+// MEMREF-ALIGN-LABEL: func @gather
+// MEMREF-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
+// MEMREF-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// MEMREF-ALIGN: return %[[G]] : vector<3xf32>
>From 7a709c24ec7205a7f3c850501dbca85d48275364 Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Wed, 30 Apr 2025 01:50:17 +0000
Subject: [PATCH 04/11] Try to fix timeout
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 22 ++++++++++++-------
1 file changed, 14 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 7aa6a4257477a..05b01acb00623 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -261,14 +261,16 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
unsigned align;
if (useVectorAlignment) {
if (failed(
- getVectorAlignment(*this->getTypeConverter(), vectorTy, align)))
+ getVectorAlignment(*this->getTypeConverter(), vectorTy, align))) {
return rewriter.notifyMatchFailure(
loadOrStoreOp, "could not resolve vector alignment");
+ }
} else {
if (failed(
- getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
+ getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) {
return rewriter.notifyMatchFailure(
loadOrStoreOp, "could not resolve memref alignment");
+ }
}
// Resolve address.
@@ -314,14 +316,16 @@ class VectorGatherOpConversion
// Resolve alignment.
unsigned align;
if (useVectorAlignment) {
- if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align)))
+ if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align))) {
return rewriter.notifyMatchFailure(
gather, "could not resolve vector alignment");
+ }
} else {
- if (failed(
- getMemRefAlignment(*this->getTypeConverter(), memRefType, align)))
+ if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefType,
+ align))) {
return rewriter.notifyMatchFailure(
gather, "could not resolve memref alignment");
+ }
}
// Resolve address.
@@ -372,14 +376,16 @@ class VectorScatterOpConversion
// Resolve alignment.
unsigned align;
if (useVectorAlignment) {
- if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align)))
+ if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align))) {
return rewriter.notifyMatchFailure(
scatter, "could not resolve vector alignment");
+ }
} else {
- if (failed(
- getMemRefAlignment(*this->getTypeConverter(), memRefType, align)))
+ if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefType,
+ align))) {
return rewriter.notifyMatchFailure(
scatter, "could not resolve memref alignment");
+ }
}
// Resolve address.
>From e9593e4142abf6cb021965c24e57b707e061bc57 Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Wed, 30 Apr 2025 17:54:46 +0000
Subject: [PATCH 05/11] change failure
---
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 05b01acb00623..ecf6504c6e2da 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -262,14 +262,12 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
if (useVectorAlignment) {
if (failed(
getVectorAlignment(*this->getTypeConverter(), vectorTy, align))) {
- return rewriter.notifyMatchFailure(
- loadOrStoreOp, "could not resolve vector alignment");
+ return failure();
}
} else {
if (failed(
getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) {
- return rewriter.notifyMatchFailure(
- loadOrStoreOp, "could not resolve memref alignment");
+ return failure();
}
}
>From 29598e8d052a6d9fe77912a3883da0575eaa6ef4 Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Wed, 30 Apr 2025 18:04:05 +0000
Subject: [PATCH 06/11] Fix how patterns are added to the pass
---
mlir/include/mlir/Conversion/Passes.td | 5 +-
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 104 +++++++++---------
.../VectorToLLVM/use-vector-alignment.mlir | 28 ++---
3 files changed, 68 insertions(+), 69 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 3f64e719a3fcb..4ab2f300d0ccc 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1397,8 +1397,9 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
Option<"useVectorAlignment", "use-vector-alignment",
"bool", /*default=*/"false",
"Use the preferred alignment of a vector type in load/store "
- "operations instead of the alignment of the element type of the "
- "memref">,
+ "operations instead of the alignment of the element type of the "
+ "memref. This flag is intended for use with hardware which requires"
+ "vector alignment.">,
Option<"amx", "enable-amx",
"bool", /*default=*/"false",
"Enables the use of AMX dialect while lowering the vector "
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index ecf6504c6e2da..6ebac30690959 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -97,6 +97,21 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
return success();
}
+LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
+ VectorType vectorType,
+ MemRefType memrefType, unsigned &align,
+ bool useVectorAlignment) {
+ if (useVectorAlignment) {
+ if (failed(getVectorAlignment(typeConverter, vectorType, align))) {
+ return failure();
+ }
+ } else {
+ if (failed(getMemRefAlignment(typeConverter, memrefType, align))) {
+ return failure();
+ }
+ }
+}
+
// Check if the last stride is non-unit and has a valid memory space.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
const LLVMTypeConverter &converter) {
@@ -259,17 +274,10 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
// Resolve alignment.
unsigned align;
- if (useVectorAlignment) {
- if (failed(
- getVectorAlignment(*this->getTypeConverter(), vectorTy, align))) {
- return failure();
- }
- } else {
- if (failed(
- getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) {
- return failure();
- }
- }
+ if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
+ memRefTy, align, useVectorAlignment)))
+ return rewriter.notifyMatchFailure(loadOrStoreOp,
+ "could not resolve alignment");
// Resolve address.
auto vtype = cast<VectorType>(
@@ -282,6 +290,10 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
}
private:
+ // If true, use the preferred alignment of the vector type.
+ // If false, use the preferred alignment of the element type
+ // of the memref. This flag is intended for use with hardware
+ // backends that require alignment of vector operations.
const bool useVectorAlignment;
};
@@ -313,18 +325,9 @@ class VectorGatherOpConversion
// Resolve alignment.
unsigned align;
- if (useVectorAlignment) {
- if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align))) {
- return rewriter.notifyMatchFailure(
- gather, "could not resolve vector alignment");
- }
- } else {
- if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefType,
- align))) {
- return rewriter.notifyMatchFailure(
- gather, "could not resolve memref alignment");
- }
- }
+ if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
+ memRefType, align, useVectorAlignment)))
+ return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
// Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
@@ -342,6 +345,10 @@ class VectorGatherOpConversion
}
private:
+ // If true, use the preferred alignment of the vector type.
+ // If false, use the preferred alignment of the element type
+ // of the memref. This flag is intended for use with hardware
+ // backends that require alignment of vector operations.
const bool useVectorAlignment;
};
@@ -373,18 +380,10 @@ class VectorScatterOpConversion
// Resolve alignment.
unsigned align;
- if (useVectorAlignment) {
- if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align))) {
- return rewriter.notifyMatchFailure(
- scatter, "could not resolve vector alignment");
- }
- } else {
- if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefType,
- align))) {
- return rewriter.notifyMatchFailure(
- scatter, "could not resolve memref alignment");
- }
- }
+ if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
+ memRefType, align, useVectorAlignment)))
+ return rewriter.notifyMatchFailure(scatter,
+ "could not resolve alignment");
// Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
@@ -401,6 +400,10 @@ class VectorScatterOpConversion
}
private:
+ // If true, use the preferred alignment of the vector type.
+ // If false, use the preferred alignment of the element type
+ // of the memref. This flag is intended for use with hardware
+ // backends that require alignment of vector operations.
const bool useVectorAlignment;
};
@@ -1996,28 +1999,23 @@ void mlir::populateVectorToLLVMConversionPatterns(
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
+ patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
+ VectorLoadStoreConversion<vector::MaskedLoadOp>,
+ VectorLoadStoreConversion<vector::StoreOp>,
+ VectorLoadStoreConversion<vector::MaskedStoreOp>,
+ VectorGatherOpConversion, VectorScatterOpConversion>(
+ converter, useVectorAlignment);
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
VectorInsertOpConversion, VectorPrintOpConversion,
- VectorTypeCastOpConversion, VectorScaleOpConversion>(ctx),
- patterns.add<VectorLoadStoreConversion<vector::LoadOp>>(
- ctx, useVectorAlignment),
- patterns.add<VectorLoadStoreConversion<vector::MaskedLoadOp>>(
- ctx, useVectorAlignment),
- patterns.add<VectorLoadStoreConversion<vector::StoreOp>>(
- ctx, useVectorAlignment),
- patterns.add<VectorLoadStoreConversion<vector::MaskedStoreOp>>(
- ctx, useVectorAlignment),
- patterns.add<VectorGatherOpConversion>(ctx, useVectorAlignment),
- patterns.add<VectorScatterOpConversion>(ctx, useVectorAlignment),
- patterns.add<VectorExpandLoadOpConversion,
- VectorCompressStoreOpConversion, VectorSplatOpLowering,
- VectorSplatNdOpLowering, VectorScalableInsertOpLowering,
- VectorScalableExtractOpLowering, MaskedReductionOpConversion,
- VectorInterleaveOpLowering, VectorDeinterleaveOpLowering,
- VectorFromElementsLowering, VectorScalableStepOpLowering>(
- converter);
+ VectorTypeCastOpConversion, VectorScaleOpConversion,
+ VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
+ VectorSplatOpLowering, VectorSplatNdOpLowering,
+ VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
+ MaskedReductionOpConversion, VectorInterleaveOpLowering,
+ VectorDeinterleaveOpLowering, VectorFromElementsLowering,
+ VectorScalableStepOpLowering>(converter);
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
index 791fe81fcc652..96187d92afb2f 100644
--- a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
@@ -1,13 +1,13 @@
-// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=f32:32' %s --convert-vector-to-llvm='use-vector-alignment=1' | FileCheck %s --check-prefix=VEC-ALIGN
-// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=v256:256' %s --convert-vector-to-llvm='use-vector-alignment=0' | FileCheck %s --check-prefix=MEMREF-ALIGN
+// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=f32:32' %s --convert-vector-to-llvm='use-vector-alignment=0' | FileCheck %s --check-prefix=MEMREF-ALIGN
+// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=v256:256' %s --convert-vector-to-llvm='use-vector-alignment=1' | FileCheck %s --check-prefix=VEC-ALIGN
//===----------------------------------------------------------------------===//
// vector.load
//===----------------------------------------------------------------------===//
-func.func @load(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
- %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
+func.func @load(%base : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
+ %0 = vector.load %base[%i, %j] : memref<200x100xf32>, vector<8xf32>
return %0 : vector<8xf32>
}
@@ -31,9 +31,9 @@ func.func @load(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector
// vector.store
//===----------------------------------------------------------------------===//
-func.func @store(%memref : memref<200x100xf32>, %i : index, %j : index) {
+func.func @store(%base : memref<200x100xf32>, %i : index, %j : index) {
%val = arith.constant dense<11.0> : vector<4xf32>
- vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32>
+ vector.store %val, %base[%i, %j] : memref<200x100xf32>, vector<4xf32>
return
}
@@ -57,9 +57,9 @@ func.func @store(%memref : memref<200x100xf32>, %i : index, %j : index) {
// vector.maskedload
//===----------------------------------------------------------------------===//
-func.func @masked_load(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
+func.func @masked_load(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0: index
- %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ %0 = vector.maskedload %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %0 : vector<16xf32>
}
@@ -83,9 +83,9 @@ func.func @masked_load(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector
// vector.maskedstore
//===----------------------------------------------------------------------===//
-func.func @masked_store(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
+func.func @masked_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
%c0 = arith.constant 0: index
- vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32>
+ vector.maskedstore %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32>
return
}
@@ -107,9 +107,9 @@ func.func @masked_store(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vecto
// vector.scatter
//===----------------------------------------------------------------------===//
-func.func @scatter(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
+func.func @scatter(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %value: vector<3xf32>) {
%0 = arith.constant 0: index
- vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+ vector.scatter %base[%0][%index], %mask, %value : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
return
}
@@ -127,9 +127,9 @@ func.func @scatter(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi
// vector.gather
//===----------------------------------------------------------------------===//
-func.func @gather(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
+func.func @gather(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %passthru: vector<3xf32>) -> vector<3xf32> {
%0 = arith.constant 0: index
- %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+ %1 = vector.gather %base[%0][%index], %mask, %passthru : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
return %1 : vector<3xf32>
}
>From b483329ab3cc6446c3fb35e5f3c001051fdad552 Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Wed, 30 Apr 2025 20:51:17 +0000
Subject: [PATCH 07/11] :)
---
mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
index 91e5358622b69..4c2e68866f5bf 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
@@ -17,7 +17,7 @@
// CMP64: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi64>
// CMP64: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<11xi64>
// CMP64: %[[T4:.*]] = arith.cmpi sgt, %[[T3]], %[[T0]] : vector<11xi64>
-// CMP64: return %[[T4]] : vector<11xi1>
+// CMP64: return %[[T4]] : vector<11xi1> :)
func.func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
%0 = vector.create_mask %arg0 : vector<11xi1>
>From 199e5c71cc46d7835318255790c285386c800abf Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Wed, 30 Apr 2025 20:51:35 +0000
Subject: [PATCH 08/11] Fix
---
mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
index 4c2e68866f5bf..91e5358622b69 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
@@ -17,7 +17,7 @@
// CMP64: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi64>
// CMP64: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<11xi64>
// CMP64: %[[T4:.*]] = arith.cmpi sgt, %[[T3]], %[[T0]] : vector<11xi64>
-// CMP64: return %[[T4]] : vector<11xi1> :)
+// CMP64: return %[[T4]] : vector<11xi1>
func.func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
%0 = vector.create_mask %arg0 : vector<11xi1>
>From 9ef0047c294b1b11e9bab946f22995b8e5a60199 Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Wed, 30 Apr 2025 21:21:08 +0000
Subject: [PATCH 09/11] fix vector alignment helper
---
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 1 +
mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir | 4 ++--
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 6ebac30690959..839a8761ba4f4 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -110,6 +110,7 @@ LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
return failure();
}
}
+ return success();
}
// Check if the last stride is non-unit and has a valid memory space.
diff --git a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
index 96187d92afb2f..bb8448eab35e5 100644
--- a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=f32:32' %s --convert-vector-to-llvm='use-vector-alignment=0' | FileCheck %s --check-prefix=MEMREF-ALIGN
-// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=v256:256' %s --convert-vector-to-llvm='use-vector-alignment=1' | FileCheck %s --check-prefix=VEC-ALIGN
+// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=f32:32' %s --convert-vector-to-llvm='use-vector-alignment=0' --split-input-file | FileCheck %s --check-prefix=MEMREF-ALIGN
+// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=v256:256' %s --convert-vector-to-llvm='use-vector-alignment=1' --split-input-file | FileCheck %s --check-prefix=VEC-ALIGN
//===----------------------------------------------------------------------===//
>From d5fb52c3197e8931e654440cad4a3fea72d82395 Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Wed, 30 Apr 2025 21:50:19 +0000
Subject: [PATCH 10/11] fix test
---
mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
index bb8448eab35e5..37fa0fd447131 100644
--- a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=f32:32' %s --convert-vector-to-llvm='use-vector-alignment=0' --split-input-file | FileCheck %s --check-prefix=MEMREF-ALIGN
-// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=v256:256' %s --convert-vector-to-llvm='use-vector-alignment=1' --split-input-file | FileCheck %s --check-prefix=VEC-ALIGN
+// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=v128:256-v256:256-v512:256' %s --convert-vector-to-llvm='use-vector-alignment=1' --split-input-file | FileCheck %s --check-prefix=VEC-ALIGN
//===----------------------------------------------------------------------===//
@@ -115,7 +115,7 @@ func.func @scatter(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3x
// VEC-ALIGN-LABEL: func @scatter
// VEC-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
-// VEC-ALIGN: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 32 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+// VEC-ALIGN: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 16 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
// MEMREF-ALIGN-LABEL: func @scatter
// MEMREF-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
@@ -135,7 +135,7 @@ func.func @gather(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi
// VEC-ALIGN-LABEL: func @gather
// VEC-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
-// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 32 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 16 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
// VEC-ALIGN: return %[[G]] : vector<3xf32>
// MEMREF-ALIGN-LABEL: func @gather
>From 876725398b14cc6c9691a39396e9450e91428ce5 Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Wed, 30 Apr 2025 22:49:41 +0000
Subject: [PATCH 11/11] test setting f32 alignment
---
mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
index 37fa0fd447131..c1e449de50d5c 100644
--- a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=f32:32' %s --convert-vector-to-llvm='use-vector-alignment=0' --split-input-file | FileCheck %s --check-prefix=MEMREF-ALIGN
+// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=f32:64' %s --convert-vector-to-llvm='use-vector-alignment=0' --split-input-file | FileCheck %s --check-prefix=MEMREF-ALIGN
// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=v128:256-v256:256-v512:256' %s --convert-vector-to-llvm='use-vector-alignment=1' --split-input-file | FileCheck %s --check-prefix=VEC-ALIGN
More information about the Mlir-commits
mailing list