[Mlir-commits] [mlir] [mlir][vector] Add alignment attribute to vector operations. (PR #152507)

Erick Ochoa Lopez llvmlistbot at llvm.org
Thu Aug 7 08:19:22 PDT 2025


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/152507

>From 394c64fa315069c9db524ddc358b8dbe5aae72cf Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 6 Aug 2025 16:34:02 -0700
Subject: [PATCH 1/5] Use llvm::Align

---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.h  |  1 +
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 12 ++++++------
 2 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 364c1728715e8..63410b8bea747 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -32,6 +32,7 @@
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/Alignment.h"
 
 // Pull in all enum type definitions and utility function declarations.
 #include "mlir/Dialect/Vector/IR/VectorEnums.h.inc"
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index dc55704c36183..eeedf68a1df7c 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1729,18 +1729,18 @@ def Vector_LoadOp : Vector_Op<"load", [
                    "Value":$base,
                    "ValueRange":$indices,
                    CArg<"bool", "false">:$nontemporal,
-                   CArg<"uint64_t", "0">:$alignment), [{
+                   CArg<"llvm::Align", "llvm::Align()">:$alignment), [{
       return build($_builder, $_state, resultType, base, indices, nontemporal,
-                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                   alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
                                     nullptr);
     }]>,
     OpBuilder<(ins "TypeRange":$resultTypes,
                    "Value":$base,
                    "ValueRange":$indices,
                    CArg<"bool", "false">:$nontemporal,
-                   CArg<"uint64_t", "0">:$alignment), [{
+                   CArg<"llvm::Align", "llvm::Align()">:$alignment), [{
       return build($_builder, $_state, resultTypes, base, indices, nontemporal,
-                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                   alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
                                     nullptr);
     }]>
   ];
@@ -1847,9 +1847,9 @@ def Vector_StoreOp : Vector_Op<"store", [
                    "Value":$base,
                    "ValueRange":$indices,
                    CArg<"bool", "false">:$nontemporal,
-                   CArg<"uint64_t", "0">:$alignment), [{
+                   CArg<"llvm::Align", "llvm::Align()">:$alignment), [{
       return build($_builder, $_state, valueToStore, base, indices, nontemporal,
-                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                   alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
                                     nullptr);
     }]>
   ];

>From da405541dd57a50b38a9cf5a862b1cddc1bc7d19 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 6 Aug 2025 19:48:16 -0700
Subject: [PATCH 2/5] [mlir][vector] Add alignment to vector.gather.

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 29 ++++++++++++++++++-
 mlir/test/Dialect/Vector/invalid.mlir         | 18 ++++++++++++
 2 files changed, 46 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index eeedf68a1df7c..f2b55be19e192 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2003,7 +2003,9 @@ def Vector_GatherOp :
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
                VectorOfNonZeroRankOf<[I1]>:$mask,
-               AnyVectorOfNonZeroRank:$pass_thru)>,
+               AnyVectorOfNonZeroRank:$pass_thru,
+               ConfinedAttr<OptionalAttr<I64Attr>,
+                   [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
     Results<(outs AnyVectorOfNonZeroRank:$result)> {
 
   let summary = [{
@@ -2060,6 +2062,31 @@ def Vector_GatherOp :
     "`into` type($result)";
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
+
+  let builders = [
+    OpBuilder<(ins "VectorType":$resultType,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   "Value":$index_vec,
+                   "Value":$mask,
+                   "Value":$passthrough,
+                   CArg<"llvm::Align", "llvm::Align()">:$alignment), [{
+      return build($_builder, $_state, resultType, base, indices, index_vec, mask, passthrough,
+                   alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
+                                    nullptr);
+    }]>,
+    OpBuilder<(ins "TypeRange":$resultTypes,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   "Value":$index_vec,
+                   "Value":$mask,
+                   "Value":$passthrough,
+                   CArg<"llvm::Align", "llvm::Align()">:$alignment), [{
+      return build($_builder, $_state, resultTypes, base, indices, index_vec, mask, passthrough,
+                   alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
+                                    nullptr);
+    }]>
+  ];
 }
 
 def Vector_ScatterOp :
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c21de562d05e1..6a4c7a5623a43 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1430,6 +1430,24 @@ func.func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector
 
 // -----
 
+func.func @gather_invalid_alignment(%base: memref<16xf32>, %indices: vector<16xi32>,
+                                %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) {
+  // expected-error at +2 {{'vector.gather' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
+    { alignment = -1 } : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+func.func @gather_invalid_alignment(%base: memref<16xf32>, %indices: vector<16xi32>,
+                                %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) {
+  // expected-error at +2 {{'vector.gather' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
+    { alignment = 3 } : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
 func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
                              %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index

>From 12eead5b3dd6496afe11653be4d03e75c39f7356 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 6 Aug 2025 19:58:43 -0700
Subject: [PATCH 3/5] [mlir][vector] Add alignment to vector.scatter

---
 .../mlir/Dialect/Vector/IR/VectorOps.td        | 17 ++++++++++++++++-
 mlir/test/Dialect/Vector/invalid.mlir          | 18 ++++++++++++++++++
 2 files changed, 34 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index f2b55be19e192..1ad527d5619d3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2095,7 +2095,9 @@ def Vector_ScatterOp :
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
                VectorOfNonZeroRankOf<[I1]>:$mask,
-               AnyVectorOfNonZeroRank:$valueToStore)> {
+               AnyVectorOfNonZeroRank:$valueToStore,
+               ConfinedAttr<OptionalAttr<I64Attr>,
+                   [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> {
 
   let summary = [{
     scatters elements from a vector into memory as defined by an index vector
@@ -2153,6 +2155,19 @@ def Vector_ScatterOp :
       "type($index_vec)  `,` type($mask) `,` type($valueToStore)";
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
+
+  let builders = [
+    OpBuilder<(ins "Value":$base,
+                   "ValueRange":$indices,
+                   "Value":$index_vec,
+                   "Value":$mask,
+                   "Value":$valueToStore,
+                   CArg<"llvm::Align", "llvm::Align()">: $alignment), [{
+      return build($_builder, $_state, base, indices, index_vec, mask, valueToStore,
+                   alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
+                                    nullptr);
+    }]>
+  ];
 }
 
 def Vector_ExpandLoadOp :
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 6a4c7a5623a43..c5f5c2694b6ff 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1509,6 +1509,24 @@ func.func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi
 
 // -----
 
+func.func @scatter_invalid_alignment(%base: memref<?xf32>, %indices: vector<16xi32>,
+                                %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
+  // expected-error at +1 {{'vector.scatter' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  vector.scatter %base[%c0][%indices], %mask, %value { alignment = -1 }
+    : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+}
+
+// -----
+
+func.func @scatter_invalid_alignment(%base: memref<?xf32>, %indices: vector<16xi32>,
+                                %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
+  // expected-error at +1 {{'vector.scatter' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  vector.scatter %base[%c0][%indices], %mask, %value { alignment = 3 }
+    : memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+}
+
+// -----
+
 func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
   // expected-error at +1 {{'vector.expandload' op base and result element type should match}}

>From 2b512366d25a7e417dd4b5f5e1ea46b320d481ae Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 6 Aug 2025 20:48:21 -0700
Subject: [PATCH 4/5] [mlir][vector] Add alignment to compressstore

---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 15 ++++++++++++++-
 mlir/test/Dialect/Vector/invalid.mlir            | 14 ++++++++++++++
 2 files changed, 28 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 1ad527d5619d3..d3c3260ae4484 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2244,7 +2244,9 @@ def Vector_CompressStoreOp :
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
                FixedVectorOfNonZeroRankOf<[I1]>:$mask,
-               AnyVectorOfNonZeroRank:$valueToStore)> {
+               AnyVectorOfNonZeroRank:$valueToStore,
+               ConfinedAttr<OptionalAttr<I64Attr>,
+                            [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> {
 
   let summary = "writes elements selectively from a vector as defined by a mask";
 
@@ -2303,6 +2305,17 @@ def Vector_CompressStoreOp :
       "type($base) `,` type($mask) `,` type($valueToStore)";
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
+  let builders = [
+    OpBuilder<(ins "Value":$base,
+                   "ValueRange":$indices,
+                   "Value":$mask,
+                   "Value":$valueToStore,
+                   CArg<"llvm::Align", "llvm::Align()">:$alignment), [{
+      return build($_builder, $_state, base, indices, valueToStore, mask,
+                   alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
+                                    nullptr);
+    }]>
+  ];
 }
 
 def Vector_ShapeCastOp :
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c5f5c2694b6ff..23c908a95616f 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1599,6 +1599,20 @@ func.func @compress_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>
 
 // -----
 
+func.func @compress_invalid_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
+  // expected-error @below {{'vector.compressstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  vector.compressstore %base[%c0], %mask, %value { alignment = -1 } : memref<?xf32>, vector<16xi1>, vector<16xf32>
+}
+
+// -----
+
+func.func @compress_invalid_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
+  // expected-error @below {{'vector.compressstore' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  vector.compressstore %base[%c0], %mask, %value { alignment = 3 } : memref<?xf32>, vector<16xi1>, vector<16xf32>
+}
+
+// -----
+
 func.func @scan_reduction_dim_constraint(%arg0: vector<2x3xi32>, %arg1: vector<3xi32>) -> vector<3xi32> {
   // expected-error at +1 {{'vector.scan' op reduction dimension 5 has to be less than 2}}
   %0:2 = vector.scan <add>, %arg0, %arg1 {inclusive = true, reduction_dim = 5} :

>From 5a47ffff57cf05ed0b86ca232f3aa23b9c168111 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 6 Aug 2025 20:57:52 -0700
Subject: [PATCH 5/5] [mlir][vector] Add alignment to expandload

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 27 ++++++++++++++++++-
 mlir/test/Dialect/Vector/invalid.mlir         | 14 ++++++++++
 2 files changed, 40 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index d3c3260ae4484..62fac6ef2ea36 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2175,7 +2175,9 @@ def Vector_ExpandLoadOp :
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                FixedVectorOfNonZeroRankOf<[I1]>:$mask,
-               AnyVectorOfNonZeroRank:$pass_thru)>,
+               AnyVectorOfNonZeroRank:$pass_thru,
+               ConfinedAttr<OptionalAttr<I64Attr>,
+                            [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
     Results<(outs AnyVectorOfNonZeroRank:$result)> {
 
   let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
@@ -2237,6 +2239,29 @@ def Vector_ExpandLoadOp :
     "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
+
+  let builders = [
+    OpBuilder<(ins "VectorType":$resultType,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   "Value":$mask,
+                   "Value":$passthrough,
+                   CArg<"llvm::Align", "llvm::Align()">:$alignment), [{
+      return build($_builder, $_state, resultType, base, indices, mask, passthrough,
+                   alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
+                                    nullptr);
+    }]>,
+    OpBuilder<(ins "TypeRange":$resultTypes,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   "Value":$mask,
+                   "Value":$passthrough,
+                   CArg<"llvm::Align", "llvm::Align()">:$alignment), [{
+      return build($_builder, $_state, resultTypes, base, indices, mask, passthrough,
+                   alignment != llvm::Align() ? $_builder.getI64IntegerAttr(alignment.value()) :
+                                    nullptr);
+    }]>
+  ];
 }
 
 def Vector_CompressStoreOp :
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 23c908a95616f..b95df0fe30eb2 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1567,6 +1567,20 @@ func.func @expand_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>,
 
 // -----
 
+func.func @expand_invalid_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0: index) {
+  // expected-error at +1 {{'vector.expandload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  %0 = vector.expandload %base[%c0], %mask, %pass_thru { alignment = -1 } : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+func.func @expand_invalid_alignment(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0: index) {
+  // expected-error at +1 {{'vector.expandload' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  %0 = vector.expandload %base[%c0], %mask, %pass_thru { alignment = 3 } : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
 func.func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
   // expected-error at +1 {{'vector.compressstore' op base and valueToStore element type should match}}



More information about the Mlir-commits mailing list