[Mlir-commits] [mlir] 3715de9 - [mlir] Add use-vector-alignment flag to ConvertVectorToLLVMPass (#137389)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 2 12:54:29 PDT 2025


Author: Lily Orth-Smith
Date: 2025-05-02T12:54:26-07:00
New Revision: 3715de976e41654db83a0aef5be50790024639ea

URL: https://github.com/llvm/llvm-project/commit/3715de976e41654db83a0aef5be50790024639ea
DIFF: https://github.com/llvm/llvm-project/commit/3715de976e41654db83a0aef5be50790024639ea.diff

LOG: [mlir] Add use-vector-alignment flag to ConvertVectorToLLVMPass (#137389)

In ConvertVectorToLLVM, the only option for setting alignment of
`vector.gather`, `vector.scatter`, and the `vector.load/store` ops was
to extract it from the datatype of the memref type. However, this is
insufficient for hardware backends requiring alignment of vector types.
This PR introduces the `use-vector-alignment` option to the
`ConvertVectorToLLVMPass`, which makes the pass use the alignment of the
vector type of these operations instead of the alignment of the memref
type.

---------

Co-authored-by: Lily Orth-Smith <lorthsmith at microsoft.com>

Added: 
    mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
    mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
    mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index ed88ada9778d2..f7b0112b32a33 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1411,6 +1411,13 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            "bool", /*default=*/"true",
            "Allows compiler to assume vector indices fit in 32-bit if that "
      "yields faster code">,
+    Option<"useVectorAlignment", "use-vector-alignment",
+           "bool", /*default=*/"false",
+           "Use the preferred alignment of a vector type in load/store "
+           "operations instead of the alignment of the element type of the "
+           "memref. This flag is intended for use with hardware which requires" 
+           "vector alignment, or in application contexts where it is known all "
+           "vector access are naturally aligned. ">,
     Option<"amx", "enable-amx",
            "bool", /*default=*/"false",
            "Enables the use of AMX dialect while lowering the vector "

diff  --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index 1e29bfeb9c392..f6b09deb4e44c 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -22,7 +22,8 @@ void populateVectorToLLVMMatrixConversionPatterns(
 /// Collect a set of patterns to convert from the Vector dialect to LLVM.
 void populateVectorToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    bool reassociateFPReductions = false, bool force32BitVectorIndices = false);
+    bool reassociateFPReductions = false, bool force32BitVectorIndices = false,
+    bool useVectorAlignment = false);
 
 namespace vector {
 void registerConvertVectorToLLVMInterface(DialectRegistry &registry);

diff  --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 14cbbac99d9ae..299f198e4ab9c 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -32,7 +32,8 @@ def ApplyVectorToLLVMConversionPatternsOp : Op<Transform_Dialect,
 
   let arguments = (ins
       DefaultValuedAttr<BoolAttr, "false">:$reassociate_fp_reductions,
-      DefaultValuedAttr<BoolAttr, "true">:$force_32bit_vector_indices);
+      DefaultValuedAttr<BoolAttr, "true">:$force_32bit_vector_indices,
+      DefaultValuedAttr<BoolAttr, "false">:$use_vector_alignment);
   let assemblyFormat = "attr-dict";
 }
 

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 076e5512f375b..5296013189b9e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -67,6 +67,21 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
   return rewriter.create<LLVM::ExtractValueOp>(loc, val, pos);
 }
 
+// Helper that returns data layout alignment of a vector.
+LogicalResult getVectorAlignment(const LLVMTypeConverter &typeConverter,
+                                 VectorType vectorType, unsigned &align) {
+  Type convertedVectorTy = typeConverter.convertType(vectorType);
+  if (!convertedVectorTy)
+    return failure();
+
+  llvm::LLVMContext llvmContext;
+  align = LLVM::TypeToLLVMIRTranslator(llvmContext)
+              .getPreferredAlignment(convertedVectorTy,
+                                     typeConverter.getDataLayout());
+
+  return success();
+}
+
 // Helper that returns data layout alignment of a memref.
 LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
                                  MemRefType memrefType, unsigned &align) {
@@ -82,6 +97,28 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
   return success();
 }
 
+// Helper to resolve the alignment for vector load/store, gather and scatter
+// ops. If useVectorAlignment is true, get the preferred alignment for the
+// vector type in the operation. This option is used for hardware backends with
+// vectorization. Otherwise, use the preferred alignment of the element type of
+// the memref. Note that if you choose to use vector alignment, the shape of the
+// vector type must be resolved before the ConvertVectorToLLVM pass is run.
+LogicalResult getVectorToLLVMAlignment(const LLVMTypeConverter &typeConverter,
+                                       VectorType vectorType,
+                                       MemRefType memrefType, unsigned &align,
+                                       bool useVectorAlignment) {
+  if (useVectorAlignment) {
+    if (failed(getVectorAlignment(typeConverter, vectorType, align))) {
+      return failure();
+    }
+  } else {
+    if (failed(getMemRefAlignment(typeConverter, memrefType, align))) {
+      return failure();
+    }
+  }
+  return success();
+}
+
 // Check if the last stride is non-unit and has a valid memory space.
 static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
                                            const LLVMTypeConverter &converter) {
@@ -224,6 +261,10 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp,
 template <class LoadOrStoreOp>
 class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
 public:
+  explicit VectorLoadStoreConversion(const LLVMTypeConverter &typeConv,
+                                     bool useVectorAlign)
+      : ConvertOpToLLVMPattern<LoadOrStoreOp>(typeConv),
+        useVectorAlignment(useVectorAlign) {}
   using ConvertOpToLLVMPattern<LoadOrStoreOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
@@ -240,8 +281,10 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*this->getTypeConverter(), memRefTy, align)))
-      return failure();
+    if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
+                                        memRefTy, align, useVectorAlignment)))
+      return rewriter.notifyMatchFailure(loadOrStoreOp,
+                                         "could not resolve alignment");
 
     // Resolve address.
     auto vtype = cast<VectorType>(
@@ -252,12 +295,23 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
                          rewriter);
     return success();
   }
+
+private:
+  // If true, use the preferred alignment of the vector type.
+  // If false, use the preferred alignment of the element type
+  // of the memref. This flag is intended for use with hardware
+  // backends that require alignment of vector operations.
+  const bool useVectorAlignment;
 };
 
 /// Conversion pattern for a vector.gather.
 class VectorGatherOpConversion
     : public ConvertOpToLLVMPattern<vector::GatherOp> {
 public:
+  explicit VectorGatherOpConversion(const LLVMTypeConverter &typeConv,
+                                    bool useVectorAlign)
+      : ConvertOpToLLVMPattern<vector::GatherOp>(typeConv),
+        useVectorAlignment(useVectorAlign) {}
   using ConvertOpToLLVMPattern<vector::GatherOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
@@ -278,10 +332,9 @@ class VectorGatherOpConversion
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
-      return rewriter.notifyMatchFailure(gather,
-                                         "could not resolve memref alignment");
-    }
+    if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
+                                        memRefType, align, useVectorAlignment)))
+      return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
 
     // Resolve address.
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
@@ -297,12 +350,24 @@ class VectorGatherOpConversion
         adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
     return success();
   }
+
+private:
+  // If true, use the preferred alignment of the vector type.
+  // If false, use the preferred alignment of the element type
+  // of the memref. This flag is intended for use with hardware
+  // backends that require alignment of vector operations.
+  const bool useVectorAlignment;
 };
 
 /// Conversion pattern for a vector.scatter.
 class VectorScatterOpConversion
     : public ConvertOpToLLVMPattern<vector::ScatterOp> {
 public:
+  explicit VectorScatterOpConversion(const LLVMTypeConverter &typeConv,
+                                     bool useVectorAlign)
+      : ConvertOpToLLVMPattern<vector::ScatterOp>(typeConv),
+        useVectorAlignment(useVectorAlign) {}
+
   using ConvertOpToLLVMPattern<vector::ScatterOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
@@ -322,10 +387,10 @@ class VectorScatterOpConversion
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align))) {
+    if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
+                                        memRefType, align, useVectorAlignment)))
       return rewriter.notifyMatchFailure(scatter,
-                                         "could not resolve memref alignment");
-    }
+                                         "could not resolve alignment");
 
     // Resolve address.
     Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
@@ -340,6 +405,13 @@ class VectorScatterOpConversion
         rewriter.getI32IntegerAttr(align));
     return success();
   }
+
+private:
+  // If true, use the preferred alignment of the vector type.
+  // If false, use the preferred alignment of the element type
+  // of the memref. This flag is intended for use with hardware
+  // backends that require alignment of vector operations.
+  const bool useVectorAlignment;
 };
 
 /// Conversion pattern for a vector.expandload.
@@ -1928,21 +2000,23 @@ void mlir::vector::populateVectorRankReducingFMAPattern(
 /// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::populateVectorToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    bool reassociateFPReductions, bool force32BitVectorIndices) {
+    bool reassociateFPReductions, bool force32BitVectorIndices,
+    bool useVectorAlignment) {
   // This function populates only ConversionPatterns, not RewritePatterns.
   MLIRContext *ctx = converter.getDialect()->getContext();
   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
   patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
+  patterns.add<VectorLoadStoreConversion<vector::LoadOp>,
+               VectorLoadStoreConversion<vector::MaskedLoadOp>,
+               VectorLoadStoreConversion<vector::StoreOp>,
+               VectorLoadStoreConversion<vector::MaskedStoreOp>,
+               VectorGatherOpConversion, VectorScatterOpConversion>(
+      converter, useVectorAlignment);
   patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
                VectorExtractElementOpConversion, VectorExtractOpConversion,
                VectorFMAOp1DConversion, VectorInsertElementOpConversion,
                VectorInsertOpConversion, VectorPrintOpConversion,
                VectorTypeCastOpConversion, VectorScaleOpConversion,
-               VectorLoadStoreConversion<vector::LoadOp>,
-               VectorLoadStoreConversion<vector::MaskedLoadOp>,
-               VectorLoadStoreConversion<vector::StoreOp>,
-               VectorLoadStoreConversion<vector::MaskedStoreOp>,
-               VectorGatherOpConversion, VectorScatterOpConversion,
                VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
                VectorSplatOpLowering, VectorSplatNdOpLowering,
                VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..0ee6dce9ee94b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -92,7 +92,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
   populateVectorTransferLoweringPatterns(patterns);
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
   populateVectorToLLVMConversionPatterns(
-      converter, patterns, reassociateFPReductions, force32BitVectorIndices);
+      converter, patterns, reassociateFPReductions, force32BitVectorIndices,
+      useVectorAlignment);
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
 
   // Architecture specific augmentations.

diff  --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 1be436dd7bf41..125c3d918284c 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -34,7 +34,8 @@ void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns(
     TypeConverter &typeConverter, RewritePatternSet &patterns) {
   populateVectorToLLVMConversionPatterns(
       static_cast<LLVMTypeConverter &>(typeConverter), patterns,
-      getReassociateFpReductions(), getForce_32bitVectorIndices());
+      getReassociateFpReductions(), getForce_32bitVectorIndices(),
+      getUseVectorAlignment());
 }
 
 LogicalResult

diff  --git a/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
new file mode 100644
index 0000000000000..3fa248656cf3a
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
@@ -0,0 +1,102 @@
+// RUN: mlir-opt %s --convert-vector-to-llvm='use-vector-alignment=0' --split-input-file | FileCheck %s --check-prefix=MEMREF-ALIGN
+// RUN: mlir-opt %s --convert-vector-to-llvm='use-vector-alignment=1' --split-input-file | FileCheck %s --check-prefix=VEC-ALIGN
+
+
+//===----------------------------------------------------------------------===//
+// vector.load
+//===----------------------------------------------------------------------===//
+
+func.func @load(%base : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
+  %0 = vector.load %base[%i, %j] : memref<200x100xf32>, vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// ALL-LABEL: func @load
+
+// VEC-ALIGN: llvm.load %{{.*}} {alignment = 32 : i64} : !llvm.ptr -> vector<8xf32>
+// MEMREF-ALIGN: llvm.load %{{.*}} {alignment = 4 : i64} : !llvm.ptr -> vector<8xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.store
+//===----------------------------------------------------------------------===//
+
+func.func @store(%base : memref<200x100xf32>, %i : index, %j : index) {
+  %val = arith.constant dense<11.0> : vector<4xf32>
+  vector.store %val, %base[%i, %j] : memref<200x100xf32>, vector<4xf32>
+  return
+}
+
+// ALL-LABEL: func @store
+
+// VEC-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 16 : i64} :  vector<4xf32>, !llvm.ptr
+// MEMREF-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 4 : i64} :  vector<4xf32>, !llvm.ptr
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.maskedload
+//===----------------------------------------------------------------------===//
+
+func.func @masked_load(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
+  %c0 = arith.constant 0: index
+  %0 = vector.maskedload %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  return %0 : vector<16xf32>
+}
+
+// ALL-LABEL: func @masked_load
+
+// VEC-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 64 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+// MEMREF-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.maskedstore
+//===----------------------------------------------------------------------===//
+
+func.func @masked_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
+  %c0 = arith.constant 0: index
+  vector.maskedstore %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32>
+  return
+}
+
+// ALL-LABEL: func @masked_store
+
+// VEC-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 64 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
+// MEMREF-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.scatter
+//===----------------------------------------------------------------------===//
+
+func.func @scatter(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %value: vector<3xf32>) {
+  %0 = arith.constant 0: index
+  vector.scatter %base[%0][%index], %mask, %value : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+  return
+}
+
+// ALL-LABEL: func @scatter
+
+// VEC-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 16 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+// MEMREF-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.gather
+//===----------------------------------------------------------------------===//
+
+func.func @gather(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %passthru: vector<3xf32>) -> vector<3xf32> {
+  %0 = arith.constant 0: index
+  %1 = vector.gather %base[%0][%index], %mask, %passthru : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+  return %1 : vector<3xf32>
+}
+
+// ALL-LABEL: func @gather
+
+// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 16 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// MEMREF-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>


        


More information about the Mlir-commits mailing list