[Mlir-commits] [mlir] ffbe9cf - [mlir][spirv] Propagate alignment requirements from vector to spirv (#155278)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 28 10:06:19 PDT 2025
Author: Erick Ochoa Lopez
Date: 2025-08-28T13:06:16-04:00
New Revision: ffbe9cf99b72326ba703940759218cbffc360f1f
URL: https://github.com/llvm/llvm-project/commit/ffbe9cf99b72326ba703940759218cbffc360f1f
DIFF: https://github.com/llvm/llvm-project/commit/ffbe9cf99b72326ba703940759218cbffc360f1f.diff
LOG: [mlir][spirv] Propagate alignment requirements from vector to spirv (#155278)
Propagates the alignment attribute from `vector.{load,store}` to
`spirv.{load,store}`.
---------
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
Added:
mlir/test/Dialect/SPIRV/IR/invalid.mlir
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
index aad50175546a5..6253601a7c2b2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
@@ -220,7 +220,7 @@ def SPIRV_LoadOp : SPIRV_Op<"Load", []> {
let arguments = (ins
SPIRV_AnyPtr:$ptr,
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
- OptionalAttr<I32Attr>:$alignment
+ OptionalAttr<IntValidAlignment<I32Attr>>:$alignment
);
let results = (outs
@@ -345,7 +345,7 @@ def SPIRV_StoreOp : SPIRV_Op<"Store", []> {
SPIRV_AnyPtr:$ptr,
SPIRV_Type:$value,
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
- OptionalAttr<I32Attr>:$alignment
+ OptionalAttr<IntValidAlignment<I32Attr>>:$alignment
);
let results = (outs);
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index a4be7d4bb5473..036cbad0bcfe8 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -743,6 +743,22 @@ struct VectorLoadOpConverter final
auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
+ std::optional<uint64_t> alignment = loadOp.getAlignment();
+ if (alignment > std::numeric_limits<uint32_t>::max()) {
+ return rewriter.notifyMatchFailure(loadOp,
+ "invalid alignment requirement");
+ }
+
+ auto memoryAccess = spirv::MemoryAccess::None;
+ spirv::MemoryAccessAttr memoryAccessAttr;
+ IntegerAttr alignmentAttr;
+ if (alignment.has_value()) {
+ memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
+ memoryAccessAttr =
+ spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
+ alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
+ }
+
// For single element vectors, we don't need to bitcast the access chain to
// the original vector type. Both is going to be the same, a pointer
// to a scalar.
@@ -753,7 +769,8 @@ struct VectorLoadOpConverter final
accessChain);
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
- castedAccessChain);
+ castedAccessChain,
+ memoryAccessAttr, alignmentAttr);
return success();
}
@@ -782,6 +799,12 @@ struct VectorStoreOpConverter final
return rewriter.notifyMatchFailure(
storeOp, "failed to get memref element pointer");
+ std::optional<uint64_t> alignment = storeOp.getAlignment();
+ if (alignment > std::numeric_limits<uint32_t>::max()) {
+ return rewriter.notifyMatchFailure(storeOp,
+ "invalid alignment requirement");
+ }
+
spirv::StorageClass storageClass = attr.getValue();
auto vectorType = storeOp.getVectorType();
auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
@@ -795,8 +818,19 @@ struct VectorStoreOpConverter final
: spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
accessChain);
- rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
- adaptor.getValueToStore());
+ auto memoryAccess = spirv::MemoryAccess::None;
+ spirv::MemoryAccessAttr memoryAccessAttr;
+ IntegerAttr alignmentAttr;
+ if (alignment.has_value()) {
+ memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
+ memoryAccessAttr =
+ spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
+ alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
+ }
+
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(
+ storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
+ alignmentAttr);
return success();
}
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 8918f91ef9145..4b56897821dbb 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -953,6 +953,14 @@ func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<St
return %0: vector<1xf32>
}
+// CHECK-LABEL: @vector_load_aligned
+func.func @vector_load_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+ %idx = arith.constant 0 : index
+ // CHECK: spirv.Load
+ // CHECK-SAME: ["Aligned", 8]
+ %0 = vector.load %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+ return %0: vector<4xf32>
+}
// CHECK-LABEL: @vector_load_2d
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
@@ -996,6 +1004,15 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
return
}
+// CHECK-LABEL: @vector_store_aligned
+func.func @vector_store_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+ %idx = arith.constant 0 : index
+ // CHECK: spirv.Store
+ // CHECK-SAME: ["Aligned", 8]
+ vector.store %arg1, %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+ return
+}
+
// CHECK-LABEL: @vector_store_single_elem
// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
// CHECK-SAME: %[[ARG1:.*]]: vector<1xf32>
diff --git a/mlir/test/Dialect/SPIRV/IR/invalid.mlir b/mlir/test/Dialect/SPIRV/IR/invalid.mlir
new file mode 100644
index 0000000000000..e0100748a0d68
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/invalid.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt --split-input-file --verify-diagnostics %s
+
+//===----------------------------------------------------------------------===//
+// spirv.LoadOp
+//===----------------------------------------------------------------------===//
+
+func.func @aligned_load_non_positive() -> () {
+ %0 = spirv.Variable : !spirv.ptr<f32, Function>
+ // expected-error at below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %1 = spirv.Load "Function" %0 ["Aligned", 0] : f32
+ return
+}
+
+// -----
+
+func.func @aligned_load_non_power_of_two() -> () {
+ %0 = spirv.Variable : !spirv.ptr<f32, Function>
+ // expected-error at below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %1 = spirv.Load "Function" %0 ["Aligned", 3] : f32
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.StoreOp
+//===----------------------------------------------------------------------===//
+
+func.func @aligned_store_non_positive(%arg0 : f32) -> () {
+ %0 = spirv.Variable : !spirv.ptr<f32, Function>
+ // expected-error at below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ spirv.Store "Function" %0, %arg0 ["Aligned", 0] : f32
+ return
+}
+
+// -----
+
+func.func @aligned_store_non_power_of_two(%arg0 : f32) -> () {
+ %0 = spirv.Variable : !spirv.ptr<f32, Function>
+ // expected-error at below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ spirv.Store "Function" %0, %arg0 ["Aligned", 3] : f32
+ return
+}
More information about the Mlir-commits
mailing list