[Mlir-commits] [mlir] ffbe9cf - [mlir][spirv] Propagate alignment requirements from vector to spirv (#155278)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 28 10:06:19 PDT 2025


Author: Erick Ochoa Lopez
Date: 2025-08-28T13:06:16-04:00
New Revision: ffbe9cf99b72326ba703940759218cbffc360f1f

URL: https://github.com/llvm/llvm-project/commit/ffbe9cf99b72326ba703940759218cbffc360f1f
DIFF: https://github.com/llvm/llvm-project/commit/ffbe9cf99b72326ba703940759218cbffc360f1f.diff

LOG: [mlir][spirv] Propagate alignment requirements from vector to spirv (#155278)

Propagates the alignment attribute from `vector.{load,store}` to
`spirv.{load,store}`.

---------

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>

Added: 
    mlir/test/Dialect/SPIRV/IR/invalid.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
index aad50175546a5..6253601a7c2b2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
@@ -220,7 +220,7 @@ def SPIRV_LoadOp : SPIRV_Op<"Load", []> {
   let arguments = (ins
     SPIRV_AnyPtr:$ptr,
     OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
-    OptionalAttr<I32Attr>:$alignment
+    OptionalAttr<IntValidAlignment<I32Attr>>:$alignment
   );
 
   let results = (outs
@@ -345,7 +345,7 @@ def SPIRV_StoreOp : SPIRV_Op<"Store", []> {
     SPIRV_AnyPtr:$ptr,
     SPIRV_Type:$value,
     OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
-    OptionalAttr<I32Attr>:$alignment
+    OptionalAttr<IntValidAlignment<I32Attr>>:$alignment
   );
 
   let results = (outs);

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index a4be7d4bb5473..036cbad0bcfe8 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -743,6 +743,22 @@ struct VectorLoadOpConverter final
 
     auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
 
+    std::optional<uint64_t> alignment = loadOp.getAlignment();
+    if (alignment > std::numeric_limits<uint32_t>::max()) {
+      return rewriter.notifyMatchFailure(loadOp,
+                                         "invalid alignment requirement");
+    }
+
+    auto memoryAccess = spirv::MemoryAccess::None;
+    spirv::MemoryAccessAttr memoryAccessAttr;
+    IntegerAttr alignmentAttr;
+    if (alignment.has_value()) {
+      memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
+      memoryAccessAttr =
+          spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
+      alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
+    }
+
     // For single element vectors, we don't need to bitcast the access chain to
     // the original vector type. Both is going to be the same, a pointer
     // to a scalar.
@@ -753,7 +769,8 @@ struct VectorLoadOpConverter final
                                        accessChain);
 
     rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
-                                               castedAccessChain);
+                                               castedAccessChain,
+                                               memoryAccessAttr, alignmentAttr);
 
     return success();
   }
@@ -782,6 +799,12 @@ struct VectorStoreOpConverter final
       return rewriter.notifyMatchFailure(
           storeOp, "failed to get memref element pointer");
 
+    std::optional<uint64_t> alignment = storeOp.getAlignment();
+    if (alignment > std::numeric_limits<uint32_t>::max()) {
+      return rewriter.notifyMatchFailure(storeOp,
+                                         "invalid alignment requirement");
+    }
+
     spirv::StorageClass storageClass = attr.getValue();
     auto vectorType = storeOp.getVectorType();
     auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
@@ -795,8 +818,19 @@ struct VectorStoreOpConverter final
             : spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
                                        accessChain);
 
-    rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
-                                                adaptor.getValueToStore());
+    auto memoryAccess = spirv::MemoryAccess::None;
+    spirv::MemoryAccessAttr memoryAccessAttr;
+    IntegerAttr alignmentAttr;
+    if (alignment.has_value()) {
+      memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
+      memoryAccessAttr =
+          spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
+      alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
+    }
+
+    rewriter.replaceOpWithNewOp<spirv::StoreOp>(
+        storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
+        alignmentAttr);
 
     return success();
   }

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 8918f91ef9145..4b56897821dbb 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -953,6 +953,14 @@ func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<St
   return %0: vector<1xf32>
 }
 
+// CHECK-LABEL: @vector_load_aligned
+func.func @vector_load_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+  %idx = arith.constant 0 : index
+  // CHECK: spirv.Load
+  // CHECK-SAME: ["Aligned", 8]
+  %0 = vector.load %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+  return %0: vector<4xf32>
+}
 
 // CHECK-LABEL: @vector_load_2d
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
@@ -996,6 +1004,15 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
   return
 }
 
+// CHECK-LABEL: @vector_store_aligned
+func.func @vector_store_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+  %idx = arith.constant 0 : index
+  // CHECK: spirv.Store
+  // CHECK-SAME: ["Aligned", 8]
+  vector.store %arg1, %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+  return
+}
+
 // CHECK-LABEL: @vector_store_single_elem
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
 //  CHECK-SAME:  %[[ARG1:.*]]: vector<1xf32>

diff  --git a/mlir/test/Dialect/SPIRV/IR/invalid.mlir b/mlir/test/Dialect/SPIRV/IR/invalid.mlir
new file mode 100644
index 0000000000000..e0100748a0d68
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/invalid.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt --split-input-file --verify-diagnostics %s
+
+//===----------------------------------------------------------------------===//
+// spirv.LoadOp
+//===----------------------------------------------------------------------===//
+
+func.func @aligned_load_non_positive() -> () {
+  %0 = spirv.Variable : !spirv.ptr<f32, Function>
+  // expected-error at below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  %1 = spirv.Load "Function" %0 ["Aligned", 0] : f32
+  return
+}
+
+// -----
+
+func.func @aligned_load_non_power_of_two() -> () {
+  %0 = spirv.Variable : !spirv.ptr<f32, Function>
+  // expected-error at below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  %1 = spirv.Load "Function" %0 ["Aligned", 3] : f32
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.StoreOp
+//===----------------------------------------------------------------------===//
+
+func.func @aligned_store_non_positive(%arg0 : f32) -> () {
+  %0 = spirv.Variable : !spirv.ptr<f32, Function>
+  // expected-error at below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  spirv.Store "Function" %0, %arg0 ["Aligned", 0] : f32
+  return
+}
+
+// -----
+
+func.func @aligned_store_non_power_of_two(%arg0 : f32) -> () {
+  %0 = spirv.Variable : !spirv.ptr<f32, Function>
+  // expected-error at below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+  spirv.Store "Function" %0, %arg0 ["Aligned", 3] : f32
+  return
+}


        


More information about the Mlir-commits mailing list