[Mlir-commits] [mlir] [mlir][spirv] Propagate alignment requirements from vector to spirv (PR #155278)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Mon Aug 25 12:16:19 PDT 2025
https://github.com/amd-eochoalo created https://github.com/llvm/llvm-project/pull/155278
None
>From f0de6ef4edffe02f1f9047c92f24314752a9d850 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 25 Aug 2025 08:26:08 -0700
Subject: [PATCH 1/3] Add IntValidAlignment predicate
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 24 +++++++------------
mlir/include/mlir/IR/CommonAttrConstraints.td | 4 ++++
2 files changed, 12 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index bcc423a634148..fd1c0159e3aba 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1720,8 +1720,7 @@ def Vector_LoadOp : Vector_Op<"load", [
[MemRead]>:$base,
Variadic<Index>:$indices,
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
- ConfinedAttr<OptionalAttr<I64Attr>,
- [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
+ OptionalAttr<IntValidAlignment<I64Attr>>: $alignment);
let builders = [
OpBuilder<(ins "VectorType":$resultType,
@@ -1837,8 +1836,7 @@ def Vector_StoreOp : Vector_Op<"store", [
[MemWrite]>:$base,
Variadic<Index>:$indices,
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
- ConfinedAttr<OptionalAttr<I64Attr>,
- [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
+ OptionalAttr<IntValidAlignment<I64Attr>>: $alignment);
let builders = [
OpBuilder<(ins "Value":$valueToStore,
@@ -1875,8 +1873,7 @@ def Vector_MaskedLoadOp :
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$pass_thru,
- ConfinedAttr<OptionalAttr<I64Attr>,
- [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
+ OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)>,
Results<(outs AnyVectorOfNonZeroRank:$result)> {
let summary = "loads elements from memory into a vector as defined by a mask vector";
@@ -1968,8 +1965,7 @@ def Vector_MaskedStoreOp :
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$valueToStore,
- ConfinedAttr<OptionalAttr<I64Attr>,
- [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> {
+ OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)> {
let summary = "stores elements from a vector into memory as defined by a mask vector";
@@ -2051,8 +2047,7 @@ def Vector_GatherOp :
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
VectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$pass_thru,
- ConfinedAttr<OptionalAttr<I64Attr>,
- [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
+ OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)>,
Results<(outs AnyVectorOfNonZeroRank:$result)> {
let summary = [{
@@ -2154,8 +2149,7 @@ def Vector_ScatterOp :
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
VectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$valueToStore,
- ConfinedAttr<OptionalAttr<I64Attr>,
- [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> {
+ OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)> {
let summary = [{
scatters elements from a vector into memory as defined by an index vector
@@ -2239,8 +2233,7 @@ def Vector_ExpandLoadOp :
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$pass_thru,
- ConfinedAttr<OptionalAttr<I64Attr>,
- [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)>,
+ OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)>,
Results<(outs AnyVectorOfNonZeroRank:$result)> {
let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
@@ -2328,8 +2321,7 @@ def Vector_CompressStoreOp :
Variadic<Index>:$indices,
FixedVectorOfNonZeroRankOf<[I1]>:$mask,
AnyVectorOfNonZeroRank:$valueToStore,
- ConfinedAttr<OptionalAttr<I64Attr>,
- [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment)> {
+ OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)> {
let summary = "writes elements selectively from a vector as defined by a mask";
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 18da85a580710..e1869c1821b11 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -800,6 +800,10 @@ def IntPowerOf2 : AttrConstraint<
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isPowerOf2()">,
"whose value is a power of two > 0">;
+def IntPositivePowerOf2 : AllAttrOf<[IntPositive, IntPowerOf2]>;
+
+class IntValidAlignment<Attr attr>: ConfinedAttr<attr, [IntPositivePowerOf2]>;
+
class ArrayMaxCount<int n> : AttrConstraint<
CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() <= " # n>,
"with at most " # n # " elements">;
>From a2b8f6b40c5c078a3a813c5983a774b0d3f0f432 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 25 Aug 2025 11:25:59 -0700
Subject: [PATCH 2/3] [mlir][spirv] Constraint alignment attribute
---
.../mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td | 4 +-
mlir/test/Dialect/SPIRV/IR/invalid.mlir | 43 +++++++++++++++++++
2 files changed, 45 insertions(+), 2 deletions(-)
create mode 100644 mlir/test/Dialect/SPIRV/IR/invalid.mlir
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
index aad50175546a5..6253601a7c2b2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td
@@ -220,7 +220,7 @@ def SPIRV_LoadOp : SPIRV_Op<"Load", []> {
let arguments = (ins
SPIRV_AnyPtr:$ptr,
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
- OptionalAttr<I32Attr>:$alignment
+ OptionalAttr<IntValidAlignment<I32Attr>>:$alignment
);
let results = (outs
@@ -345,7 +345,7 @@ def SPIRV_StoreOp : SPIRV_Op<"Store", []> {
SPIRV_AnyPtr:$ptr,
SPIRV_Type:$value,
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access,
- OptionalAttr<I32Attr>:$alignment
+ OptionalAttr<IntValidAlignment<I32Attr>>:$alignment
);
let results = (outs);
diff --git a/mlir/test/Dialect/SPIRV/IR/invalid.mlir b/mlir/test/Dialect/SPIRV/IR/invalid.mlir
new file mode 100644
index 0000000000000..72eb9883a6538
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/invalid.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+
+//===----------------------------------------------------------------------===//
+// spirv.LoadOp
+//===----------------------------------------------------------------------===//
+
+func.func @aligned_load_non_positive() -> () {
+ %0 = spirv.Variable : !spirv.ptr<f32, Function>
+ // expected-error at below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %1 = spirv.Load "Function" %0 ["Aligned", 0] : f32
+ return
+}
+
+// -----
+
+func.func @aligned_load_non_power_of_two() -> () {
+ %0 = spirv.Variable : !spirv.ptr<f32, Function>
+ // expected-error at below {{'spirv.Load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ %1 = spirv.Load "Function" %0 ["Aligned", 3] : f32
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.StoreOp
+//===----------------------------------------------------------------------===//
+
+func.func @aligned_store_non_positive(%arg0 : f32) -> () {
+ %0 = spirv.Variable : !spirv.ptr<f32, Function>
+ // expected-error at below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ spirv.Store "Function" %0, %arg0 ["Aligned", 0] : f32
+ return
+}
+
+// -----
+
+func.func @aligned_store_non_power_of_two(%arg0 : f32) -> () {
+ %0 = spirv.Variable : !spirv.ptr<f32, Function>
+ // expected-error at below {{'spirv.Store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
+ spirv.Store "Function" %0, %arg0 ["Aligned", 3] : f32
+ return
+}
>From 9fd4b1e5e470ce19b7ebd5564a02caf2a955b846 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Mon, 25 Aug 2025 11:10:19 -0700
Subject: [PATCH 3/3] [mlir] Propagate alignment attribute in VectorToSPIRV.
---
.../VectorToSPIRV/VectorToSPIRV.cpp | 42 +++++++++++++++++--
.../VectorToSPIRV/vector-to-spirv.mlir | 17 ++++++++
2 files changed, 56 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index a4be7d4bb5473..5f5d3e0834dad 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -743,6 +743,23 @@ struct VectorLoadOpConverter final
auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
+ auto alignment = loadOp.getAlignment();
+ if (alignment.has_value() &&
+ alignment > std::numeric_limits<uint32_t>::max()) {
+ return rewriter.notifyMatchFailure(loadOp,
+ "invalid alignment requirement");
+ }
+
+ auto memoryAccess = spirv::MemoryAccess::None;
+ auto memoryAccessAttr = spirv::MemoryAccessAttr{};
+ IntegerAttr alignmentAttr = nullptr;
+ if (alignment.has_value()) {
+ memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
+ memoryAccessAttr =
+ spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
+ alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
+ }
+
// For single element vectors, we don't need to bitcast the access chain to
// the original vector type. Both is going to be the same, a pointer
// to a scalar.
@@ -753,7 +770,8 @@ struct VectorLoadOpConverter final
accessChain);
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
- castedAccessChain);
+ castedAccessChain,
+ memoryAccessAttr, alignmentAttr);
return success();
}
@@ -782,6 +800,12 @@ struct VectorStoreOpConverter final
return rewriter.notifyMatchFailure(
storeOp, "failed to get memref element pointer");
+ auto alignment = storeOp.getAlignment();
+ if (alignment && alignment > std::numeric_limits<uint32_t>::max()) {
+ return rewriter.notifyMatchFailure(storeOp,
+ "invalid alignment requirement");
+ }
+
spirv::StorageClass storageClass = attr.getValue();
auto vectorType = storeOp.getVectorType();
auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
@@ -795,8 +819,20 @@ struct VectorStoreOpConverter final
: spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
accessChain);
- rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
- adaptor.getValueToStore());
+ auto memoryAccess = spirv::MemoryAccess::None;
+ auto memoryAccessAttr = spirv::MemoryAccessAttr{};
+ IntegerAttr alignmentAttr = nullptr;
+ if (alignment.has_value()) {
+ memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
+ memoryAccessAttr =
+ spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
+ alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
+ }
+
+ // For single element vectors, we don't need to bitcast the access chain to
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(
+ storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
+ alignmentAttr);
return success();
}
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 8918f91ef9145..4b56897821dbb 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -953,6 +953,14 @@ func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<St
return %0: vector<1xf32>
}
+// CHECK-LABEL: @vector_load_aligned
+func.func @vector_load_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
+ %idx = arith.constant 0 : index
+ // CHECK: spirv.Load
+ // CHECK-SAME: ["Aligned", 8]
+ %0 = vector.load %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+ return %0: vector<4xf32>
+}
// CHECK-LABEL: @vector_load_2d
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
@@ -996,6 +1004,15 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
return
}
+// CHECK-LABEL: @vector_store_aligned
+func.func @vector_store_aligned(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
+ %idx = arith.constant 0 : index
+ // CHECK: spirv.Store
+ // CHECK-SAME: ["Aligned", 8]
+ vector.store %arg1, %arg0[%idx] { alignment = 8 } : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
+ return
+}
+
// CHECK-LABEL: @vector_store_single_elem
// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
// CHECK-SAME: %[[ARG1:.*]]: vector<1xf32>
More information about the Mlir-commits
mailing list