[Mlir-commits] [mlir] Add attribute to MemRef/Vector memory access ops (PR #144344)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 16 05:48:43 PDT 2025


https://github.com/tyb0807 created https://github.com/llvm/llvm-project/pull/144344

None

>From a474d1958190716f17cf5c9a84f54de8044b5be4 Mon Sep 17 00:00:00 2001
From: tyb0807 <sontuan.vu119 at gmail.com>
Date: Mon, 16 Jun 2025 14:45:40 +0200
Subject: [PATCH] Add  attribute to MemRef/Vector memory access ops

---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       | 61 +++++++++++-
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 50 +++++++++-
 .../Dialect/MemRef/load-store-alignment.mlir  | 27 ++++++
 .../Dialect/Vector/load-store-alignment.mlir  | 27 ++++++
 mlir/unittests/Dialect/CMakeLists.txt         |  1 +
 mlir/unittests/Dialect/MemRef/CMakeLists.txt  |  1 +
 .../Dialect/MemRef/LoadStoreAlignment.cpp     | 88 +++++++++++++++++
 mlir/unittests/Dialect/Vector/CMakeLists.txt  |  7 ++
 .../Dialect/Vector/LoadStoreAlignment.cpp     | 95 +++++++++++++++++++
 9 files changed, 352 insertions(+), 5 deletions(-)
 create mode 100644 mlir/test/Dialect/MemRef/load-store-alignment.mlir
 create mode 100644 mlir/test/Dialect/Vector/load-store-alignment.mlir
 create mode 100644 mlir/unittests/Dialect/MemRef/LoadStoreAlignment.cpp
 create mode 100644 mlir/unittests/Dialect/Vector/CMakeLists.txt
 create mode 100644 mlir/unittests/Dialect/Vector/LoadStoreAlignment.cpp

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 77e3074661abf..160b04e452c5a 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1227,7 +1227,45 @@ def LoadOp : MemRef_Op<"load",
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
                            [MemRead]>:$memref,
                        Variadic<Index>:$indices,
-                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+                       ConfinedAttr<OptionalAttr<I32Attr>,
+                                    [IntPositive]>:$alignment);
+
+  let builders = [
+    OpBuilder<(ins "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, memref, indices, false, alignment);
+    }]>,
+    OpBuilder<(ins "Value":$memref,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, memref, indices, nontemporal,
+                   IntegerAttr());
+    }]>,
+    OpBuilder<(ins "Type":$resultType,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, resultType, memref, indices, false,
+                   alignment);
+    }]>,
+    OpBuilder<(ins "Type":$resultType,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, resultType, memref, indices, nontemporal,
+                   IntegerAttr());
+    }]>,
+    OpBuilder<(ins "TypeRange":$resultTypes,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, resultTypes, memref, indices, false,
+                   alignment);
+    }]>
+  ];
+
   let results = (outs AnyType:$result);
 
   let extraClassDeclaration = [{
@@ -1924,13 +1962,30 @@ def MemRef_StoreOp : MemRef_Op<"store",
                        Arg<AnyMemRef, "the reference to store to",
                            [MemWrite]>:$memref,
                        Variadic<Index>:$indices,
-                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+                       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+                       ConfinedAttr<OptionalAttr<I32Attr>,
+                                    [IntPositive]>:$alignment);
 
   let builders = [
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, valueToStore, memref, indices, false,
+                   alignment);
+    }]>,
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$memref,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, valueToStore, memref, indices, nontemporal,
+                   IntegerAttr());
+    }]>,
     OpBuilder<(ins "Value":$valueToStore, "Value":$memref), [{
       $_state.addOperands(valueToStore);
       $_state.addOperands(memref);
-    }]>];
+    }]>
+  ];
 
   let extraClassDeclaration = [{
       Value getValueToStore() { return getOperand(0); }
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 8353314ed958b..3cd71491bcc04 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1739,7 +1739,34 @@ def Vector_LoadOp : Vector_Op<"load"> {
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
       [MemRead]>:$base,
       Variadic<Index>:$indices,
-      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
+      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+      ConfinedAttr<OptionalAttr<I32Attr>,
+                   [IntPositive]>:$alignment);
+
+  let builders = [
+    OpBuilder<(ins "VectorType":$resultType,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, resultType, base, indices, false,
+                   alignment);
+    }]>,
+    OpBuilder<(ins "VectorType":$resultType,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, resultType, base, indices, nontemporal,
+                   IntegerAttr());
+    }]>,
+    OpBuilder<(ins "TypeRange":$resultTypes,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, resultTypes, base, indices, false,
+                   alignment);
+    }]>
+  ];
+
   let results = (outs AnyVectorOfAnyRank:$result);
 
   let extraClassDeclaration = [{
@@ -1825,9 +1852,28 @@ def Vector_StoreOp : Vector_Op<"store"> {
       Arg<AnyMemRef, "the reference to store to",
       [MemWrite]>:$base,
       Variadic<Index>:$indices,
-      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal
+      DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
+      ConfinedAttr<OptionalAttr<I32Attr>,
+                   [IntPositive]>:$alignment
   );
 
+  let builders = [
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
+      return build($_builder, $_state, valueToStore, base, indices, false,
+                   alignment);
+    }]>,
+    OpBuilder<(ins "Value":$valueToStore,
+                   "Value":$base,
+                   "ValueRange":$indices,
+                   "bool":$nontemporal), [{
+      return build($_builder, $_state, valueToStore, base, indices, nontemporal,
+                   IntegerAttr());
+    }]>
+  ];
+
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
       return ::llvm::cast<MemRefType>(getBase().getType());
diff --git a/mlir/test/Dialect/MemRef/load-store-alignment.mlir b/mlir/test/Dialect/MemRef/load-store-alignment.mlir
new file mode 100644
index 0000000000000..4f5a5461e0ac0
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/load-store-alignment.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: func @test_load_store_alignment
+// CHECK: memref.load {{.*}} {alignment = 16 : i32}
+// CHECK: memref.store {{.*}} {alignment = 16 : i32}
+func.func @test_load_store_alignment(%memref: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
+  %val = memref.load %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>
+  memref.store %val, %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_load_alignment(%memref: memref<4xi32>) {
+  // expected-error @+1 {{custom op 'memref.load' 'memref.load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
+  %val = memref.load %memref[%c0] { alignment = -1 } : memref<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_store_alignment(%memref: memref<4xi32>, %val: memref<4xi32>) {
+  // expected-error @+1 {{custom op 'memref.store' 'memref.store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
+  memref.store %val, %memref[%c0] { alignment = -1 } : memref<4xi32>
+  return
+}
diff --git a/mlir/test/Dialect/Vector/load-store-alignment.mlir b/mlir/test/Dialect/Vector/load-store-alignment.mlir
new file mode 100644
index 0000000000000..4f54d989dd190
--- /dev/null
+++ b/mlir/test/Dialect/Vector/load-store-alignment.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: func @test_load_store_alignment
+// CHECK: vector.load {{.*}} {alignment = 16 : i32}
+// CHECK: vector.store {{.*}} {alignment = 16 : i32}
+func.func @test_load_store_alignment(%memref: memref<4xi32>) {
+  %c0 = arith.constant 0 : index
+  %val = vector.load %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>, vector<4xi32>
+  vector.store %val, %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>, vector<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_load_alignment(%memref: memref<4xi32>) {
+  // expected-error @+1 {{custom op 'vector.load' 'vector.load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
+  %val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
+  return
+}
+
+// -----
+
+func.func @test_invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>) {
+  // expected-error @+1 {{custom op 'vector.store' 'vector.store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
+  vector.store %val, %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
+  return
+}
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index aea247547473d..34c9fb7317443 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -18,3 +18,4 @@ add_subdirectory(SPIRV)
 add_subdirectory(SMT)
 add_subdirectory(Transform)
 add_subdirectory(Utils)
+add_subdirectory(Vector)
diff --git a/mlir/unittests/Dialect/MemRef/CMakeLists.txt b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
index dede3ba0a885c..87d33854fadcd 100644
--- a/mlir/unittests/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_unittest(MLIRMemRefTests
   InferShapeTest.cpp
+  LoadStoreAlignment.cpp
 )
 mlir_target_link_libraries(MLIRMemRefTests
   PRIVATE
diff --git a/mlir/unittests/Dialect/MemRef/LoadStoreAlignment.cpp b/mlir/unittests/Dialect/MemRef/LoadStoreAlignment.cpp
new file mode 100644
index 0000000000000..f0b8e93c2d0e1
--- /dev/null
+++ b/mlir/unittests/Dialect/MemRef/LoadStoreAlignment.cpp
@@ -0,0 +1,88 @@
+//===- LoadStoreAlignment.cpp - unit tests for load/store alignment -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Verifier.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::memref;
+
+TEST(LoadStoreAlignmentTest, ValidAlignment) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+  ctx.loadDialect<memref::MemRefDialect>();
+
+  // Create a dummy memref
+  Type elementType = b.getI32Type();
+  auto memrefType = MemRefType::get({4}, elementType);
+  Value memref = b.create<memref::AllocaOp>(b.getUnknownLoc(), memrefType);
+
+  // Create load with valid alignment
+  Value zero = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), 0);
+  IntegerAttr alignment = IntegerAttr::get(IntegerType::get(&ctx, 32), 16);
+  auto loadOp =
+      b.create<LoadOp>(b.getUnknownLoc(), memref, ValueRange{zero}, alignment);
+
+  // Verify the attribute exists
+  auto alignmentAttr = loadOp->getAttrOfType<IntegerAttr>("alignment");
+  EXPECT_TRUE(alignmentAttr != nullptr);
+  EXPECT_EQ(alignmentAttr.getInt(), 16);
+
+  // Create store with valid alignment
+  auto storeOp = b.create<StoreOp>(b.getUnknownLoc(), loadOp, memref,
+                                   ValueRange{zero}, alignment);
+
+  // Verify the attribute exists
+  alignmentAttr = storeOp->getAttrOfType<IntegerAttr>("alignment");
+  EXPECT_TRUE(alignmentAttr != nullptr);
+  EXPECT_EQ(alignmentAttr.getInt(), 16);
+}
+
+TEST(LoadStoreAlignmentTest, InvalidAlignmentFailsVerification) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+  ctx.loadDialect<memref::MemRefDialect>();
+
+  Type elementType = b.getI32Type();
+  auto memrefType = MemRefType::get({4}, elementType);
+  Value memref = b.create<memref::AllocaOp>(b.getUnknownLoc(), memrefType);
+
+  Value zero = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), 0);
+  IntegerAttr alignment = IntegerAttr::get(IntegerType::get(&ctx, 32), -1);
+
+  auto loadOp =
+      b.create<LoadOp>(b.getUnknownLoc(), memref, ValueRange{zero}, alignment);
+
+  // Capture diagnostics
+  std::string errorMessage;
+  ScopedDiagnosticHandler handler(
+      &ctx, [&](Diagnostic &diag) { errorMessage = diag.str(); });
+
+  // Trigger verification
+  auto result = mlir::verify(loadOp);
+
+  // Check results
+  EXPECT_TRUE(failed(result));
+  EXPECT_EQ(
+      errorMessage,
+      "'memref.load' op attribute 'alignment' failed to satisfy constraint: "
+      "32-bit signless integer attribute whose value is positive");
+
+  auto storeOp = b.create<StoreOp>(b.getUnknownLoc(), loadOp, memref,
+                                   ValueRange{zero}, alignment);
+  result = mlir::verify(storeOp);
+
+  // Check results
+  EXPECT_TRUE(failed(result));
+  EXPECT_EQ(
+      errorMessage,
+      "'memref.store' op attribute 'alignment' failed to satisfy constraint: "
+      "32-bit signless integer attribute whose value is positive");
+}
diff --git a/mlir/unittests/Dialect/Vector/CMakeLists.txt b/mlir/unittests/Dialect/Vector/CMakeLists.txt
new file mode 100644
index 0000000000000..b23d9c2df3870
--- /dev/null
+++ b/mlir/unittests/Dialect/Vector/CMakeLists.txt
@@ -0,0 +1,7 @@
+add_mlir_unittest(MLIRVectorTests
+  LoadStoreAlignment.cpp
+)
+mlir_target_link_libraries(MLIRVectorTests
+  PRIVATE
+  MLIRVectorDialect
+  )
diff --git a/mlir/unittests/Dialect/Vector/LoadStoreAlignment.cpp b/mlir/unittests/Dialect/Vector/LoadStoreAlignment.cpp
new file mode 100644
index 0000000000000..745dd8632fe4d
--- /dev/null
+++ b/mlir/unittests/Dialect/Vector/LoadStoreAlignment.cpp
@@ -0,0 +1,95 @@
+//===- LoadStoreAlignment.cpp - unit tests for load/store alignment -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Verifier.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+TEST(LoadStoreAlignmentTest, ValidAlignment) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+  ctx.loadDialect<memref::MemRefDialect>();
+  ctx.loadDialect<vector::VectorDialect>();
+
+  // Create a dummy memref
+  Type elementType = b.getI32Type();
+  auto memrefType = MemRefType::get({4}, elementType);
+  Value memref = b.create<memref::AllocaOp>(b.getUnknownLoc(), memrefType);
+
+  VectorType elemVecTy = VectorType::get({2}, elementType);
+
+  // Create load with valid alignment
+  Value zero = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), 0);
+  IntegerAttr alignment = IntegerAttr::get(IntegerType::get(&ctx, 32), 16);
+  auto loadOp = b.create<LoadOp>(b.getUnknownLoc(), elemVecTy, memref,
+                                 ValueRange{zero}, alignment);
+
+  // Verify the attribute exists
+  auto alignmentAttr = loadOp->getAttrOfType<IntegerAttr>("alignment");
+  EXPECT_TRUE(alignmentAttr != nullptr);
+  EXPECT_EQ(alignmentAttr.getInt(), 16);
+
+  // Create store with valid alignment
+  auto storeOp = b.create<StoreOp>(b.getUnknownLoc(), loadOp, memref,
+                                   ValueRange{zero}, alignment);
+
+  // Verify the attribute exists
+  alignmentAttr = storeOp->getAttrOfType<IntegerAttr>("alignment");
+  EXPECT_TRUE(alignmentAttr != nullptr);
+  EXPECT_EQ(alignmentAttr.getInt(), 16);
+}
+
+TEST(LoadStoreAlignmentTest, InvalidAlignmentFailsVerification) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+  ctx.loadDialect<memref::MemRefDialect>();
+  ctx.loadDialect<vector::VectorDialect>();
+
+  Type elementType = b.getI32Type();
+  auto memrefType = MemRefType::get({4}, elementType);
+  Value memref = b.create<memref::AllocaOp>(b.getUnknownLoc(), memrefType);
+
+  VectorType elemVecTy = VectorType::get({2}, elementType);
+
+  Value zero = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), 0);
+  IntegerAttr alignment = IntegerAttr::get(IntegerType::get(&ctx, 32), -1);
+
+  auto loadOp = b.create<LoadOp>(b.getUnknownLoc(), elemVecTy, memref,
+                                 ValueRange{zero}, alignment);
+
+  // Capture diagnostics
+  std::string errorMessage;
+  ScopedDiagnosticHandler handler(
+      &ctx, [&](Diagnostic &diag) { errorMessage = diag.str(); });
+
+  // Trigger verification
+  auto result = mlir::verify(loadOp);
+
+  // Check results
+  EXPECT_TRUE(failed(result));
+  EXPECT_EQ(
+      errorMessage,
+      "'vector.load' op attribute 'alignment' failed to satisfy constraint: "
+      "32-bit signless integer attribute whose value is positive");
+
+  auto storeOp = b.create<StoreOp>(b.getUnknownLoc(), loadOp, memref,
+                                   ValueRange{zero}, alignment);
+  result = mlir::verify(storeOp);
+
+  // Check results
+  EXPECT_TRUE(failed(result));
+  EXPECT_EQ(
+      errorMessage,
+      "'vector.store' op attribute 'alignment' failed to satisfy constraint: "
+      "32-bit signless integer attribute whose value is positive");
+}



More information about the Mlir-commits mailing list