[Mlir-commits] [mlir] [mlir][vector] Propagate alignment from vector to llvm dialects. (PR #153482)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 25 06:57:16 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Erick Ochoa Lopez (amd-eochoalo)

<details>
<summary>Changes</summary>

Allows alignment to be propagated correctly from vector to LLVM dialect operations.

---
Full diff: https://github.com/llvm/llvm-project/pull/153482.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+19-6) 
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+39) 


``````````diff
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index afc3d1b12ac0d..7d29750ddcf39 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -299,8 +299,9 @@ class VectorGatherOpConversion
     }
 
     // Resolve alignment.
-    unsigned align;
-    if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
+    unsigned align = gather.getAlignment().value_or(0);
+    if (!align &&
+        failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
                                         memRefType, align, useVectorAlignment)))
       return rewriter.notifyMatchFailure(gather, "could not resolve alignment");
 
@@ -354,8 +355,9 @@ class VectorScatterOpConversion
     }
 
     // Resolve alignment.
-    unsigned align;
-    if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
+    unsigned align = scatter.getAlignment().value_or(0);
+    if (!align &&
+        failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
                                         memRefType, align, useVectorAlignment)))
       return rewriter.notifyMatchFailure(scatter,
                                          "could not resolve alignment");
@@ -399,8 +401,14 @@ class VectorExpandLoadOpConversion
     Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
                                      adaptor.getBase(), adaptor.getIndices());
 
+    // From:
+    // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
+    //   The pointer alignment defaults to 1.
+    uint64_t alignment = expand.getAlignment().value_or(1);
+
     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
-        expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
+        expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru(),
+        alignment);
     return success();
   }
 };
@@ -421,8 +429,13 @@ class VectorCompressStoreOpConversion
     Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
                                      adaptor.getBase(), adaptor.getIndices());
 
+    // From:
+    // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
+    //   The pointer alignment defaults to 1.
+    uint64_t alignment = compress.getAlignment().value_or(1);
+
     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
-        compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
+        compress, adaptor.getValueToStore(), ptr, adaptor.getMask(), alignment);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 9b57b1b6fb4c7..5973c2ba2cbd0 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2042,6 +2042,16 @@ func.func @gather_1d_from_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]x
 
 // -----
 
+func.func @gather_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>, %0: index) -> vector<3xf32> {
+  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 {alignment = 8} : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+  return %1 : vector<3xf32>
+}
+
+// CHECK-LABEL: func @gather_with_alignment
+// CHECK: llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.scatter
 //===----------------------------------------------------------------------===//
@@ -2118,6 +2128,17 @@ func.func @scatter_1d_into_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]
 // CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> vector<[4]x!llvm.ptr>, f32
 // CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into vector<[4]x!llvm.ptr>
 
+// -----
+
+func.func @scatter_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>, %0: index) {
+  vector.scatter %arg0[%0][%arg1], %arg2, %arg3 { alignment = 8 } : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+  return
+}
+
+// CHECK-LABEL: func @scatter_with_alignment
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+
+
 // -----
 
 //===----------------------------------------------------------------------===//
@@ -2149,6 +2170,15 @@ func.func @expand_load_op_index(%arg0: memref<?xindex>, %arg1: vector<11xi1>, %a
 
 // -----
 
+func.func @expand_load_op_with_alignment(%arg0: memref<?xindex>, %arg1: vector<11xi1>, %arg2: vector<11xindex>, %c0: index) -> vector<11xindex> {
+  %0 = vector.expandload %arg0[%c0], %arg1, %arg2 { alignment = 8 } : memref<?xindex>, vector<11xi1>, vector<11xindex> into vector<11xindex>
+  return %0 : vector<11xindex>
+}
+// CHECK-LABEL: func @expand_load_op_with_alignment
+// CHECK: %{{.*}} = "llvm.intr.masked.expandload"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 8 : i64}, {}, {}]}> : (!llvm.ptr, vector<11xi1>, vector<11xi64>) -> vector<11xi64>
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.compressstore
 //===----------------------------------------------------------------------===//
@@ -2177,6 +2207,15 @@ func.func @compress_store_op_index(%arg0: memref<?xindex>, %arg1: vector<11xi1>,
 
 // -----
 
+func.func @compress_store_op_with_alignment(%arg0: memref<?xindex>, %arg1: vector<11xi1>, %arg2: vector<11xindex>, %c0: index) {
+  vector.compressstore %arg0[%c0], %arg1, %arg2 { alignment = 8 } : memref<?xindex>, vector<11xi1>, vector<11xindex>
+  return
+}
+// CHECK-LABEL: func @compress_store_op_with_alignment
+// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{}, {llvm.align = 8 : i64}, {}]}> : (vector<11xi64>, !llvm.ptr, vector<11xi1>) -> ()
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.splat
 //===----------------------------------------------------------------------===//

``````````

</details>


https://github.com/llvm/llvm-project/pull/153482


More information about the Mlir-commits mailing list