[Mlir-commits] [mlir] [mlir] Add use-vector-alignment flag to ConvertVectorToLLVMPass (PR #137389)

Lily Orth-Smith llvmlistbot at llvm.org
Wed Apr 30 15:50:00 PDT 2025


https://github.com/electriclilies updated https://github.com/llvm/llvm-project/pull/137389

>From de7365648aabbd0834b4d50667938be211139f2f Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Fri, 25 Apr 2025 19:51:30 +0000
Subject: [PATCH 01/11] Fix how we get alignment for memrefs

---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 19 +++++++++++++++----
 1 file changed, 15 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 076e5512f375b..d5b312d6929be 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -70,15 +70,26 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
 // Helper that returns data layout alignment of a memref.
 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
                                  MemRefType memrefType, unsigned &align) {
-  Type elementTy = typeConverter.convertType(memrefType.getElementType());
-  if (!elementTy)
+  // If shape is statically known, assign MemRefTypes to the alignment of a
+  // VectorType with the same size and dtype. Otherwise, fall back to the
+  // alignment of the element type.
+  Type convertedType;
+  if (memrefType.hasStaticShape()) {
+    convertedType = typeConverter.convertType(VectorType::get(
+        memrefType.getNumElements(), memrefType.getElementType()));
+  } else {
+    convertedType = typeConverter.convertType(memrefType.getElementType());
+  }
+
+  if (!convertedType)
     return failure();
 
   // TODO: this should use the MLIR data layout when it becomes available and
   // stop depending on translation.
   llvm::LLVMContext llvmContext;
-  align = LLVM::TypeToLLVMIRTranslator(llvmContext)
-              .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
+  align =
+      LLVM::TypeToLLVMIRTranslator(llvmContext)
+          .getPreferredAlignment(convertedType, typeConverter.getDataLayout());
   return success();
 }
 

>From efa44d5d6f6d3aeec92c5207bc5a2e1d27a26ac8 Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Tue, 29 Apr 2025 22:51:13 +0000
Subject: [PATCH 02/11] Use flag to determine whether we use vector alignment
 or memref alignment

---
 mlir/include/mlir/Conversion/Passes.td        |   5 +
 .../VectorToLLVM/ConvertVectorToLLVM.h        |   3 +-
 .../Vector/TransformOps/VectorTransformOps.td |   3 +-
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 126 +++++++++++++-----
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |   3 +-
 .../TransformOps/VectorTransformOps.cpp       |   3 +-
 6 files changed, 103 insertions(+), 40 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..3f64e719a3fcb 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1394,6 +1394,11 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            "bool", /*default=*/"true",
            "Allows compiler to assume vector indices fit in 32-bit if that "
      "yields faster code">,
+    Option<"useVectorAlignment", "use-vector-alignment",
+           "bool", /*default=*/"false",
+           "Use the preferred alignment of a vector type in load/store "
+      "operations instead of the alignment of the element type of the "
+      "memref">,
     Option<"amx", "enable-amx",
            "bool", /*default=*/"false",
            "Enables the use of AMX dialect while lowering the vector "
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 1e29bfeb9c392..f6b09deb4e44c 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -22,7 +22,8 @@ void populateVectorToLLVMMatrixConversionPatterns(
 /// Collect a set of patterns to convert from the Vector dialect to LLVM.
 void populateVectorToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    bool reassociateFPReductions = false, bool force32BitVectorIndices = false);
+    bool reassociateFPReductions = false, bool force32BitVectorIndices = false,
+    bool useVectorAlignment = false);
 
 namespace vector {
 void registerConvertVectorToLLVMInterface(DialectRegistry &registry);
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 14cbbac99d9ae..299f198e4ab9c 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -32,7 +32,8 @@ def ApplyVectorToLLVMConversionPatternsOp : Op<Transform_Dialect,
 
   let arguments = (ins
       DefaultValuedAttr<BoolAttr, "false">:$reassociate_fp_reductions,
-      DefaultValuedAttr<BoolAttr, "true">:$force_32bit_vector_indices);
+      DefaultValuedAttr<BoolAttr, "true">:$force_32bit_vector_indices,
+      DefaultValuedAttr<BoolAttr, "false">:$use_vector_alignment);
   let assemblyFormat = "attr-dict";
 }
 
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d5b312d6929be..7aa6a4257477a 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -67,29 +67,33 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
   return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
 }
 
+// Helper that returns data layout alignment of a vector.
+LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
+                                 VectorType vectorType, unsigned &align) {
+  Type convertedVectorTy = typeConverter.convertType(vectorType);
+  if (!convertedVectorTy)
+    return failure();
+
+  llvm::LLVMContext llvmContext;
+  align = LLVM::TypeToLLVMIRTranslator(llvmContext)
+              .getPreferredAlignment(convertedVectorTy,
+                                     typeConverter.getDataLayout());
+
+  return success();
+}
+
 // Helper that returns data layout alignment of a memref.
 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
                                  MemRefType memrefType, unsigned &align) {
-  // If shape is statically known, assign MemRefTypes to the alignment of a
-  // VectorType with the same size and dtype. Otherwise, fall back to the
-  // alignment of the element type.
-  Type convertedType;
-  if (memrefType.hasStaticShape()) {
-    convertedType = typeConverter.convertType(VectorType::get(
-        memrefType.getNumElements(), memrefType.getElementType()));
-  } else {
-    convertedType = typeConverter.convertType(memrefType.getElementType());
-  }
-
-  if (!convertedType)
+  Type elementTy = typeConverter.convertType(memrefType.getElementType());
+  if (!elementTy)
     return failure();
 
   // TODO: this should use the MLIR data layout when it becomes available and
   // stop depending on translation.
   llvm::LLVMContext llvmContext;
-  align =
-      LLVM::TypeToLLVMIRTranslator(llvmContext)
-          .getPreferredAlignment(convertedType, typeConverter.getDataLayout());
+  align = LLVM::TypeToLLVMIRTranslator(llvmContext)
+              .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
   return success();
 }
 
@@ -235,6 +239,10 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
 template <class LoadOrStoreOp>
 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
 public:
+  explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
+                                     bool useVectorAlign)
+      : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
+        useVectorAlignment(useVectorAlign) {}
   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
@@ -251,8 +259,17 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
-      return failure();
+    if (useVectorAlignment) {
+      if (failed(
+              getVectorAlignment(*this->getTypeConverter(), vectorTy, align)))
+        return rewriter.notifyMatchFailure(
+            loadOrStoreOp, "could not resolve vector alignment");
+    } else {
+      if (failed(
+              getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
+        return rewriter.notifyMatchFailure(
+            loadOrStoreOp, "could not resolve memref alignment");
+    }
 
     // Resolve address.
     auto vtype = cast<VectorType>(
@@ -263,12 +280,19 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
                          rewriter);
     return success();
   }
+
+private:
+  const bool useVectorAlignment;
 };
 
 /// Conversion pattern for a vector.gather.
 class VectorGatherOpConversion
     : public ConvertOpToLLVMPattern<vector::GatherOp> {
 public:
+  explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
+                                    bool useVectorAlign)
+      : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
+        useVectorAlignment(useVectorAlign) {}
   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
@@ -289,9 +313,15 @@ class VectorGatherOpConversion
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
-      return rewriter.notifyMatchFailure(gather,
-                                         "could not resolve memref alignment");
+    if (useVectorAlignment) {
+      if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align)))
+        return rewriter.notifyMatchFailure(
+            gather, "could not resolve vector alignment");
+    } else {
+      if (failed(
+              getMemRefAlignment(*this->getTypeConverter(), memRefType, align)))
+        return rewriter.notifyMatchFailure(
+            gather, "could not resolve memref alignment");
     }
 
     // Resolve address.
@@ -308,12 +338,20 @@ class VectorGatherOpConversion
         adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
     return success();
   }
+
+private:
+  const bool useVectorAlignment;
 };
 
 /// Conversion pattern for a vector.scatter.
 class VectorScatterOpConversion
     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
 public:
+  explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
+                                     bool useVectorAlign)
+      : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
+        useVectorAlignment(useVectorAlign) {}
+
   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
@@ -333,9 +371,15 @@ class VectorScatterOpConversion
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
-      return rewriter.notifyMatchFailure(scatter,
-                                         "could not resolve memref alignment");
+    if (useVectorAlignment) {
+      if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align)))
+        return rewriter.notifyMatchFailure(
+            scatter, "could not resolve vector alignment");
+    } else {
+      if (failed(
+              getMemRefAlignment(*this->getTypeConverter(), memRefType, align)))
+        return rewriter.notifyMatchFailure(
+            scatter, "could not resolve memref alignment");
     }
 
     // Resolve address.
@@ -351,6 +395,9 @@ class VectorScatterOpConversion
         rewriter.getI32IntegerAttr(align));
     return success();
   }
+
+private:
+  const bool useVectorAlignment;
 };
 
 /// Conversion pattern for a vector.expandload.
@@ -1939,7 +1986,8 @@ void mlir::vector::populateVectorRankReducingFMAPattern(
 /// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::populateVectorToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    bool reassociateFPReductions, bool force32BitVectorIndices) {
+    bool reassociateFPReductions, bool force32BitVectorIndices,
+    bool useVectorAlignment) {
   // This function populates only ConversionPatterns, not RewritePatterns.
   MLIRContext *ctx = converter.getDialect()->getContext();
   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
@@ -1948,18 +1996,24 @@ void mlir::populateVectorToLLVMConversionPatterns(
                VectorExtractElementOpConversion, VectorExtractOpConversion,
                VectorFMAOp1DConversion, VectorInsertElementOpConversion,
                VectorInsertOpConversion, VectorPrintOpConversion,
-               VectorTypeCastOpConversion, VectorScaleOpConversion,
-               VectorLoadStoreConversion<vector::LoadOp>,
-               VectorLoadStoreConversion<vector::MaskedLoadOp>,
-               VectorLoadStoreConversion<vector::StoreOp>,
-               VectorLoadStoreConversion<vector::MaskedStoreOp>,
-               VectorGatherOpConversion, VectorScatterOpConversion,
-               VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
-               VectorSplatOpLowering, VectorSplatNdOpLowering,
-               VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
-               MaskedReductionOpConversion, VectorInterleaveOpLowering,
-               VectorDeinterleaveOpLowering, VectorFromElementsLowering,
-               VectorScalableStepOpLowering>(converter);
+               VectorTypeCastOpConversion, VectorScaleOpConversion>(ctx),
+      patterns.add<VectorLoadStoreConversion<vector::LoadOp>>(
+          ctx, useVectorAlignment),
+      patterns.add<VectorLoadStoreConversion<vector::MaskedLoadOp>>(
+          ctx, useVectorAlignment),
+      patterns.add<VectorLoadStoreConversion<vector::StoreOp>>(
+          ctx, useVectorAlignment),
+      patterns.add<VectorLoadStoreConversion<vector::MaskedStoreOp>>(
+          ctx, useVectorAlignment),
+      patterns.add<VectorGatherOpConversion>(ctx, useVectorAlignment),
+      patterns.add<VectorScatterOpConversion>(ctx, useVectorAlignment),
+      patterns.add<VectorExpandLoadOpConversion,
+                   VectorCompressStoreOpConversion, VectorSplatOpLowering,
+                   VectorSplatNdOpLowering, VectorScalableInsertOpLowering,
+                   VectorScalableExtractOpLowering, MaskedReductionOpConversion,
+                   VectorInterleaveOpLowering, VectorDeinterleaveOpLowering,
+                   VectorFromElementsLowering, VectorScalableStepOpLowering>(
+          converter);
 }
 
 void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..0ee6dce9ee94b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -92,7 +92,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
   populateVectorTransferLoweringPatterns(patterns);
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
   populateVectorToLLVMConversionPatterns(
-      converter, patterns, reassociateFPReductions, force32BitVectorIndices);
+      converter, patterns, reassociateFPReductions, force32BitVectorIndices,
+      useVectorAlignment);
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
 
   // Architecture specific augmentations.
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 1be436dd7bf41..125c3d918284c 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -34,7 +34,8 @@ void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
     TypeConverter &typeConverter, RewritePatternSet &patterns) {
   populateVectorToLLVMConversionPatterns(
       static_cast<LLVMTypeConverter &>(typeConverter), patterns,
-      getReassociateFpReductions(), getForce_32bitVectorIndices());
+      getReassociateFpReductions(), getForce_32bitVectorIndices(),
+      getUseVectorAlignment());
 }
 
 LogicalResult

>From c9cc43c128ae73f3e38011321e511b2d338911cd Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Wed, 30 Apr 2025 00:15:26 +0000
Subject: [PATCH 03/11] Add test for use-vector-alignment flag

---
 .../VectorToLLVM/use-vector-alignment.mlir    | 144 ++++++++++++++++++
 1 file changed, 144 insertions(+)
 create mode 100644 mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir

diff --git a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
new file mode 100644
index 0000000000000..791fe81fcc652
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
@@ -0,0 +1,144 @@
+// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=f32:32' %s --convert-vector-to-llvm='use-vector-alignment=1' | FileCheck %s --check-prefix=VEC-ALIGN
+// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=v256:256' %s --convert-vector-to-llvm='use-vector-alignment=0' | FileCheck %s --check-prefix=MEMREF-ALIGN
+
+
+//===----------------------------------------------------------------------===//
+// vector.load
+//===----------------------------------------------------------------------===//
+
+func.func @load(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// VEC-ALIGN-LABEL: func @load
+// VEC-ALIGN: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
+// VEC-ALIGN: %[[MUL:.*]] = llvm.mul %{{.*}}, %[[C100]]  : i64
+// VEC-ALIGN: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}}  : i64
+// VEC-ALIGN: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// VEC-ALIGN: llvm.load %[[GEP]] {alignment = 32 : i64} : !llvm.ptr -> vector<8xf32>
+
+// MEMREF-ALIGN-LABEL: func @load
+// MEMREF-ALIGN: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
+// MEMREF-ALIGN: %[[MUL:.*]] = llvm.mul %{{.*}}, %[[C100]]  : i64
+// MEMREF-ALIGN: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}}  : i64
+// MEMREF-ALIGN: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// MEMREF-ALIGN: llvm.load %[[GEP]] {alignment = 4 : i64} : !llvm.ptr -> vector<8xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.store
+//===----------------------------------------------------------------------===//
+
+func.func @store(%memref : memref<200x100xf32>, %i : index, %j : index) {
+  %val = arith.constant dense<11.0> : vector<4xf32>
+  vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32>
+  return
+}
+
+// VEC-ALIGN-LABEL: func @store
+// VEC-ALIGN: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
+// VEC-ALIGN: %[[MUL:.*]] = llvm.mul %{{.*}}, %[[C100]]  : i64
+// VEC-ALIGN: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}}  : i64
+// VEC-ALIGN: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// VEC-ALIGN: llvm.store %{{.*}}, %[[GEP]] {alignment = 32 : i64} :  vector<4xf32>, !llvm.ptr
+
+// MEMREF-ALIGN-LABEL: func @store
+// MEMREF-ALIGN: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
+// MEMREF-ALIGN: %[[MUL:.*]] = llvm.mul %{{.*}}, %[[C100]]  : i64
+// MEMREF-ALIGN: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}}  : i64
+// MEMREF-ALIGN: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// MEMREF-ALIGN: llvm.store %{{.*}}, %[[GEP]] {alignment = 4 : i64} :  vector<4xf32>, !llvm.ptr
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.maskedload
+//===----------------------------------------------------------------------===//
+
+func.func @masked_load(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
+  %c0 = arith.constant 0: index
+  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  return %0 : vector<16xf32>
+}
+
+// VEC-ALIGN-LABEL: func @masked_load
+// VEC-ALIGN: %[[CO:.*]] = arith.constant 0 : index
+// VEC-ALIGN: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64
+// VEC-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// VEC-ALIGN: %[[L:.*]] = llvm.intr.masked.load %[[P]], %{{.*}}, %{{.*}} {alignment = 32 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+// VEC-ALIGN: return %[[L]] : vector<16xf32>
+
+// MEMREF-ALIGN-LABEL: func @masked_load
+// MEMREF-ALIGN: %[[CO:.*]] = arith.constant 0 : index
+// MEMREF-ALIGN: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64
+// MEMREF-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// MEMREF-ALIGN: %[[L:.*]] = llvm.intr.masked.load %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+// MEMREF-ALIGN: return %[[L]] : vector<16xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.maskedstore
+//===----------------------------------------------------------------------===//
+
+func.func @masked_store(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
+  %c0 = arith.constant 0: index
+  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32>
+  return
+}
+
+// VEC-ALIGN-LABEL: func @masked_store
+// VEC-ALIGN: %[[CO:.*]] = arith.constant 0 : index
+// VEC-ALIGN: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64
+// VEC-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// VEC-ALIGN: llvm.intr.masked.store %{{.*}}, %[[P]], %{{.*}} {alignment = 32 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
+
+// MEMREF-ALIGN-LABEL: func @masked_store
+// MEMREF-ALIGN: %[[CO:.*]] = arith.constant 0 : index
+// MEMREF-ALIGN: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64
+// MEMREF-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+// MEMREF-ALIGN: llvm.intr.masked.store %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.scatter
+//===----------------------------------------------------------------------===//
+
+func.func @scatter(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
+  %0 = arith.constant 0: index
+  vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+  return
+}
+
+// VEC-ALIGN-LABEL: func @scatter
+// VEC-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
+// VEC-ALIGN: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 32 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+
+// MEMREF-ALIGN-LABEL: func @scatter
+// MEMREF-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
+// MEMREF-ALIGN: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.gather
+//===----------------------------------------------------------------------===//
+
+func.func @gather(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
+  %0 = arith.constant 0: index
+  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+  return %1 : vector<3xf32>
+}
+
+// VEC-ALIGN-LABEL: func @gather
+// VEC-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
+// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 32 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// VEC-ALIGN: return %[[G]] : vector<3xf32>
+
+// MEMREF-ALIGN-LABEL: func @gather
+// MEMREF-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
+// MEMREF-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// MEMREF-ALIGN: return %[[G]] : vector<3xf32>

>From 7a709c24ec7205a7f3c850501dbca85d48275364 Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Wed, 30 Apr 2025 01:50:17 +0000
Subject: [PATCH 04/11] Try to fix timeout

---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 22 ++++++++++++-------
 1 file changed, 14 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 7aa6a4257477a..05b01acb00623 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -261,14 +261,16 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
     unsigned align;
     if (useVectorAlignment) {
       if (failed(
-              getVectorAlignment(*this->getTypeConverter(), vectorTy, align)))
+              getVectorAlignment(*this->getTypeConverter(), vectorTy, align))) {
         return rewriter.notifyMatchFailure(
             loadOrStoreOp, "could not resolve vector alignment");
+      }
     } else {
       if (failed(
-              getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
+              getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) {
         return rewriter.notifyMatchFailure(
             loadOrStoreOp, "could not resolve memref alignment");
+      }
     }
 
     // Resolve address.
@@ -314,14 +316,16 @@ class VectorGatherOpConversion
     // Resolve alignment.
     unsigned align;
     if (useVectorAlignment) {
-      if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align)))
+      if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align))) {
         return rewriter.notifyMatchFailure(
             gather, "could not resolve vector alignment");
+      }
     } else {
-      if (failed(
-              getMemRefAlignment(*this->getTypeConverter(), memRefType, align)))
+      if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefType,
+                                    align))) {
         return rewriter.notifyMatchFailure(
             gather, "could not resolve memref alignment");
+      }
     }
 
     // Resolve address.
@@ -372,14 +376,16 @@ class VectorScatterOpConversion
     // Resolve alignment.
     unsigned align;
     if (useVectorAlignment) {
-      if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align)))
+      if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align))) {
         return rewriter.notifyMatchFailure(
             scatter, "could not resolve vector alignment");
+      }
     } else {
-      if (failed(
-              getMemRefAlignment(*this->getTypeConverter(), memRefType, align)))
+      if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefType,
+                                    align))) {
         return rewriter.notifyMatchFailure(
             scatter, "could not resolve memref alignment");
+      }
     }
 
     // Resolve address.

>From e9593e4142abf6cb021965c24e57b707e061bc57 Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Wed, 30 Apr 2025 17:54:46 +0000
Subject: [PATCH 05/11] change failure

---
 mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 05b01acb00623..ecf6504c6e2da 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -262,14 +262,12 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
     if (useVectorAlignment) {
       if (failed(
               getVectorAlignment(*this->getTypeConverter(), vectorTy, align))) {
-        return rewriter.notifyMatchFailure(
-            loadOrStoreOp, "could not resolve vector alignment");
+        return failure();
       }
     } else {
       if (failed(
               getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) {
-        return rewriter.notifyMatchFailure(
-            loadOrStoreOp, "could not resolve memref alignment");
+        return failure();
       }
     }
 

>From 29598e8d052a6d9fe77912a3883da0575eaa6ef4 Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Wed, 30 Apr 2025 18:04:05 +0000
Subject: [PATCH 06/11] Fix how patterns are added to the pass

---
 mlir/include/mlir/Conversion/Passes.td        |   5 +-
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 104 +++++++++---------
 .../VectorToLLVM/use-vector-alignment.mlir    |  28 ++---
 3 files changed, 68 insertions(+), 69 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 3f64e719a3fcb..4ab2f300d0ccc 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1397,8 +1397,9 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
     Option<"useVectorAlignment", "use-vector-alignment",
            "bool", /*default=*/"false",
            "Use the preferred alignment of a vector type in load/store "
-      "operations instead of the alignment of the element type of the "
-      "memref">,
+           "operations instead of the alignment of the element type of the "
+           "memref. This flag is intended for use with hardware which requires" 
+           "vector alignment.">,
     Option<"amx", "enable-amx",
            "bool", /*default=*/"false",
            "Enables the use of AMX dialect while lowering the vector "
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index ecf6504c6e2da..6ebac30690959 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -97,6 +97,21 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
   return success();
 }
 
+LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
+                                       VectorType vectorType,
+                                       MemRefType memrefType, unsigned &align,
+                                       bool useVectorAlignment) {
+  if (useVectorAlignment) {
+    if (failed(getVectorAlignment(typeConverter, vectorType, align))) {
+      return failure();
+    }
+  } else {
+    if (failed(getMemRefAlignment(typeConverter, memrefType, align))) {
+      return failure();
+    }
+  }
+}
+
 // Check if the last stride is non-unit and has a valid memory space.
 static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
                                            const LLVMTypeConverter &converter) {
@@ -259,17 +274,10 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
 
     // Resolve alignment.
     unsigned align;
-    if (useVectorAlignment) {
-      if (failed(
-              getVectorAlignment(*this->getTypeConverter(), vectorTy, align))) {
-        return failure();
-      }
-    } else {
-      if (failed(
-              getMemRefAlignment(*this->getTypeConverter(), memRefTy, align))) {
-        return failure();
-      }
-    }
+    if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
+                                        memRefTy, align, useVectorAlignment)))
+      return rewriter.notifyMatchFailure(loadOrStoreOp,
+                                         "could not resolve alignment");
 
     // Resolve address.
     auto vtype = cast<VectorType>(
@@ -282,6 +290,10 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
   }
 
 private:
+  // If true, use the preferred alignment of the vector type.
+  // If false, use the preferred alignment of the element type
+  // of the memref. This flag is intended for use with hardware
+  // backends that require alignment of vector operations.
   const bool useVectorAlignment;
 };
 
@@ -313,18 +325,9 @@ class VectorGatherOpConversion
 
     // Resolve alignment.
     unsigned align;
-    if (useVectorAlignment) {
-      if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align))) {
-        return rewriter.notifyMatchFailure(
-            gather, "could not resolve vector alignment");
-      }
-    } else {
-      if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefType,
-                                    align))) {
-        return rewriter.notifyMatchFailure(
-            gather, "could not resolve memref alignment");
-      }
-    }
+    if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
+                                        memRefType, align, useVectorAlignment)))
+      return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
 
     // Resolve address.
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
@@ -342,6 +345,10 @@ class VectorGatherOpConversion
   }
 
 private:
+  // If true, use the preferred alignment of the vector type.
+  // If false, use the preferred alignment of the element type
+  // of the memref. This flag is intended for use with hardware
+  // backends that require alignment of vector operations.
   const bool useVectorAlignment;
 };
 
@@ -373,18 +380,10 @@ class VectorScatterOpConversion
 
     // Resolve alignment.
     unsigned align;
-    if (useVectorAlignment) {
-      if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align))) {
-        return rewriter.notifyMatchFailure(
-            scatter, "could not resolve vector alignment");
-      }
-    } else {
-      if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefType,
-                                    align))) {
-        return rewriter.notifyMatchFailure(
-            scatter, "could not resolve memref alignment");
-      }
-    }
+    if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
+                                        memRefType, align, useVectorAlignment)))
+      return rewriter.notifyMatchFailure(scatter,
+                                         "could not resolve alignment");
 
     // Resolve address.
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
@@ -401,6 +400,10 @@ class VectorScatterOpConversion
   }
 
 private:
+  // If true, use the preferred alignment of the vector type.
+  // If false, use the preferred alignment of the element type
+  // of the memref. This flag is intended for use with hardware
+  // backends that require alignment of vector operations.
   const bool useVectorAlignment;
 };
 
@@ -1996,28 +1999,23 @@ void mlir::populateVectorToLLVMConversionPatterns(
   MLIRContext *ctx = converter.getDialect()->getContext();
   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
   patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
+  patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
+               VectorLoadStoreConversion<vector::MaskedLoadOp>,
+               VectorLoadStoreConversion<vector::StoreOp>,
+               VectorLoadStoreConversion<vector::MaskedStoreOp>,
+               VectorGatherOpConversion, VectorScatterOpConversion>(
+      converter, useVectorAlignment);
   patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
                VectorExtractElementOpConversion, VectorExtractOpConversion,
                VectorFMAOp1DConversion, VectorInsertElementOpConversion,
                VectorInsertOpConversion, VectorPrintOpConversion,
-               VectorTypeCastOpConversion, VectorScaleOpConversion>(ctx),
-      patterns.add<VectorLoadStoreConversion<vector::LoadOp>>(
-          ctx, useVectorAlignment),
-      patterns.add<VectorLoadStoreConversion<vector::MaskedLoadOp>>(
-          ctx, useVectorAlignment),
-      patterns.add<VectorLoadStoreConversion<vector::StoreOp>>(
-          ctx, useVectorAlignment),
-      patterns.add<VectorLoadStoreConversion<vector::MaskedStoreOp>>(
-          ctx, useVectorAlignment),
-      patterns.add<VectorGatherOpConversion>(ctx, useVectorAlignment),
-      patterns.add<VectorScatterOpConversion>(ctx, useVectorAlignment),
-      patterns.add<VectorExpandLoadOpConversion,
-                   VectorCompressStoreOpConversion, VectorSplatOpLowering,
-                   VectorSplatNdOpLowering, VectorScalableInsertOpLowering,
-                   VectorScalableExtractOpLowering, MaskedReductionOpConversion,
-                   VectorInterleaveOpLowering, VectorDeinterleaveOpLowering,
-                   VectorFromElementsLowering, VectorScalableStepOpLowering>(
-          converter);
+               VectorTypeCastOpConversion, VectorScaleOpConversion,
+               VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
+               VectorSplatOpLowering, VectorSplatNdOpLowering,
+               VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
+               MaskedReductionOpConversion, VectorInterleaveOpLowering,
+               VectorDeinterleaveOpLowering, VectorFromElementsLowering,
+               VectorScalableStepOpLowering>(converter);
 }
 
 void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
index 791fe81fcc652..96187d92afb2f 100644
--- a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
@@ -1,13 +1,13 @@
-// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=f32:32' %s --convert-vector-to-llvm='use-vector-alignment=1' | FileCheck %s --check-prefix=VEC-ALIGN
-// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=v256:256' %s --convert-vector-to-llvm='use-vector-alignment=0' | FileCheck %s --check-prefix=MEMREF-ALIGN
+// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=f32:32' %s --convert-vector-to-llvm='use-vector-alignment=0' | FileCheck %s --check-prefix=MEMREF-ALIGN
+// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=v256:256' %s --convert-vector-to-llvm='use-vector-alignment=1' | FileCheck %s --check-prefix=VEC-ALIGN
 
 
 //===----------------------------------------------------------------------===//
 // vector.load
 //===----------------------------------------------------------------------===//
 
-func.func @load(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
-  %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
+func.func @load(%base : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
+  %0 = vector.load %base[%i, %j] : memref<200x100xf32>, vector<8xf32>
   return %0 : vector<8xf32>
 }
 
@@ -31,9 +31,9 @@ func.func @load(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector
 // vector.store
 //===----------------------------------------------------------------------===//
 
-func.func @store(%memref : memref<200x100xf32>, %i : index, %j : index) {
+func.func @store(%base : memref<200x100xf32>, %i : index, %j : index) {
   %val = arith.constant dense<11.0> : vector<4xf32>
-  vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32>
+  vector.store %val, %base[%i, %j] : memref<200x100xf32>, vector<4xf32>
   return
 }
 
@@ -57,9 +57,9 @@ func.func @store(%memref : memref<200x100xf32>, %i : index, %j : index) {
 // vector.maskedload
 //===----------------------------------------------------------------------===//
 
-func.func @masked_load(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
+func.func @masked_load(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
   %c0 = arith.constant 0: index
-  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  %0 = vector.maskedload %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
   return %0 : vector<16xf32>
 }
 
@@ -83,9 +83,9 @@ func.func @masked_load(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector
 // vector.maskedstore
 //===----------------------------------------------------------------------===//
 
-func.func @masked_store(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
+func.func @masked_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
   %c0 = arith.constant 0: index
-  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32>
+  vector.maskedstore %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32>
   return
 }
 
@@ -107,9 +107,9 @@ func.func @masked_store(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vecto
 // vector.scatter
 //===----------------------------------------------------------------------===//
 
-func.func @scatter(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
+func.func @scatter(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %value: vector<3xf32>) {
   %0 = arith.constant 0: index
-  vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+  vector.scatter %base[%0][%index], %mask, %value : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
   return
 }
 
@@ -127,9 +127,9 @@ func.func @scatter(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi
 // vector.gather
 //===----------------------------------------------------------------------===//
 
-func.func @gather(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
+func.func @gather(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %passthru: vector<3xf32>) -> vector<3xf32> {
   %0 = arith.constant 0: index
-  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+  %1 = vector.gather %base[%0][%index], %mask, %passthru : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
   return %1 : vector<3xf32>
 }
 

>From b483329ab3cc6446c3fb35e5f3c001051fdad552 Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Wed, 30 Apr 2025 20:51:17 +0000
Subject: [PATCH 07/11] :)

---
 mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
index 91e5358622b69..4c2e68866f5bf 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
@@ -17,7 +17,7 @@
 // CMP64: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi64>
 // CMP64: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<11xi64>
 // CMP64: %[[T4:.*]] = arith.cmpi sgt, %[[T3]], %[[T0]] : vector<11xi64>
-// CMP64: return %[[T4]] : vector<11xi1>
+// CMP64: return %[[T4]] : vector<11xi1> :)
 
 func.func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
   %0 = vector.create_mask %arg0 : vector<11xi1>

>From 199e5c71cc46d7835318255790c285386c800abf Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Wed, 30 Apr 2025 20:51:35 +0000
Subject: [PATCH 08/11] Fix

---
 mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
index 4c2e68866f5bf..91e5358622b69 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
@@ -17,7 +17,7 @@
 // CMP64: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi64>
 // CMP64: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<11xi64>
 // CMP64: %[[T4:.*]] = arith.cmpi sgt, %[[T3]], %[[T0]] : vector<11xi64>
-// CMP64: return %[[T4]] : vector<11xi1> :)
+// CMP64: return %[[T4]] : vector<11xi1>
 
 func.func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
   %0 = vector.create_mask %arg0 : vector<11xi1>

>From 9ef0047c294b1b11e9bab946f22995b8e5a60199 Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Wed, 30 Apr 2025 21:21:08 +0000
Subject: [PATCH 09/11] fix vector alignment helper

---
 mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp    | 1 +
 mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir | 4 ++--
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 6ebac30690959..839a8761ba4f4 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -110,6 +110,7 @@ LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
       return failure();
     }
   }
+  return success();
 }
 
 // Check if the last stride is non-unit and has a valid memory space.
diff --git a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
index 96187d92afb2f..bb8448eab35e5 100644
--- a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=f32:32' %s --convert-vector-to-llvm='use-vector-alignment=0' | FileCheck %s --check-prefix=MEMREF-ALIGN
-// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=v256:256' %s --convert-vector-to-llvm='use-vector-alignment=1' | FileCheck %s --check-prefix=VEC-ALIGN
+// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=f32:32' %s --convert-vector-to-llvm='use-vector-alignment=0' --split-input-file | FileCheck %s --check-prefix=MEMREF-ALIGN
+// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=v256:256' %s --convert-vector-to-llvm='use-vector-alignment=1' --split-input-file | FileCheck %s --check-prefix=VEC-ALIGN
 
 
 //===----------------------------------------------------------------------===//

>From d5fb52c3197e8931e654440cad4a3fea72d82395 Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Wed, 30 Apr 2025 21:50:19 +0000
Subject: [PATCH 10/11] fix test

---
 mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
index bb8448eab35e5..37fa0fd447131 100644
--- a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt -set-llvm-module-datalayout='data-layout=f32:32' %s --convert-vector-to-llvm='use-vector-alignment=0' --split-input-file | FileCheck %s --check-prefix=MEMREF-ALIGN
-// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=v256:256' %s --convert-vector-to-llvm='use-vector-alignment=1' --split-input-file | FileCheck %s --check-prefix=VEC-ALIGN
+// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=v128:256-v256:256-v512:256' %s --convert-vector-to-llvm='use-vector-alignment=1' --split-input-file | FileCheck %s --check-prefix=VEC-ALIGN
 
 
 //===----------------------------------------------------------------------===//
@@ -115,7 +115,7 @@ func.func @scatter(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3x
 
 // VEC-ALIGN-LABEL: func @scatter
 // VEC-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
-// VEC-ALIGN: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 32 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+// VEC-ALIGN: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 16 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
 
 // MEMREF-ALIGN-LABEL: func @scatter
 // MEMREF-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
@@ -135,7 +135,7 @@ func.func @gather(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi
 
 // VEC-ALIGN-LABEL: func @gather
 // VEC-ALIGN: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
-// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 32 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 16 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
 // VEC-ALIGN: return %[[G]] : vector<3xf32>
 
 // MEMREF-ALIGN-LABEL: func @gather

>From 876725398b14cc6c9691a39396e9450e91428ce5 Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Wed, 30 Apr 2025 22:49:41 +0000
Subject: [PATCH 11/11] test setting f32 alignment

---
 mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
index 37fa0fd447131..c1e449de50d5c 100644
--- a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=f32:32' %s --convert-vector-to-llvm='use-vector-alignment=0' --split-input-file | FileCheck %s --check-prefix=MEMREF-ALIGN
+// RUN: mlir-opt -set-llvm-module-datalayout='data-layout=f32:64' %s --convert-vector-to-llvm='use-vector-alignment=0' --split-input-file | FileCheck %s --check-prefix=MEMREF-ALIGN
 // RUN: mlir-opt -set-llvm-module-datalayout='data-layout=v128:256-v256:256-v512:256' %s --convert-vector-to-llvm='use-vector-alignment=1' --split-input-file | FileCheck %s --check-prefix=VEC-ALIGN
 
 



More information about the Mlir-commits mailing list