[Mlir-commits] [mlir] [mlir][spirv] Make `CooperativeMatrixType` a `ShapedType` (PR #142784)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 4 08:16:57 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Igor Wodiany (IgWod-IMG)
<details>
<summary>Changes</summary>
This is to enable `CooperativeMatrixType` to be used with `DenseElementsAttr`, so that a `spirv.Constant` can be easily built from `OpConstantComposite`. For example:
```mlir
%cst = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<1x1xf32, Subgroup, MatrixAcc>
```
Additional constraints are added to arithmetic operations, as `SameOperandsAndResultType` can no longer fully verify CoopMatrices. This is because for shaped types the verifier only checks element type and shapes, whereas for any other arbitrary type it looks for an exact match.
This patch does not enable the actual deserialization. This is done in #<!-- -->142786.
---
Full diff: https://github.com/llvm/llvm-project/pull/142784.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td (+12-2)
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h (+19-1)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (+14-8)
- (modified) mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir (+2-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 22d5afcd77381..48f525e048e60 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -18,12 +18,21 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+class SPIRV_SameCoopMatrix<string lhs, string rhs> : PredOpTrait<
+ "cooperative matrix types match",
+ CPred<"(::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # lhs # ".getType()) "
+ "&& ::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # rhs # ".getType()))"
+ "? $" # lhs # ".getType() == $" # rhs # ".getType() : true">
+>;
+
class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
list<Trait> traits = []> :
// Operands type same as result type.
SPIRV_BinaryOp<mnemonic, type, type,
!listconcat(traits,
- [Pure, SameOperandsAndResultType])> {
+ [Pure, SameOperandsAndResultType,
+ SPIRV_SameCoopMatrix<"operand1", "operand2">,
+ SPIRV_SameCoopMatrix<"operand2", "result">])> {
// In addition to normal types arithmetic instructions can support cooperative
// matrix.
let arguments = (ins
@@ -42,7 +51,8 @@ class SPIRV_ArithmeticUnaryOp<string mnemonic, Type type,
// Operand type same as result type.
SPIRV_UnaryOp<mnemonic, type, type,
!listconcat(traits,
- [Pure, SameOperandsAndResultType])> {
+ [Pure, SameOperandsAndResultType,
+ SPIRV_SameCoopMatrix<"operand", "result">])> {
// In addition to normal types arithmetic instructions can support cooperative
// matrix.
let arguments = (ins
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 2e29e9afaabf4..a7b6569245dd5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -394,7 +394,8 @@ hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
// SPIR-V KHR cooperative matrix type
class CooperativeMatrixType
: public Type::TypeBase<CooperativeMatrixType, CompositeType,
- detail::CooperativeMatrixTypeStorage> {
+ detail::CooperativeMatrixTypeStorage,
+ ShapedType::Trait> {
public:
using Base::Base;
@@ -418,6 +419,23 @@ class CooperativeMatrixType
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
+
+ operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+
+ ArrayRef<int64_t> getShape() const;
+
+ bool hasRank() const { return true; }
+
+ CooperativeMatrixType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const {
+ if (shape == std::nullopt)
+ return get(elementType, getRows(), getColumns(), getScope(), getUse());
+ else {
+ assert(shape.value().size() == 2);
+ return get(elementType, shape.value()[0], shape.value()[1], getScope(),
+ getUse());
+ }
+ }
};
// SPIR-V matrix type
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 337df3a5a65f0..de2034680cd5f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -195,7 +195,7 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
using KeyTy =
- std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>;
+ std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
static CooperativeMatrixTypeStorage *
construct(TypeStorageAllocator &allocator, const KeyTy &key) {
@@ -204,17 +204,17 @@ struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
}
bool operator==(const KeyTy &key) const {
- return key == KeyTy(elementType, rows, columns, scope, use);
+ return key == KeyTy(elementType, shape[0], shape[1], scope, use);
}
CooperativeMatrixTypeStorage(const KeyTy &key)
- : elementType(std::get<0>(key)), rows(std::get<1>(key)),
- columns(std::get<2>(key)), scope(std::get<3>(key)),
+ : elementType(std::get<0>(key)),
+ shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
use(std::get<4>(key)) {}
Type elementType;
- uint32_t rows;
- uint32_t columns;
+ // [#rows, #columns]
+ SmallVector<int64_t, 2> shape;
Scope scope;
CooperativeMatrixUseKHR use;
};
@@ -231,10 +231,16 @@ Type CooperativeMatrixType::getElementType() const {
return getImpl()->elementType;
}
-uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; }
+uint32_t CooperativeMatrixType::getRows() const {
+ return static_cast<uint32_t>(getImpl()->shape[0]);
+}
uint32_t CooperativeMatrixType::getColumns() const {
- return getImpl()->columns;
+ return static_cast<uint32_t>(getImpl()->shape[1]);
+}
+
+ArrayRef<int64_t> CooperativeMatrixType::getShape() const {
+ return getImpl()->shape;
}
Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index d3e1dbc229ef9..4ae8b70bf43ca 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -524,7 +524,7 @@ spirv.func @matrix_times_scalar(%a: !matA_f32, %b: f32) "None" {
spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
%b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) "None" {
- // expected-error @+1 {{op requires the same type for all operands and results}}
+ // expected-error @+1 {{op failed to verify that cooperative matrix types match}}
%q = "spirv.IAdd"(%a, %b) :
(!spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>)
-> !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
@@ -535,7 +535,7 @@ spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>,
%b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) "None" {
- // expected-error @+1 {{op requires the same type for all operands and results}}
+ // expected-error @+1 {{op failed to verify that cooperative matrix types match}}
%q = "spirv.FAdd"(%a, %b) :
(!spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>)
-> !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>
``````````
</details>
https://github.com/llvm/llvm-project/pull/142784
More information about the Mlir-commits
mailing list