[Mlir-commits] [mlir] [mlir][spirv] Propagate alignment requirements from vector to spirv (PR #155278)

Erick Ochoa Lopez llvmlistbot at llvm.org
Mon Aug 25 12:16:19 PDT 2025


https://github.com/amd-eochoalo created https://github.com/llvm/llvm-project/pull/155278

None

>From f0de6ef4edffe02f1f9047c92f24314752a9d850 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 25 Aug 2025 08:26:08 -0700
Subject: [PATCH 1/3] Add IntValidAlignment predicate

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 24 +++++++------------
 mlir/include/mlir/IR/CommonAttrConstraints.td |  4 ++++
 2 files changed, 12 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index bcc423a634148..fd1c0159e3aba 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1720,8 +1720,7 @@ def Vector_LoadOp : Vector_Op<"load", [
       [MemRead]>:$base,
       Variadic<Index>:$indices,
       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
-      ConfinedAttr<OptionalAttr<I64Attr>,
-                   [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
+      OptionalAttr<IntValidAlignment<I64Attr>>: $alignment);
 
   let builders = [
     OpBuilder<(ins "VectorType":$resultType,
@@ -1837,8 +1836,7 @@ def Vector_StoreOp : Vector_Op<"store", [
       [MemWrite]>:$base,
       Variadic<Index>:$indices,
       DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
-      ConfinedAttr<OptionalAttr<I64Attr>,
-                   [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
+      OptionalAttr<IntValidAlignment<I64Attr>>: $alignment);
 
   let builders = [
     OpBuilder<(ins "Value":$valueToStore,
@@ -1875,8 +1873,7 @@ def Vector_MaskedLoadOp :
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
                AnyVectorOfNonZeroRank:$pass_thru,
-               ConfinedAttr<OptionalAttr<I64Attr>,
-                   [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
+               OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)>,
     Results<(outs AnyVectorOfNonZeroRank:$result)> {
 
   let summary = "loads elements from memory into a vector as defined by a mask vector";
@@ -1968,8 +1965,7 @@ def Vector_MaskedStoreOp :
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[I1]>:$mask,
                AnyVectorOfNonZeroRank:$valueToStore,
-               ConfinedAttr<OptionalAttr<I64Attr>,
-                   [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> {
+               OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)> {
 
   let summary = "stores elements from a vector into memory as defined by a mask vector";
 
@@ -2051,8 +2047,7 @@ def Vector_GatherOp :
                VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
                VectorOfNonZeroRankOf<[I1]>:$mask,
                AnyVectorOfNonZeroRank:$pass_thru,
-               ConfinedAttr<OptionalAttr<I64Attr>,
-                   [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
+               OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)>,
     Results<(outs AnyVectorOfNonZeroRank:$result)> {
 
   let summary = [{
@@ -2154,8 +2149,7 @@ def Vector_ScatterOp :
                VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
                VectorOfNonZeroRankOf<[I1]>:$mask,
                AnyVectorOfNonZeroRank:$valueToStore,
-               ConfinedAttr<OptionalAttr<I64Attr>,
-                   [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> {
+               OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)> {
 
   let summary = [{
     scatters elements from a vector into memory as defined by an index vector
@@ -2239,8 +2233,7 @@ def Vector_ExpandLoadOp :
                Variadic<Index>:$indices,
                FixedVectorOfNonZeroRankOf<[I1]>:$mask,
                AnyVectorOfNonZeroRank:$pass_thru,
-               ConfinedAttr<OptionalAttr<I64Attr>,
-                            [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
+               OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)>,
     Results<(outs AnyVectorOfNonZeroRank:$result)> {
 
   let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
@@ -2328,8 +2321,7 @@ def Vector_CompressStoreOp :
                Variadic<Index>:$indices,
                FixedVectorOfNonZeroRankOf<[I1]>:$mask,
                AnyVectorOfNonZeroRank:$valueToStore,
-               ConfinedAttr<OptionalAttr<I64Attr>,
-                            [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> {
+               OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)> {
 
   let summary = "writes elements selectively from a vector as defined by a mask";
 
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 18da85a580710..e1869c1821b11 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -800,6 +800,10 @@ def IntPowerOf2 : AttrConstraint<
     CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isPowerOf2()">,
     "whose value is a power of two > 0">;
 
+def IntPositivePowerOf2 : AllAttrOf<[IntPositive, IntPowerOf2]>;
+
+class IntValidAlignment<Attr attr>: ConfinedAttr<attr, [IntPositivePowerOf2]>;
+
 class ArrayMaxCount<int n> : AttrConstraint<
     CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() <= " # n>,
     "with at most " # n # " elements">;

>From a2b8f6b40c5c078a3a813c5983a774b0d3f0f432 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 25 Aug 2025 11:25:59 -0700
Subject: [PATCH 2/3] [mlir][spirv] Constraint alignment attribute

---
 .../mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td   |  4 +-
 mlir/test/Dialect/SPIRV/IR/invalid.mlir       | 43 +++++++++++++++++++
 2 files changed, 45 insertions(+), 2 deletions(-)
 create mode 100644 mlir/test/Dialect/SPIRV/IR/invalid.mlir

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
index aad50175546a5..6253601a7c2b2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
@@ -220,7 +220,7 @@ def SPIRV_LoadOp : SPIRV_Op<"Load", []> {
   let arguments = (ins
     SPIRV_AnyPtr:$ptr,
     OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
-    OptionalAttr<I32Attr>:$alignment
+    OptionalAttr<IntValidAlignment<I32Attr>>:$alignment
   );
 
   let results = (outs
@@ -345,7 +345,7 @@ def SPIRV_StoreOp : SPIRV_Op<"Store", []> {
     SPIRV_AnyPtr:$ptr,
     SPIRV_Type:$value,
     OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
-    OptionalAttr<I32Attr>:$alignment
+    OptionalAttr<IntValidAlignment<I32Attr>>:$alignment
   );
 
   let results = (outs);
diff --git a/mlir/test/Dialect/SPIRV/IR/invalid.mlir b/mlir/test/Dialect/SPIRV/IR/invalid.mlir
new file mode 100644
index 0000000000000..72eb9883a6538
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/invalid.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s 
+
+//===----------------------------------------------------------------------===//
+// spirv.LoadOp
+//===----------------------------------------------------------------------===//
+
+func.func @aligned_load_non_positive() -> () {
+  %0 = spirv.Variable : !spirv.ptr<f32, Function>
+  // expected-error at below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  %1 = spirv.Load "Function" %0 ["Aligned", 0] : f32
+  return
+}
+
+// -----
+
+func.func @aligned_load_non_power_of_two() -> () {
+  %0 = spirv.Variable : !spirv.ptr<f32, Function>
+  // expected-error at below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  %1 = spirv.Load "Function" %0 ["Aligned", 3] : f32
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.StoreOp
+//===----------------------------------------------------------------------===//
+
+func.func @aligned_store_non_positive(%arg0 : f32) -> () {
+  %0 = spirv.Variable : !spirv.ptr<f32, Function>
+  // expected-error at below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  spirv.Store  "Function" %0, %arg0 ["Aligned", 0] : f32
+  return
+}
+
+// -----
+
+func.func @aligned_store_non_power_of_two(%arg0 : f32) -> () {
+  %0 = spirv.Variable : !spirv.ptr<f32, Function>
+  // expected-error at below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  spirv.Store  "Function" %0, %arg0 ["Aligned", 3] : f32
+  return
+}

>From 9fd4b1e5e470ce19b7ebd5564a02caf2a955b846 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 25 Aug 2025 11:10:19 -0700
Subject: [PATCH 3/3] [mlir] Propagate alignment attribute in VectorToSPIRV.

---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 42 +++++++++++++++++--
 .../VectorToSPIRV/vector-to-spirv.mlir        | 17 ++++++++
 2 files changed, 56 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index a4be7d4bb5473..5f5d3e0834dad 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -743,6 +743,23 @@ struct VectorLoadOpConverter final
 
     auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
 
+    auto alignment = loadOp.getAlignment();
+    if (alignment.has_value() &&
+        alignment > std::numeric_limits<uint32_t>::max()) {
+      return rewriter.notifyMatchFailure(loadOp,
+                                         "invalid alignment requirement");
+    }
+
+    auto memoryAccess = spirv::MemoryAccess::None;
+    auto memoryAccessAttr = spirv::MemoryAccessAttr{};
+    IntegerAttr alignmentAttr = nullptr;
+    if (alignment.has_value()) {
+      memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
+      memoryAccessAttr =
+          spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
+      alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
+    }
+
     // For single element vectors, we don't need to bitcast the access chain to
     // the original vector type. Both is going to be the same, a pointer
     // to a scalar.
@@ -753,7 +770,8 @@ struct VectorLoadOpConverter final
                                        accessChain);
 
     rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
-                                               castedAccessChain);
+                                               castedAccessChain,
+                                               memoryAccessAttr, alignmentAttr);
 
     return success();
   }
@@ -782,6 +800,12 @@ struct VectorStoreOpConverter final
       return rewriter.notifyMatchFailure(
           storeOp, "failed to get memref element pointer");
 
+    auto alignment = storeOp.getAlignment();
+    if (alignment && alignment > std::numeric_limits<uint32_t>::max()) {
+      return rewriter.notifyMatchFailure(storeOp,
+                                         "invalid alignment requirement");
+    }
+
     spirv::StorageClass storageClass = attr.getValue();
     auto vectorType = storeOp.getVectorType();
     auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
@@ -795,8 +819,20 @@ struct VectorStoreOpConverter final
             : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
                                        accessChain);
 
-    rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
-                                                adaptor.getValueToStore());
+    auto memoryAccess = spirv::MemoryAccess::None;
+    auto memoryAccessAttr = spirv::MemoryAccessAttr{};
+    IntegerAttr alignmentAttr = nullptr;
+    if (alignment.has_value()) {
+      memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
+      memoryAccessAttr =
+          spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
+      alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
+    }
+
+    // For single element vectors, we don't need to bitcast the access chain to
+    rewriter.replaceOpWithNewOp<spirv::StoreOp>(
+        storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
+        alignmentAttr);
 
     return success();
   }
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 8918f91ef9145..4b56897821dbb 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -953,6 +953,14 @@ func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<St
   return %0: vector<1xf32>
 }
 
+// CHECK-LABEL: @vector_load_aligned
+func.func @vector_load_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+  %idx = arith.constant 0 : index
+  // CHECK: spirv.Load
+  // CHECK-SAME: ["Aligned", 8]
+  %0 = vector.load %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+  return %0: vector<4xf32>
+}
 
 // CHECK-LABEL: @vector_load_2d
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
@@ -996,6 +1004,15 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
   return
 }
 
+// CHECK-LABEL: @vector_store_aligned
+func.func @vector_store_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+  %idx = arith.constant 0 : index
+  // CHECK: spirv.Store
+  // CHECK-SAME: ["Aligned", 8]
+  vector.store %arg1, %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+  return
+}
+
 // CHECK-LABEL: @vector_store_single_elem
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
 //  CHECK-SAME:  %[[ARG1:.*]]: vector<1xf32>



More information about the Mlir-commits mailing list