[Mlir-commits] [mlir] [mlir] Fix MemRefType alignment in ConvertVectorToLLVM (PR #137389)

Lily Orth-Smith llvmlistbot at llvm.org
Tue Apr 29 16:02:50 PDT 2025


https://github.com/electriclilies updated https://github.com/llvm/llvm-project/pull/137389

>From de7365648aabbd0834b4d50667938be211139f2f Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Fri, 25 Apr 2025 19:51:30 +0000
Subject: [PATCH 1/2] Fix how we get alignment for memrefs

---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 19 +++++++++++++++----
 1 file changed, 15 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 076e5512f375b..d5b312d6929be 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -70,15 +70,26 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
 // Helper that returns data layout alignment of a memref.
 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
                                  MemRefType memrefType, unsigned &align) {
-  Type elementTy = typeConverter.convertType(memrefType.getElementType());
-  if (!elementTy)
+  // If shape is statically known, assign MemRefTypes to the alignment of a
+  // VectorType with the same size and dtype. Otherwise, fall back to the
+  // alignment of the element type.
+  Type convertedType;
+  if (memrefType.hasStaticShape()) {
+    convertedType = typeConverter.convertType(VectorType::get(
+        memrefType.getNumElements(), memrefType.getElementType()));
+  } else {
+    convertedType = typeConverter.convertType(memrefType.getElementType());
+  }
+
+  if (!convertedType)
     return failure();
 
   // TODO: this should use the MLIR data layout when it becomes available and
   // stop depending on translation.
   llvm::LLVMContext llvmContext;
-  align = LLVM::TypeToLLVMIRTranslator(llvmContext)
-              .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
+  align =
+      LLVM::TypeToLLVMIRTranslator(llvmContext)
+          .getPreferredAlignment(convertedType, typeConverter.getDataLayout());
   return success();
 }
 

>From cd537f36bdcb4bed031cf2e2423e6ea78803a179 Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lorthsmith at microsoft.com>
Date: Tue, 29 Apr 2025 22:51:13 +0000
Subject: [PATCH 2/2] Use flag to determine whether we use vector alignment or
 memref alignment

---
 mlir/include/mlir/Conversion/Passes.td        |   5 +
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 128 +++++++++++++-----
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |   3 +-
 3 files changed, 99 insertions(+), 37 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..451928b92cfa7 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1390,6 +1390,11 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
     Option<"reassociateFPReductions", "reassociate-fp-reductions",
            "bool", /*default=*/"false",
            "Allows llvm to reassociate floating-point reductions for speed">,
+    Option<"useVectorAlignment", "use-vector-alignment",
+           "bool", /*default=*/"false",
+           "Use the preferred alignment of a vector type in load/store "
+      "operations instead of the alignment of the element type of the "
+      "memref">,
     Option<"force32BitVectorIndices", "force-32bit-vector-indices",
            "bool", /*default=*/"true",
            "Allows compiler to assume vector indices fit in 32-bit if that "
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d5b312d6929be..86590c2bc3666 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -67,29 +67,33 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
   return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
 }
 
+// Helper that returns data layout alignment of a vector.
+LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
+                                 VectorType vectorType, unsigned &align) {
+  Type convertedVectorTy = typeConverter.convertType(vectorType);
+  if (!convertedVectorTy)
+    return failure();
+
+  llvm::LLVMContext llvmContext;
+  align = LLVM::TypeToLLVMIRTranslator(llvmContext)
+              .getPreferredAlignment(convertedVectorTy,
+                                     typeConverter.getDataLayout());
+
+  return success();
+}
+
 // Helper that returns data layout alignment of a memref.
 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
                                  MemRefType memrefType, unsigned &align) {
-  // If shape is statically known, assign MemRefTypes to the alignment of a
-  // VectorType with the same size and dtype. Otherwise, fall back to the
-  // alignment of the element type.
-  Type convertedType;
-  if (memrefType.hasStaticShape()) {
-    convertedType = typeConverter.convertType(VectorType::get(
-        memrefType.getNumElements(), memrefType.getElementType()));
-  } else {
-    convertedType = typeConverter.convertType(memrefType.getElementType());
-  }
-
-  if (!convertedType)
+  Type elementTy = typeConverter.convertType(memrefType.getElementType());
+  if (!elementTy)
     return failure();
 
   // TODO: this should use the MLIR data layout when it becomes available and
   // stop depending on translation.
   llvm::LLVMContext llvmContext;
-  align =
-      LLVM::TypeToLLVMIRTranslator(llvmContext)
-          .getPreferredAlignment(convertedType, typeConverter.getDataLayout());
+  align = LLVM::TypeToLLVMIRTranslator(llvmContext)
+              .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
   return success();
 }
 
@@ -235,6 +239,10 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
 template <class LoadOrStoreOp>
 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
 public:
+  explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
+                                     bool useVectorAlign)
+      : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
+        useVectorAlignment(useVectorAlign) {}
   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
@@ -251,8 +259,17 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
-      return failure();
+    if (useVectorAlignment) {
+      if (failed(
+              getVectorAlignment(*this->getTypeConverter(), vectorTy, align)))
+        return rewriter.notifyMatchFailure(
+            loadOrStoreOp, "could not resolve vector alignment");
+    } else {
+      if (failed(
+              getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
+        return rewriter.notifyMatchFailure(
+            loadOrStoreOp, "could not resolve memref alignment");
+    }
 
     // Resolve address.
     auto vtype = cast<VectorType>(
@@ -263,12 +280,19 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
                          rewriter);
     return success();
   }
+
+private:
+  const bool useVectorAlignment;
 };
 
 /// Conversion pattern for a vector.gather.
 class VectorGatherOpConversion
     : public ConvertOpToLLVMPattern<vector::GatherOp> {
 public:
+  explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
+                                    bool useVectorAlign)
+      : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
+        useVectorAlignment(useVectorAlign) {}
   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
@@ -289,9 +313,15 @@ class VectorGatherOpConversion
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
-      return rewriter.notifyMatchFailure(gather,
-                                         "could not resolve memref alignment");
+    if (useVectorAlignment) {
+      if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align)))
+        return rewriter.notifyMatchFailure(
+            gather, "could not resolve vector alignment");
+    } else {
+      if (failed(
+              getMemRefAlignment(*this->getTypeConverter(), memRefType, align)))
+        return rewriter.notifyMatchFailure(
+            gather, "could not resolve memref alignment");
     }
 
     // Resolve address.
@@ -308,12 +338,20 @@ class VectorGatherOpConversion
         adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
     return success();
   }
+
+private:
+  const bool useVectorAlignment;
 };
 
 /// Conversion pattern for a vector.scatter.
 class VectorScatterOpConversion
     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
 public:
+  explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
+                                     bool useVectorAlign)
+      : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
+        useVectorAlignment(useVectorAlign) {}
+
   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
@@ -333,9 +371,18 @@ class VectorScatterOpConversion
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
-      return rewriter.notifyMatchFailure(scatter,
-                                         "could not resolve memref alignment");
+
+    // Resolve alignment.
+    unsigned align;
+    if (useVectorAlignment) {
+      if (failed(getVectorAlignment(*this->getTypeConverter(), vType, align)))
+        return rewriter.notifyMatchFailure(
+            scatter, "could not resolve vector alignment");
+    } else {
+      if (failed(
+              getMemRefAlignment(*this->getTypeConverter(), memRefType, align)))
+        return rewriter.notifyMatchFailure(
+            scatter, "could not resolve memref alignment");
     }
 
     // Resolve address.
@@ -351,6 +398,9 @@ class VectorScatterOpConversion
         rewriter.getI32IntegerAttr(align));
     return success();
   }
+
+private:
+  const bool useVectorAlignment;
 };
 
 /// Conversion pattern for a vector.expandload.
@@ -1939,7 +1989,8 @@ void mlir::vector::populateVectorRankReducingFMAPattern(
 /// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::populateVectorToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    bool reassociateFPReductions, bool force32BitVectorIndices) {
+    bool reassociateFPReductions, bool force32BitVectorIndices,
+    bool useVectorAlignment) {
   // This function populates only ConversionPatterns, not RewritePatterns.
   MLIRContext *ctx = converter.getDialect()->getContext();
   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
@@ -1948,18 +1999,23 @@ void mlir::populateVectorToLLVMConversionPatterns(
                VectorExtractElementOpConversion, VectorExtractOpConversion,
                VectorFMAOp1DConversion, VectorInsertElementOpConversion,
                VectorInsertOpConversion, VectorPrintOpConversion,
-               VectorTypeCastOpConversion, VectorScaleOpConversion,
-               VectorLoadStoreConversion<vector::LoadOp>,
-               VectorLoadStoreConversion<vector::MaskedLoadOp>,
-               VectorLoadStoreConversion<vector::StoreOp>,
-               VectorLoadStoreConversion<vector::MaskedStoreOp>,
-               VectorGatherOpConversion, VectorScatterOpConversion,
-               VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
-               VectorSplatOpLowering, VectorSplatNdOpLowering,
-               VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
-               MaskedReductionOpConversion, VectorInterleaveOpLowering,
-               VectorDeinterleaveOpLowering, VectorFromElementsLowering,
-               VectorScalableStepOpLowering>(converter);
+               VectorTypeCastOpConversion, VectorScaleOpConversion>(ctx),
+      patterns.add<VectorLoadStoreConversion<vector::LoadOp>>(
+          ctx, useVectorAlignment),
+      patterns.add<VectorLoadStoreConversion<vector::MaskedLoadOp>>(
+          ctx, useVectorAlignment),
+      patterns.add<VectorLoadStoreConversion<vector::StoreOp>>(
+          ctx, useVectorAlignment),
+      patterns.add<VectorLoadStoreConversion<vector::MaskedStoreOp>>(
+          ctx, useVectorAlignment),
+      patterns.add<VectorGatherOpConversion, VectorScatterOpConversion,
+                   VectorExpandLoadOpConversion,
+                   VectorCompressStoreOpConversion, VectorSplatOpLowering,
+                   VectorSplatNdOpLowering, VectorScalableInsertOpLowering,
+                   VectorScalableExtractOpLowering, MaskedReductionOpConversion,
+                   VectorInterleaveOpLowering, VectorDeinterleaveOpLowering,
+                   VectorFromElementsLowering, VectorScalableStepOpLowering>(
+          converter);
 }
 
 void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..0ee6dce9ee94b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -92,7 +92,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
   populateVectorTransferLoweringPatterns(patterns);
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
   populateVectorToLLVMConversionPatterns(
-      converter, patterns, reassociateFPReductions, force32BitVectorIndices);
+      converter, patterns, reassociateFPReductions, force32BitVectorIndices,
+      useVectorAlignment);
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
 
   // Architecture specific augmentations.



More information about the Mlir-commits mailing list