[Mlir-commits] [mlir] [mlir][vector][memref] Add `alignment` attribute to memory access ops (PR #144344)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 4 11:33:20 PDT 2025


https://github.com/tyb0807 updated https://github.com/llvm/llvm-project/pull/144344

>From 5adae03488c64aec35024f928ea2b8caf530368e Mon Sep 17 00:00:00 2001
From: tyb0807 <sontuan.vu119 at gmail.com>
Date: Mon, 16 Jun 2025 14:45:40 +0200
Subject: [PATCH] [mlir][memref][vector] Add alignment attribute to
 memref/vector.load/store

Alignment information is important to allow LLVM backends such as AMDGPU
to select wide memory accesses (e.g., dwordx4 or b128). Since this info
is not always inferable, it's better to inform LLVM backends explicitly
about it. Furthermore, alignment is not necessarily a property of the
element type, but of each individual memory access op (we can have
overaligned and underaligned accesses compared to the natural/preferred
alignment of the element type).

This patch introduces `alignment` attribute to memref/vector.load/store
ops.

Follow-up PRs will

1. Introduce `alignment` attribute to other vector memory access ops:
    vector.gather + vector.scatter
    vector.transfer_read + vector.transfer_write
    vector.compressstore + vector.expandload
    vector.maskedload + vector.maskedstore

2. Propagate these attributes to LLVM/SPIR-V.

3. Replace `--convert-vector-to-llvm='use-vector-alignment=1` with a
   simple pass to populate alignment attributes based on the vector
   types.

4. Retire `memref.assume_alignment` op.
---
 mlir/docs/DefiningDialects/Operations.md      |  2 +
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       | 60 ++++++++++++++++++-
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 55 ++++++++++++++++-
 mlir/include/mlir/IR/CommonAttrConstraints.td |  4 ++
 mlir/test/Dialect/MemRef/invalid.mlir         | 18 ++++++
 mlir/test/Dialect/MemRef/ops.mlir             | 12 ++++
 mlir/test/Dialect/Vector/invalid.mlir         | 18 ++++++
 mlir/test/Dialect/Vector/ops.mlir             | 10 ++++
 8 files changed, 173 insertions(+), 6 deletions(-)

diff --git a/mlir/docs/DefiningDialects/Operations.md b/mlir/docs/DefiningDialects/Operations.md
index b3bde055f04f0..2225329ff830b 100644
--- a/mlir/docs/DefiningDialects/Operations.md
+++ b/mlir/docs/DefiningDialects/Operations.md
@@ -306,6 +306,8 @@ Right now, the following primitive constraints are supported:
 *   `IntPositive`: Specifying an integer attribute whose value is positive
 *   `IntNonNegative`: Specifying an integer attribute whose value is
     non-negative
+*   `IntPowerOf2`: Specifying an integer attribute whose value is a power of
+    two > 0
 *   `ArrayMinCount<N>`: Specifying an array attribute to have at least `N`
     elements
 *   `ArrayMaxCount<N>`: Specifying an array attribute to have at most `N`
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index b0fb5b0785142..2da85ff508c0a 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1216,6 +1216,11 @@ def LoadOp : MemRef_Op<"load",
     be reused in the cache. For details, refer to the
     [https://llvm.org/docs/LangRef.html#load-instruction](LLVM load instruction).
 
+    An optional `alignment` attribute allows to specify the byte alignment of the
+    load operation. It must be a positive power of 2. The operation must access
+    memory at an address aligned to this boundary. Violations may lead to
+    architecture-specific faults or performance penalties.
+    A value of 0 indicates no specific alignment requirement.
     Example:
 
     ```mlir
@@ -1226,7 +1231,39 @@ def LoadOp : MemRef_Op<"load",
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
                            [MemRead]>:$memref,
                        Variadic<Index>:$indices,
-                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+                       ConfinedAttr<OptionalAttr<I64Attr>,
+                                    [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
+
+  let builders = [
+    OpBuilder<(ins "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"bool", "false">:$nontemporal,
+                   CArg<"uint64_t", "0">:$alignment), [{
+      return build($_builder, $_state, memref, indices, nontemporal,
+                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                                    nullptr);
+    }]>,
+    OpBuilder<(ins "Type":$resultType,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"bool", "false">:$nontemporal,
+                   CArg<"uint64_t", "0">:$alignment), [{
+      return build($_builder, $_state, resultType, memref, indices, nontemporal,
+                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                                    nullptr);
+    }]>,
+    OpBuilder<(ins "TypeRange":$resultTypes,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"bool", "false">:$nontemporal,
+                   CArg<"uint64_t", "0">:$alignment), [{
+      return build($_builder, $_state, resultTypes, memref, indices, nontemporal,
+                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                                    nullptr);
+    }]>
+  ];
+
   let results = (outs AnyType:$result);
 
   let extraClassDeclaration = [{
@@ -1912,6 +1949,11 @@ def MemRef_StoreOp : MemRef_Op<"store",
     be reused in the cache. For details, refer to the
     [https://llvm.org/docs/LangRef.html#store-instruction](LLVM store instruction).
 
+    An optional `alignment` attribute allows to specify the byte alignment of the
+    store operation. It must be a positive power of 2. The operation must access
+    memory at an address aligned to this boundary. Violations may lead to
+    architecture-specific faults or performance penalties.
+    A value of 0 indicates no specific alignment requirement.
     Example:
 
     ```mlir
@@ -1923,13 +1965,25 @@ def MemRef_StoreOp : MemRef_Op<"store",
                        Arg<AnyMemRef, "the reference to store to",
                            [MemWrite]>:$memref,
                        Variadic<Index>:$indices,
-                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+                       ConfinedAttr<OptionalAttr<I64Attr>,
+                                    [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
 
   let builders = [
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"bool", "false">:$nontemporal,
+                   CArg<"uint64_t", "0">:$alignment), [{
+      return build($_builder, $_state, valueToStore, memref, indices, nontemporal,
+                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                                    nullptr);
+    }]>,
     OpBuilder<(ins "Value":$valueToStore, "Value":$memref), [{
       $_state.addOperands(valueToStore);
       $_state.addOperands(memref);
-    }]>];
+    }]>
+  ];
 
   let extraClassDeclaration = [{
       Value getValueToStore() { return getOperand(0); }
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index ec2c87ca1cf44..a6fe52af575c9 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1809,12 +1809,42 @@ def Vector_LoadOp : Vector_Op<"load", [
     ```mlir
     %result = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
     ```
+
+    An optional `alignment` attribute allows to specify the byte alignment of the
+    load operation. It must be a positive power of 2. The operation must access
+    memory at an address aligned to this boundary. Violations may lead to
+    architecture-specific faults or performance penalties.
+    A value of 0 indicates no specific alignment requirement.
   }];
 
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
       [MemRead]>:$base,
       Variadic<Index>:$indices,
-      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+      ConfinedAttr<OptionalAttr<I64Attr>,
+                   [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
+
+  let builders = [
+    OpBuilder<(ins "VectorType":$resultType,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   CArg<"bool", "false">:$nontemporal,
+                   CArg<"uint64_t", "0">:$alignment), [{
+      return build($_builder, $_state, resultType, base, indices, nontemporal,
+                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                                    nullptr);
+    }]>,
+    OpBuilder<(ins "TypeRange":$resultTypes,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   CArg<"bool", "false">:$nontemporal,
+                   CArg<"uint64_t", "0">:$alignment), [{
+      return build($_builder, $_state, resultTypes, base, indices, nontemporal,
+                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                                    nullptr);
+    }]>
+  ];
+
   let results = (outs AnyVectorOfAnyRank:$result);
 
   let extraClassDeclaration = [{
@@ -1895,6 +1925,12 @@ def Vector_StoreOp : Vector_Op<"store", [
     ```mlir
     vector.store %valueToStore, %memref[%c0] : memref<7xf32>, vector<8xf32>
     ```
+
+    An optional `alignment` attribute allows to specify the byte alignment of the
+    store operation. It must be a positive power of 2. The operation must access
+    memory at an address aligned to this boundary. Violations may lead to
+    architecture-specific faults or performance penalties.
+    A value of 0 indicates no specific alignment requirement.
   }];
 
   let arguments = (ins
@@ -1902,8 +1938,21 @@ def Vector_StoreOp : Vector_Op<"store", [
       Arg<AnyMemRef, "the reference to store to",
       [MemWrite]>:$base,
       Variadic<Index>:$indices,
-      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal
-  );
+      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+      ConfinedAttr<OptionalAttr<I64Attr>,
+                   [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
+
+  let builders = [
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   CArg<"bool", "false">:$nontemporal,
+                   CArg<"uint64_t", "0">:$alignment), [{
+      return build($_builder, $_state, valueToStore, base, indices, nontemporal,
+                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                                    nullptr);
+    }]>
+  ];
 
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index e91a13fea5c7f..18da85a580710 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -796,6 +796,10 @@ def IntPositive : AttrConstraint<
     CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isStrictlyPositive()">,
     "whose value is positive">;
 
+def IntPowerOf2 : AttrConstraint<
+    CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isPowerOf2()">,
+    "whose value is a power of two > 0">;
+
 class ArrayMaxCount<int n> : AttrConstraint<
     CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() <= " # n>,
     "with at most " # n # " elements">;
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 704cdaf838f45..34e53c3963251 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -1139,3 +1139,21 @@ func.func @expand_shape_invalid_output_shape(
       into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>
   return
 }
+
+// -----
+
+func.func @test_invalid_negative_load_alignment(%memref: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error @below {{'memref.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  %val = memref.load %memref[%c0] { alignment = -1 } : memref<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_non_power_of_2_store_alignment(%memref: memref<4xi32>, %val: i32) {
+  %c0 = arith.constant 0 : index
+  // expected-error @below {{'memref.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  memref.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>
+  return
+}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index e11de7bec2d0a..abdba0eb1ae35 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -613,3 +613,15 @@ func.func @memref_transpose_map(%src : memref<?x?xf32>) -> memref<?x?xf32, affin
   %dst = memref.transpose %src (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
   return %dst : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
 }
+
+// -----
+
+// CHECK-LABEL: func @test_load_store_alignment
+// CHECK: memref.load {{.*}} {alignment = 16 : i64}
+// CHECK: memref.store {{.*}} {alignment = 16 : i64}
+func.func @test_load_store_alignment(%memref: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
+  %val = memref.load %memref[%c0] { alignment = 16 } : memref<4xi32>
+  memref.store %val, %memref[%c0] { alignment = 16 } : memref<4xi32>
+  return
+}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5038646e1f026..c98db5edf17d5 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -2005,3 +2005,21 @@ func.func @vector_store(%dest : memref<?xi8>, %vec : vector<16x16xi8>) {
   vector.store %vec, %dest[%c0] : memref<?xi8>, vector<16x16xi8>
   return
 }
+
+// -----
+
+func.func @test_invalid_negative_load_alignment(%memref: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error @below {{'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  %val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_non_power_of_2_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error @below {{'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
+  return
+}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 10bf0f1620568..8c9a2a28fcc34 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1218,3 +1218,13 @@ func.func @step() {
   %1 = vector.step : vector<[4]xindex>
   return
 }
+
+// CHECK-LABEL: func @test_load_store_alignment
+func.func @test_load_store_alignment(%memref: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
+  // CHECK: vector.load {{.*}} {alignment = 16 : i64}
+  %val = vector.load %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32>
+  // CHECK: vector.store {{.*}} {alignment = 16 : i64}
+  vector.store %val, %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32>
+  return
+}



More information about the Mlir-commits mailing list