[Mlir-commits] [mlir] [mlir][spirv] Add support for Constant Matrices (PR #123334)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 17 05:15:52 PST 2025
https://github.com/mishaobu created https://github.com/llvm/llvm-project/pull/123334
None
>From def92f0833bdd32dbfc92022370b065ffd74b332 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Fri, 17 Jan 2025 14:10:17 +0100
Subject: [PATCH 1/4] tests
---
mlir/test/Dialect/SPIRV/IR/composite-ops.mlir | 29 +++++++++++++++++++
mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 28 +++++++++++++++++-
mlir/test/Target/SPIRV/composite-op.mlir | 5 ++++
mlir/test/Target/SPIRV/constant.mlir | 11 +++++++
4 files changed, 72 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index 3fc8dfb2767d1e..5c835d2e08de91 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -11,6 +11,13 @@ func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> ve
return %0: vector<3xf32>
}
+// CHECK-LABEL: func @composite_construct_matrix
+func.func @composite_construct_matrix(%v1: vector<3xf32>, %v2: vector<3xf32>, %v3: vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> {
+ // CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
+ %0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
+ return %0: !spirv.matrix<3 x vector<3xf32>>
+}
+
// CHECK-LABEL: func @composite_construct_struct
func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spirv.array<4xf32>, %arg2 : !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> {
// CHECK: spirv.CompositeConstruct
@@ -89,9 +96,31 @@ func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2
%0 = spirv.CompositeConstruct %arg0, %arg2 : (f32, vector<2xf32>) -> vector<4xf32>
return %0: vector<4xf32>
}
+// -----
+
+func.func @composite_construct_matrix_wrong_column_count(%v1: vector<3xf32>, %v2: vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> {
+ // expected-error @+1 {{'spirv.CompositeConstruct' op expected to return a vector or cooperative matrix when the number of constituents is less than what the result needs}}
+ %0 = spirv.CompositeConstruct %v1, %v2 : (vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
+ return %0: !spirv.matrix<3 x vector<3xf32>>
+}
+
+// -----
+
+func.func @composite_construct_matrix_wrong_row_count(%v1: vector<4xf32>, %v2: vector<4xf32>, %v3: vector<4xf32>) -> !spirv.matrix<3 x vector<3xf32>> {
+ // expected-error @+1 {{operand type mismatch: expected operand type 'vector<3xf32>', but provided 'vector<4xf32>'}}
+ %0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> !spirv.matrix<3 x vector<3xf32>>
+ return %0: !spirv.matrix<3 x vector<3xf32>>
+}
// -----
+func.func @composite_construct_matrix_wrong_element_type(%v1: vector<3xi32>, %v2: vector<3xi32>, %v3: vector<3xi32>) -> !spirv.matrix<3 x vector<3xf32>> {
+ // expected-error @+1 {{operand type mismatch: expected operand type 'vector<3xf32>', but provided 'vector<3xi32>'}}
+ %0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<3xi32>, vector<3xi32>, vector<3xi32>) -> !spirv.matrix<3 x vector<3xf32>>
+ return %0: !spirv.matrix<3 x vector<3xf32>>
+}
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.CompositeExtractOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 5e98b9fdb3c546..6003d2a3576b12 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -62,6 +62,7 @@ func.func @const() -> () {
// CHECK: spirv.Constant dense<1.000000e+00> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
// CHECK: spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
// CHECK: spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
+ // CHECK: spirv.Constant [dense<1.000000e+00> : vector<3xf32>, dense<2.000000e+00> : vector<3xf32>, dense<3.000000e+00> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
%0 = spirv.Constant true
%1 = spirv.Constant 42 : i32
@@ -73,6 +74,7 @@ func.func @const() -> () {
%7 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
%8 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
%9 = spirv.Constant [[dense<3.0> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<1xvector<2xf32>>>
+ %10 = spirv.Constant [dense<1.0> : vector<3xf32>, dense<2.0> : vector<3xf32>, dense<3.0> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
return
}
@@ -95,7 +97,7 @@ func.func @array_constant() -> () {
// -----
func.func @array_constant() -> () {
- // expected-error @+1 {{must have spirv.array result type for array value}}
+ // expected-error @+1 {{'spirv.Constant' op must have spirv.array or spirv.matrix result type for array value}}
%0 = spirv.Constant [dense<3.0> : vector<2xf32>] : !spirv.rtarray<vector<2xf32>>
return
}
@@ -132,6 +134,30 @@ func.func @value_result_num_elements_mismatch() -> () {
// -----
+func.func @matrix_constant() -> () {
+ // CHECK: spirv.Constant [dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : vector<3xf32>, dense<[4.000000e+00, 5.000000e+00, 6.000000e+00]> : vector<3xf32>, dense<[7.000000e+00, 8.000000e+00, 9.000000e+00]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
+ %0 = spirv.Constant [dense<[1.0, 2.0, 3.0]> : vector<3xf32>, dense<[4.0, 5.0, 6.0]> : vector<3xf32>, dense<[7.0, 8.0, 9.0]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
+ return
+}
+
+// -----
+
+func.func @matrix_constant_wrong_column_count() -> () {
+ // expected-error @+1 {{expected 3 columns in matrix constant, but got 2}}
+ %0 = spirv.Constant [dense<1.0> : vector<3xf32>, dense<2.0> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
+ return
+}
+
+// -----
+
+func.func @matrix_constant_non_dense_column() -> () {
+ // expected-error @+1 {{matrix column #1 must be a DenseElementsAttr}}
+ %0 = spirv.Constant [dense<1.0> : vector<3xf32>, "wrong", dense<3.0> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
+ return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.EntryPoint
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/composite-op.mlir b/mlir/test/Target/SPIRV/composite-op.mlir
index 5f302fd0d38f8b..bafdb3340d0e79 100644
--- a/mlir/test/Target/SPIRV/composite-op.mlir
+++ b/mlir/test/Target/SPIRV/composite-op.mlir
@@ -11,6 +11,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32>
spirv.ReturnValue %0: vector<3xf32>
}
+ spirv.func @composite_construct_matrix(%v1: vector<3xf32>, %v2: vector<3xf32>, %v3: vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> "None" {
+ // CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
+ %0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
+ spirv.ReturnValue %0: !spirv.matrix<3 x vector<3xf32>>
+ }
spirv.func @vector_dynamic_extract(%vec: vector<4xf32>, %id : i32) -> f32 "None" {
// CHECK: spirv.VectorExtractDynamic %{{.*}}[%{{.*}}] : vector<4xf32>, i32
%0 = spirv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index f3950214a7f055..0fa70c7e5cdbb3 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -198,6 +198,17 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.Return
}
+ // CHECK-LABEL: @matrix_const
+ spirv.func @matrix_const() -> () "None" {
+ // CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<!spirv.matrix<3 x vector<3xf32>>, Function>
+ %0 = spirv.Variable : !spirv.ptr<!spirv.matrix<3 x vector<3xf32>>, Function>
+ // CHECK: %[[CST:.*]] = spirv.Constant [dense<[1.000000e+00, 0.000000e+00, 0.000000e+00]> : vector<3xf32>, dense<[0.000000e+00, 1.000000e+00, 0.000000e+00]> : vector<3xf32>, dense<[0.000000e+00, 0.000000e+00, 1.000000e+00]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
+ %1 = spirv.Constant [dense<[1., 0., 0.]> : vector<3xf32>, dense<[0., 1., 0.]> : vector<3xf32>, dense<[0., 0., 1.]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
+ // CHECK: spirv.Store "Function" %[[VAR]], %[[CST]] : !spirv.matrix<3 x vector<3xf32>>
+ spirv.Store "Function" %0, %1 : !spirv.matrix<3 x vector<3xf32>>
+ spirv.Return
+ }
+
// CHECK-LABEL: @ui64_array_const
spirv.func @ui64_array_const() -> (!spirv.array<3xui64>) "None" {
// CHECK: spirv.Constant [5, 6, 7] : !spirv.array<3 x i64>
>From 9cae5b69f94aaa70f23a80734104eb385794a5aa Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Fri, 17 Jan 2025 14:10:46 +0100
Subject: [PATCH 2/4] allow deser of const matrix type
---
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 04469f1933819b..ecc822e553aefc 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1442,6 +1442,9 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
} else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
auto attr = opBuilder.getArrayAttr(elements);
constantMap.try_emplace(resultID, attr, resultType);
+ } else if (auto matrixType = dyn_cast<spirv::MatrixType>(resultType)) {
+ auto attr = opBuilder.getArrayAttr(elements);
+ constantMap.try_emplace(resultID, attr, resultType);
} else {
return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
<< resultType;
>From 5a9850e51a8373453323d94d1e1ec8d5d45a86d1 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Fri, 17 Jan 2025 14:11:18 +0100
Subject: [PATCH 3/4] print and verify const matrices
---
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 55 ++++++++++++++++++++------
1 file changed, 43 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 26559c1321db5e..ee7c7860b05c4e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -579,7 +579,7 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
void spirv::ConstantOp::print(OpAsmPrinter &printer) {
printer << ' ' << getValue();
- if (llvm::isa<spirv::ArrayType>(getType()))
+ if (llvm::isa<spirv::ArrayType, spirv::MatrixType>(getType()))
printer << " : " << getType();
}
@@ -626,18 +626,49 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
}
return success();
}
- if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
- auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
- if (!arrayType)
- return op.emitOpError(
- "must have spirv.array result type for array value");
- Type elemType = arrayType.getElementType();
- for (Attribute element : arrayAttr.getValue()) {
- // Verify array elements recursively.
- if (failed(verifyConstantType(op, element, elemType)))
- return failure();
+ if (auto arrayAttr = mlir::dyn_cast<ArrayAttr>(value)) {
+ // Case for Matrix result type
+ if (auto matrixType = mlir::dyn_cast<spirv::MatrixType>(opType)) {
+ unsigned numColumns = matrixType.getNumColumns();
+ unsigned numRows = matrixType.getNumRows();
+ if (arrayAttr.size() != numColumns)
+ return op.emitOpError("expected ")
+ << numColumns << " columns in matrix constant, but got "
+ << arrayAttr.size();
+
+ Type elementTy = matrixType.getElementType();
+ for (auto [colIndex, colAttr] : llvm::enumerate(arrayAttr)) {
+ // Ensure each column is a dense array of the right shape/type
+ auto denseAttr = mlir::dyn_cast<DenseElementsAttr>(colAttr);
+ if (!denseAttr)
+ return op.emitOpError("matrix column #")
+ << colIndex << " must be a DenseElementsAttr";
+
+ auto shapedTy = mlir::dyn_cast<ShapedType>(denseAttr.getType());
+ if (!shapedTy || shapedTy.getNumElements() != numRows)
+ return op.emitOpError("matrix column #")
+ << colIndex << " has incorrect size: expected "
+ << numRows << " elements";
+
+ if (shapedTy.getElementType() != elementTy)
+ return op.emitOpError("matrix column #")
+ << colIndex << " has incorrect element type: expected "
+ << elementTy << ", got " << shapedTy.getElementType();
+ }
+ return success();
}
- return success();
+ // Case for Array result type
+ if (auto arrayType = mlir::dyn_cast<spirv::ArrayType>(opType)) {
+ Type elemType = arrayType.getElementType();
+ for (Attribute element : arrayAttr.getValue()) {
+ // Verify array elements recursively.
+ if (failed(verifyConstantType(op, element, elemType)))
+ return failure();
+ }
+ return success();
+ }
+ return op.emitOpError(
+ "must have spirv.array or spirv.matrix result type for array value");
}
return op.emitOpError("cannot have attribute: ") << value;
}
>From 087941b536d9ecdb4bd96beb6459da6243340bc8 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Fri, 17 Jan 2025 14:11:38 +0100
Subject: [PATCH 4/4] handle const matrix serialization
---
.../Target/SPIRV/Serialization/Serializer.cpp | 16 ++++++++++++++++
1 file changed, 16 insertions(+)
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 1f4f5d7f764db3..b5e3cd381ef822 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -782,6 +782,22 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
SmallVector<uint64_t, 4> index(rank);
resultID = prepareDenseElementsConstant(loc, constType, attr,
/*dim=*/0, index);
+ } else if (isa<spirv::MatrixType>(constType)) {
+ if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+ resultID = getNextID();
+ SmallVector<uint32_t, 4> operands = {typeID, resultID};
+ operands.reserve(arrayAttr.size() + 2);
+ for (Attribute elementAttr : arrayAttr) {
+ if (auto elementID = prepareConstant(loc,
+ cast<spirv::MatrixType>(constType).getColumnType(), elementAttr)) {
+ operands.push_back(elementID);
+ } else {
+ return 0;
+ }
+ }
+ spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
+ encodeInstructionInto(typesGlobalValues, opcode, operands);
+ }
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
resultID = prepareArrayConstant(loc, constType, arrayAttr);
}
More information about the Mlir-commits
mailing list