[Mlir-commits] [mlir] [mlir][vector][memref] Add `alignment` attribute to memory access ops (PR #144344)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 23 04:59:52 PDT 2025


https://github.com/tyb0807 updated https://github.com/llvm/llvm-project/pull/144344

>From 0019a9e47c1b5015701fda6bed0c0129f3e5decb Mon Sep 17 00:00:00 2001
From: tyb0807 <sontuan.vu119 at gmail.com>
Date: Mon, 16 Jun 2025 14:45:40 +0200
Subject: [PATCH 1/3] [mlir][memref][vectpr] Add alignment attribute to memory
 access ops

---
 mlir/docs/DefiningDialects/Operations.md      |  2 +
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       | 60 ++++++++++++++++++-
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 55 ++++++++++++++++-
 mlir/include/mlir/IR/CommonAttrConstraints.td |  4 ++
 .../Dialect/MemRef/load-store-alignment.mlir  | 27 +++++++++
 .../Dialect/Vector/load-store-alignment.mlir  | 27 +++++++++
 6 files changed, 169 insertions(+), 6 deletions(-)
 create mode 100644 mlir/test/Dialect/MemRef/load-store-alignment.mlir
 create mode 100644 mlir/test/Dialect/Vector/load-store-alignment.mlir

diff --git a/mlir/docs/DefiningDialects/Operations.md b/mlir/docs/DefiningDialects/Operations.md
index b3bde055f04f0..2225329ff830b 100644
--- a/mlir/docs/DefiningDialects/Operations.md
+++ b/mlir/docs/DefiningDialects/Operations.md
@@ -306,6 +306,8 @@ Right now, the following primitive constraints are supported:
 *   `IntPositive`: Specifying an integer attribute whose value is positive
 *   `IntNonNegative`: Specifying an integer attribute whose value is
     non-negative
+*   `IntPowerOf2`: Specifying an integer attribute whose value is a power of
+    two > 0
 *   `ArrayMinCount<N>`: Specifying an array attribute to have at least `N`
     elements
 *   `ArrayMaxCount<N>`: Specifying an array attribute to have at most `N`
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 77e3074661abf..dad4b7bff40cd 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1217,6 +1217,11 @@ def LoadOp : MemRef_Op<"load",
     be reused in the cache. For details, refer to the
     [https://llvm.org/docs/LangRef.html#load-instruction](LLVM load instruction).
 
+    An optional `alignment` attribute allows to specify the byte alignment of the
+    load operation. It must be a positive power of 2. The operation must access
+    memory at an address aligned to this boundary. Violations may lead to
+    architecture-specific faults or performance penalties.
+    A value of 0 indicates no specific alignment requirement.
     Example:
 
     ```mlir
@@ -1227,7 +1232,39 @@ def LoadOp : MemRef_Op<"load",
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
                            [MemRead]>:$memref,
                        Variadic<Index>:$indices,
-                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+                       ConfinedAttr<OptionalAttr<I64Attr>,
+                                    [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
+
+  let builders = [
+    OpBuilder<(ins "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"bool", "false">:$nontemporal,
+                   CArg<"uint64_t", "0">:$alignment), [{
+      return build($_builder, $_state, memref, indices, nontemporal,
+                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                                    nullptr);
+    }]>,
+    OpBuilder<(ins "Type":$resultType,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"bool", "false">:$nontemporal,
+                   CArg<"uint64_t", "0">:$alignment), [{
+      return build($_builder, $_state, resultType, memref, indices, nontemporal,
+                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                                    nullptr);
+    }]>,
+    OpBuilder<(ins "TypeRange":$resultTypes,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"bool", "false">:$nontemporal,
+                   CArg<"uint64_t", "0">:$alignment), [{
+      return build($_builder, $_state, resultTypes, memref, indices, nontemporal,
+                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                                    nullptr);
+    }]>
+  ];
+
   let results = (outs AnyType:$result);
 
   let extraClassDeclaration = [{
@@ -1913,6 +1950,11 @@ def MemRef_StoreOp : MemRef_Op<"store",
     be reused in the cache. For details, refer to the
     [https://llvm.org/docs/LangRef.html#store-instruction](LLVM store instruction).
 
+    An optional `alignment` attribute allows to specify the byte alignment of the
+    store operation. It must be a positive power of 2. The operation must access
+    memory at an address aligned to this boundary. Violations may lead to
+    architecture-specific faults or performance penalties.
+    A value of 0 indicates no specific alignment requirement.
     Example:
 
     ```mlir
@@ -1924,13 +1966,25 @@ def MemRef_StoreOp : MemRef_Op<"store",
                        Arg<AnyMemRef, "the reference to store to",
                            [MemWrite]>:$memref,
                        Variadic<Index>:$indices,
-                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+                       ConfinedAttr<OptionalAttr<I64Attr>,
+                                    [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
 
   let builders = [
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"bool", "false">:$nontemporal,
+                   CArg<"uint64_t", "0">:$alignment), [{
+      return build($_builder, $_state, valueToStore, memref, indices, nontemporal,
+                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                                    nullptr);
+    }]>,
     OpBuilder<(ins "Value":$valueToStore, "Value":$memref), [{
       $_state.addOperands(valueToStore);
       $_state.addOperands(memref);
-    }]>];
+    }]>
+  ];
 
   let extraClassDeclaration = [{
       Value getValueToStore() { return getOperand(0); }
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 8353314ed958b..60a173333bc7c 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1734,12 +1734,42 @@ def Vector_LoadOp : Vector_Op<"load"> {
     ```mlir
     %result = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
     ```
+
+    An optional `alignment` attribute allows to specify the byte alignment of the
+    load operation. It must be a positive power of 2. The operation must access
+    memory at an address aligned to this boundary. Violations may lead to
+    architecture-specific faults or performance penalties.
+    A value of 0 indicates no specific alignment requirement.
   }];
 
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
       [MemRead]>:$base,
       Variadic<Index>:$indices,
-      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+      ConfinedAttr<OptionalAttr<I64Attr>,
+                   [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
+
+  let builders = [
+    OpBuilder<(ins "VectorType":$resultType,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   CArg<"bool", "false">:$nontemporal,
+                   CArg<"uint64_t", "0">:$alignment), [{
+      return build($_builder, $_state, resultType, base, indices, nontemporal,
+                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                                    nullptr);
+    }]>,
+    OpBuilder<(ins "TypeRange":$resultTypes,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   CArg<"bool", "false">:$nontemporal,
+                   CArg<"uint64_t", "0">:$alignment), [{
+      return build($_builder, $_state, resultTypes, base, indices, nontemporal,
+                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                                    nullptr);
+    }]>
+  ];
+
   let results = (outs AnyVectorOfAnyRank:$result);
 
   let extraClassDeclaration = [{
@@ -1818,6 +1848,12 @@ def Vector_StoreOp : Vector_Op<"store"> {
     ```mlir
     vector.store %valueToStore, %memref[%c0] : memref<7xf32>, vector<8xf32>
     ```
+
+    An optional `alignment` attribute allows to specify the byte alignment of the
+    store operation. It must be a positive power of 2. The operation must access
+    memory at an address aligned to this boundary. Violations may lead to
+    architecture-specific faults or performance penalties.
+    A value of 0 indicates no specific alignment requirement.
   }];
 
   let arguments = (ins
@@ -1825,8 +1861,21 @@ def Vector_StoreOp : Vector_Op<"store"> {
       Arg<AnyMemRef, "the reference to store to",
       [MemWrite]>:$base,
       Variadic<Index>:$indices,
-      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal
-  );
+      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+      ConfinedAttr<OptionalAttr<I64Attr>,
+                   [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
+
+  let builders = [
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   CArg<"bool", "false">:$nontemporal,
+                   CArg<"uint64_t", "0">:$alignment), [{
+      return build($_builder, $_state, valueToStore, base, indices, nontemporal,
+                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                                    nullptr);
+    }]>
+  ];
 
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index e91a13fea5c7f..18da85a580710 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -796,6 +796,10 @@ def IntPositive : AttrConstraint<
     CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isStrictlyPositive()">,
     "whose value is positive">;
 
+def IntPowerOf2 : AttrConstraint<
+    CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isPowerOf2()">,
+    "whose value is a power of two > 0">;
+
 class ArrayMaxCount<int n> : AttrConstraint<
     CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() <= " # n>,
     "with at most " # n # " elements">;
diff --git a/mlir/test/Dialect/MemRef/load-store-alignment.mlir b/mlir/test/Dialect/MemRef/load-store-alignment.mlir
new file mode 100644
index 0000000000000..afd314a9b498a
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/load-store-alignment.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: func @test_load_store_alignment
+// CHECK: memref.load {{.*}} {alignment = 16 : i64}
+// CHECK: memref.store {{.*}} {alignment = 16 : i64}
+func.func @test_load_store_alignment(%memref: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
+  %val = memref.load %memref[%c0] { alignment = 16 } : memref<4xi32>
+  memref.store %val, %memref[%c0] { alignment = 16 } : memref<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_negative_load_alignment(%memref: memref<4xi32>) {
+  // expected-error @+1 {{custom op 'memref.load' 'memref.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  %val = memref.load %memref[%c0] { alignment = -1 } : memref<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_non_power_of_2_store_alignment(%memref: memref<4xi32>, %val: i32) {
+  // expected-error @+1 {{custom op 'memref.store' 'memref.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  memref.store %val, %memref[%c0] { alignment = 1 } : memref<4xi32>
+  return
+}
diff --git a/mlir/test/Dialect/Vector/load-store-alignment.mlir b/mlir/test/Dialect/Vector/load-store-alignment.mlir
new file mode 100644
index 0000000000000..38b946f89e7c3
--- /dev/null
+++ b/mlir/test/Dialect/Vector/load-store-alignment.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: func @test_load_store_alignment
+// CHECK: vector.load {{.*}} {alignment = 16 : i64}
+// CHECK: vector.store {{.*}} {alignment = 16 : i64}
+func.func @test_load_store_alignment(%memref: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
+  %val = vector.load %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32>
+  vector.store %val, %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_negative_load_alignment(%memref: memref<4xi32>) {
+  // expected-error @+1 {{custom op 'vector.load' 'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  %val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_non_power_of_2_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>) {
+  // expected-error @+1 {{custom op 'vector.store' 'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  vector.store %val, %memref[%c0] { alignment = 1 } : memref<4xi32>, vector<4xi32>
+  return
+}

>From 8a614dfc533274a957d595fc021fba3da4584a9f Mon Sep 17 00:00:00 2001
From: tyb0807 <sontuan.vu119 at gmail.com>
Date: Fri, 20 Jun 2025 15:26:04 +0200
Subject: [PATCH 2/3] Lowering alignment attribute from memref.load/store to
 LLVM

---
 .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp  |  7 ++++---
 .../convert-static-memref-ops.mlir            | 19 +++++++++++++++++++
 2 files changed, 23 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 8ccf1bfc292d5..267a711a71914 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -841,8 +841,8 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
                                          adaptor.getMemref(),
                                          adaptor.getIndices(), kNoWrapFlags);
     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
-        loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
-        false, loadOp.getNontemporal());
+        loadOp, typeConverter->convertType(type.getElementType()), dataPtr,
+        loadOp.getAlignment().value_or(0), false, loadOp.getNontemporal());
     return success();
   }
 };
@@ -864,7 +864,8 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
         getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
                              adaptor.getIndices(), kNoWrapFlags);
     rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
-                                               0, false, op.getNontemporal());
+                                               op.getAlignment().value_or(0),
+                                               false, op.getNontemporal());
     return success();
   }
 };
diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
index 040a27e160557..874acbc9b6c3c 100644
--- a/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir
@@ -148,6 +148,15 @@ func.func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
 
 // -----
 
+// CHECK-LABEL: func @aligned_load(
+func.func @aligned_load(%static : memref<10x42xf32>, %i : index, %j : index) {
+// CHECK:  llvm.load %{{.*}} {alignment = 16 : i64} : !llvm.ptr -> f32
+  %0 = memref.load %static[%i, %j] { alignment = 16 } : memref<10x42xf32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @zero_d_store
 func.func @zero_d_store(%arg0: memref<f32>, %arg1: f32) {
 // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64)>
@@ -177,6 +186,16 @@ func.func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %va
 
 // -----
 
+// CHECK-LABEL: func @aligned_store
+func.func @aligned_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) {
+// CHECK: llvm.store %{{.*}}, %{{.*}} {alignment = 16 : i64} : f32, !llvm.ptr
+
+  memref.store %val, %static[%i, %j] { alignment = 16 } : memref<10x42xf32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @static_memref_dim
 func.func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) {
 // CHECK:  llvm.mlir.constant(42 : index) : i64

>From d361423010b99c59404447edcc96a79fb4213a9e Mon Sep 17 00:00:00 2001
From: tyb0807 <sontuan.vu119 at gmail.com>
Date: Fri, 20 Jun 2025 17:21:35 +0200
Subject: [PATCH 3/3] Add aligment attribute to vector.maskedload/store ops

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 53 ++++++++++++++++++-
 .../Dialect/Vector/load-store-alignment.mlir  | 16 +++++-
 2 files changed, 66 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 60a173333bc7c..f9eadb431be51 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1899,7 +1899,9 @@ def Vector_MaskedLoadOp :
     Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
-               AnyVectorOfNonZeroRank:$pass_thru)>,
+               AnyVectorOfNonZeroRank:$pass_thru,
+               ConfinedAttr<OptionalAttr<I64Attr>,
+                            [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
     Results<(outs AnyVectorOfNonZeroRank:$result)> {
 
   let summary = "loads elements from memory into a vector as defined by a mask vector";
@@ -1920,6 +1922,12 @@ def Vector_MaskedLoadOp :
     comes from the pass-through vector regardless of the index, and the index is
     allowed to be out-of-bounds.
 
+    An optional `alignment` attribute allows to specify the byte alignment of the
+    load operation. It must be a positive power of 2. The operation must access
+    memory at an address aligned to this boundary. Violations may lead to
+    architecture-specific faults or performance penalties.
+    A value of 0 indicates no specific alignment requirement.
+
     The masked load can be used directly where applicable, or can be used
     during progressively lowering to bring other memory operations closer to
     hardware ISA support for a masked load. The semantics of the operation
@@ -1936,6 +1944,18 @@ def Vector_MaskedLoadOp :
        : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
     ```
   }];
+  let builders = [
+    OpBuilder<(ins "Type":$resultType,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   "Value":$mask,
+                   "Value":$pass_thru,
+                   CArg<"uint64_t", "0">:$alignment), [{
+      return build($_builder, $_state, resultType, base, indices, mask, pass_thru,
+                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                                    nullptr);
+    }]>
+  ];
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
@@ -1962,7 +1982,9 @@ def Vector_MaskedStoreOp :
     Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
-               AnyVectorOfNonZeroRank:$valueToStore)> {
+               AnyVectorOfNonZeroRank:$valueToStore,
+               ConfinedAttr<OptionalAttr<I64Attr>,
+                            [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> {
 
   let summary = "stores elements from a vector into memory as defined by a mask vector";
 
@@ -1982,6 +2004,12 @@ def Vector_MaskedStoreOp :
     is stored regardless of the index, and the index is allowed to be
     out-of-bounds.
 
+    An optional `alignment` attribute allows to specify the byte alignment of the
+    store operation. It must be a positive power of 2. The operation must access
+    memory at an address aligned to this boundary. Violations may lead to
+    architecture-specific faults or performance penalties.
+    A value of 0 indicates no specific alignment requirement.
+
     The masked store can be used directly where applicable, or can be used
     during progressively lowering to bring other memory operations closer to
     hardware ISA support for a masked store. The semantics of the operation
@@ -1998,6 +2026,27 @@ def Vector_MaskedStoreOp :
       : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
     ```
   }];
+  let builders = [
+    OpBuilder<(ins "TypeRange":$resultTypes,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   "Value":$mask,
+                   "Value":$valueToStore,
+                   CArg<"uint64_t", "0">:$alignment), [{
+      return build($_builder, $_state, resultTypes, base, indices, mask, valueToStore,
+                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                                    nullptr);
+    }]>,
+    OpBuilder<(ins "Value":$base,
+                   "ValueRange":$indices,
+                   "Value":$mask,
+                   "Value":$valueToStore,
+                   CArg<"uint64_t", "0">:$alignment), [{
+      return build($_builder, $_state, base, indices, mask, valueToStore,
+                   alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
+                                    nullptr);
+    }]>
+  ];
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
diff --git a/mlir/test/Dialect/Vector/load-store-alignment.mlir b/mlir/test/Dialect/Vector/load-store-alignment.mlir
index 38b946f89e7c3..e7beb8a3a9bd0 100644
--- a/mlir/test/Dialect/Vector/load-store-alignment.mlir
+++ b/mlir/test/Dialect/Vector/load-store-alignment.mlir
@@ -1,5 +1,17 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
 
+// CHECK-LABEL: func @test_masked_load_store_alignment
+// CHECK: vector.maskedload {{.*}} {alignment = 16 : i64}
+// CHECK: vector.maskedstore {{.*}} {alignment = 16 : i64}
+func.func @test_masked_load_store_alignment(%memref: memref<4xi32>, %mask: vector<4xi1>, %passthru: vector<4xi32>) {
+  %c0 = arith.constant 0 : index
+  %val = vector.maskedload %memref[%c0], %mask, %passthru { alignment = 16 } : memref<4xi32>, vector<4xi1>, vector<4xi32> into vector<4xi32>
+  vector.maskedstore %memref[%c0], %mask,  %val { alignment = 16 } : memref<4xi32>, vector<4xi1>, vector<4xi32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @test_load_store_alignment
 // CHECK: vector.load {{.*}} {alignment = 16 : i64}
 // CHECK: vector.store {{.*}} {alignment = 16 : i64}
@@ -13,6 +25,7 @@ func.func @test_load_store_alignment(%memref: memref<4xi32>) {
 // -----
 
 func.func @test_invalid_negative_load_alignment(%memref: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
   // expected-error @+1 {{custom op 'vector.load' 'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
   %val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
   return
@@ -21,7 +34,8 @@ func.func @test_invalid_negative_load_alignment(%memref: memref<4xi32>) {
 // -----
 
 func.func @test_invalid_non_power_of_2_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>) {
+  %c0 = arith.constant 0 : index
   // expected-error @+1 {{custom op 'vector.store' 'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
-  vector.store %val, %memref[%c0] { alignment = 1 } : memref<4xi32>, vector<4xi32>
+  vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
   return
 }



More information about the Mlir-commits mailing list