[Mlir-commits] [mlir] [mlir][spirv] Add support for Constant Matrices (PR #123334)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 17 09:06:35 PST 2025


https://github.com/mishaobu updated https://github.com/llvm/llvm-project/pull/123334

>From def92f0833bdd32dbfc92022370b065ffd74b332 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Fri, 17 Jan 2025 14:10:17 +0100
Subject: [PATCH 1/4] tests

---
 mlir/test/Dialect/SPIRV/IR/composite-ops.mlir | 29 +++++++++++++++++++
 mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 28 +++++++++++++++++-
 mlir/test/Target/SPIRV/composite-op.mlir      |  5 ++++
 mlir/test/Target/SPIRV/constant.mlir          | 11 +++++++
 4 files changed, 72 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index 3fc8dfb2767d1e..5c835d2e08de91 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -11,6 +11,13 @@ func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> ve
   return %0: vector<3xf32>
 }
 
+// CHECK-LABEL: func @composite_construct_matrix
+func.func @composite_construct_matrix(%v1: vector<3xf32>, %v2: vector<3xf32>, %v3: vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> {
+  // CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
+  %0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
+  return %0: !spirv.matrix<3 x vector<3xf32>>
+}
+
 // CHECK-LABEL: func @composite_construct_struct
 func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spirv.array<4xf32>, %arg2 : !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> {
   // CHECK: spirv.CompositeConstruct
@@ -89,9 +96,31 @@ func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2
   %0 = spirv.CompositeConstruct %arg0, %arg2 : (f32, vector<2xf32>) -> vector<4xf32>
   return %0: vector<4xf32>
 }
+// -----
+
+func.func @composite_construct_matrix_wrong_column_count(%v1: vector<3xf32>, %v2: vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> {
+  // expected-error @+1 {{'spirv.CompositeConstruct' op expected to return a vector or cooperative matrix when the number of constituents is less than what the result needs}}
+  %0 = spirv.CompositeConstruct %v1, %v2 : (vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
+  return %0: !spirv.matrix<3 x vector<3xf32>>
+}
+
+// -----
+
+func.func @composite_construct_matrix_wrong_row_count(%v1: vector<4xf32>, %v2: vector<4xf32>, %v3: vector<4xf32>) -> !spirv.matrix<3 x vector<3xf32>> {
+  // expected-error @+1 {{operand type mismatch: expected operand type 'vector<3xf32>', but provided 'vector<4xf32>'}}
+  %0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> !spirv.matrix<3 x vector<3xf32>>
+  return %0: !spirv.matrix<3 x vector<3xf32>>
+}
 
 // -----
 
+func.func @composite_construct_matrix_wrong_element_type(%v1: vector<3xi32>, %v2: vector<3xi32>, %v3: vector<3xi32>) -> !spirv.matrix<3 x vector<3xf32>> {
+  // expected-error @+1 {{operand type mismatch: expected operand type 'vector<3xf32>', but provided 'vector<3xi32>'}}
+  %0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<3xi32>, vector<3xi32>, vector<3xi32>) -> !spirv.matrix<3 x vector<3xf32>>
+  return %0: !spirv.matrix<3 x vector<3xf32>>
+}
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.CompositeExtractOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 5e98b9fdb3c546..6003d2a3576b12 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -62,6 +62,7 @@ func.func @const() -> () {
   // CHECK: spirv.Constant dense<1.000000e+00> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
   // CHECK: spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
   // CHECK: spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
+  // CHECK: spirv.Constant [dense<1.000000e+00> : vector<3xf32>, dense<2.000000e+00> : vector<3xf32>, dense<3.000000e+00> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
 
   %0 = spirv.Constant true
   %1 = spirv.Constant 42 : i32
@@ -73,6 +74,7 @@ func.func @const() -> () {
   %7 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
   %8 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
   %9 = spirv.Constant [[dense<3.0> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<1xvector<2xf32>>>
+  %10 = spirv.Constant [dense<1.0> : vector<3xf32>, dense<2.0> : vector<3xf32>, dense<3.0> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
   return
 }
 
@@ -95,7 +97,7 @@ func.func @array_constant() -> () {
 // -----
 
 func.func @array_constant() -> () {
-  // expected-error @+1 {{must have spirv.array result type for array value}}
+  // expected-error @+1 {{'spirv.Constant' op must have spirv.array or spirv.matrix result type for array value}}
   %0 = spirv.Constant [dense<3.0> : vector<2xf32>] : !spirv.rtarray<vector<2xf32>>
   return
 }
@@ -132,6 +134,30 @@ func.func @value_result_num_elements_mismatch() -> () {
 
 // -----
 
+func.func @matrix_constant() -> () {
+  // CHECK: spirv.Constant [dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : vector<3xf32>, dense<[4.000000e+00, 5.000000e+00, 6.000000e+00]> : vector<3xf32>, dense<[7.000000e+00, 8.000000e+00, 9.000000e+00]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
+  %0 = spirv.Constant [dense<[1.0, 2.0, 3.0]> : vector<3xf32>, dense<[4.0, 5.0, 6.0]> : vector<3xf32>, dense<[7.0, 8.0, 9.0]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
+  return
+}
+
+// -----
+
+func.func @matrix_constant_wrong_column_count() -> () {
+  // expected-error @+1 {{expected 3 columns in matrix constant, but got 2}}
+  %0 = spirv.Constant [dense<1.0> : vector<3xf32>, dense<2.0> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
+  return
+}
+
+// -----
+
+func.func @matrix_constant_non_dense_column() -> () {
+  // expected-error @+1 {{matrix column #1 must be a DenseElementsAttr}}
+  %0 = spirv.Constant [dense<1.0> : vector<3xf32>, "wrong", dense<3.0> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.EntryPoint
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/composite-op.mlir b/mlir/test/Target/SPIRV/composite-op.mlir
index 5f302fd0d38f8b..bafdb3340d0e79 100644
--- a/mlir/test/Target/SPIRV/composite-op.mlir
+++ b/mlir/test/Target/SPIRV/composite-op.mlir
@@ -11,6 +11,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32>
     spirv.ReturnValue %0: vector<3xf32>
   }
+  spirv.func @composite_construct_matrix(%v1: vector<3xf32>, %v2: vector<3xf32>, %v3: vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>> "None" {
+    // CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
+    %0 = spirv.CompositeConstruct %v1, %v2, %v3 : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> !spirv.matrix<3 x vector<3xf32>>
+    spirv.ReturnValue %0: !spirv.matrix<3 x vector<3xf32>>
+  }
   spirv.func @vector_dynamic_extract(%vec: vector<4xf32>, %id : i32) -> f32 "None" {
     // CHECK: spirv.VectorExtractDynamic %{{.*}}[%{{.*}}] : vector<4xf32>, i32
     %0 = spirv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index f3950214a7f055..0fa70c7e5cdbb3 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -198,6 +198,17 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     spirv.Return
   }
 
+  // CHECK-LABEL: @matrix_const
+  spirv.func @matrix_const() -> () "None" {
+    // CHECK: %[[VAR:.*]] = spirv.Variable : !spirv.ptr<!spirv.matrix<3 x vector<3xf32>>, Function>
+    %0 = spirv.Variable : !spirv.ptr<!spirv.matrix<3 x vector<3xf32>>, Function>
+    // CHECK: %[[CST:.*]] = spirv.Constant [dense<[1.000000e+00, 0.000000e+00, 0.000000e+00]> : vector<3xf32>, dense<[0.000000e+00, 1.000000e+00, 0.000000e+00]> : vector<3xf32>, dense<[0.000000e+00, 0.000000e+00, 1.000000e+00]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
+    %1 = spirv.Constant [dense<[1., 0., 0.]> : vector<3xf32>, dense<[0., 1., 0.]> : vector<3xf32>, dense<[0., 0., 1.]> : vector<3xf32>] : !spirv.matrix<3 x vector<3xf32>>
+    // CHECK: spirv.Store "Function" %[[VAR]], %[[CST]] : !spirv.matrix<3 x vector<3xf32>>
+    spirv.Store "Function" %0, %1 : !spirv.matrix<3 x vector<3xf32>>
+    spirv.Return
+  }
+
   // CHECK-LABEL: @ui64_array_const
   spirv.func @ui64_array_const() -> (!spirv.array<3xui64>) "None" {
     // CHECK: spirv.Constant [5, 6, 7] : !spirv.array<3 x i64>

>From 9cae5b69f94aaa70f23a80734104eb385794a5aa Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Fri, 17 Jan 2025 14:10:46 +0100
Subject: [PATCH 2/4] allow deser of const matrix type

---
 mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 04469f1933819b..ecc822e553aefc 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1442,6 +1442,9 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
   } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
     auto attr = opBuilder.getArrayAttr(elements);
     constantMap.try_emplace(resultID, attr, resultType);
+  } else if (auto matrixType = dyn_cast<spirv::MatrixType>(resultType)) {
+    auto attr = opBuilder.getArrayAttr(elements);
+    constantMap.try_emplace(resultID, attr, resultType);
   } else {
     return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
            << resultType;

>From 5a9850e51a8373453323d94d1e1ec8d5d45a86d1 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Fri, 17 Jan 2025 14:11:18 +0100
Subject: [PATCH 3/4] print and verify const matrices

---
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 55 ++++++++++++++++++++------
 1 file changed, 43 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 26559c1321db5e..ee7c7860b05c4e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -579,7 +579,7 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
 
 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
   printer << ' ' << getValue();
-  if (llvm::isa<spirv::ArrayType>(getType()))
+  if (llvm::isa<spirv::ArrayType, spirv::MatrixType>(getType()))
     printer << " : " << getType();
 }
 
@@ -626,18 +626,49 @@ static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
     }
     return success();
   }
-  if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
-    auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
-    if (!arrayType)
-      return op.emitOpError(
-          "must have spirv.array result type for array value");
-    Type elemType = arrayType.getElementType();
-    for (Attribute element : arrayAttr.getValue()) {
-      // Verify array elements recursively.
-      if (failed(verifyConstantType(op, element, elemType)))
-        return failure();
+  if (auto arrayAttr = mlir::dyn_cast<ArrayAttr>(value)) {
+    // Case for Matrix result type
+    if (auto matrixType = mlir::dyn_cast<spirv::MatrixType>(opType)) {
+      unsigned numColumns = matrixType.getNumColumns();
+      unsigned numRows    = matrixType.getNumRows();
+      if (arrayAttr.size() != numColumns)
+        return op.emitOpError("expected ")
+              << numColumns << " columns in matrix constant, but got "
+              << arrayAttr.size();
+
+      Type elementTy = matrixType.getElementType();
+      for (auto [colIndex, colAttr] : llvm::enumerate(arrayAttr)) {
+        // Ensure each column is a dense array of the right shape/type
+        auto denseAttr = mlir::dyn_cast<DenseElementsAttr>(colAttr);
+        if (!denseAttr)
+          return op.emitOpError("matrix column #")
+                << colIndex << " must be a DenseElementsAttr";
+
+        auto shapedTy = mlir::dyn_cast<ShapedType>(denseAttr.getType());
+        if (!shapedTy || shapedTy.getNumElements() != numRows)
+          return op.emitOpError("matrix column #")
+                << colIndex << " has incorrect size: expected "
+                << numRows << " elements";
+
+        if (shapedTy.getElementType() != elementTy)
+          return op.emitOpError("matrix column #")
+                << colIndex << " has incorrect element type: expected "
+                << elementTy << ", got " << shapedTy.getElementType();
+      }
+      return success();
     }
-    return success();
+    // Case for Array result type
+    if (auto arrayType = mlir::dyn_cast<spirv::ArrayType>(opType)) {
+      Type elemType = arrayType.getElementType();
+      for (Attribute element : arrayAttr.getValue()) {
+        // Verify array elements recursively.
+        if (failed(verifyConstantType(op, element, elemType)))
+          return failure();
+      }
+      return success();
+    }
+    return op.emitOpError(
+        "must have spirv.array or spirv.matrix result type for array value");
   }
   return op.emitOpError("cannot have attribute: ") << value;
 }

>From 087941b536d9ecdb4bd96beb6459da6243340bc8 Mon Sep 17 00:00:00 2001
From: "Misha (M3 MBP)" <obukhov.michael+m3mbp at gmail.com>
Date: Fri, 17 Jan 2025 14:11:38 +0100
Subject: [PATCH 4/4] handle const matrix serialization

---
 .../Target/SPIRV/Serialization/Serializer.cpp    | 16 ++++++++++++++++
 1 file changed, 16 insertions(+)

diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 1f4f5d7f764db3..b5e3cd381ef822 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -782,6 +782,22 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
     SmallVector<uint64_t, 4> index(rank);
     resultID = prepareDenseElementsConstant(loc, constType, attr,
                                             /*dim=*/0, index);
+  } else if (isa<spirv::MatrixType>(constType)) {
+    if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
+      resultID = getNextID();
+      SmallVector<uint32_t, 4> operands = {typeID, resultID};
+      operands.reserve(arrayAttr.size() + 2);
+      for (Attribute elementAttr : arrayAttr) {
+        if (auto elementID = prepareConstant(loc, 
+            cast<spirv::MatrixType>(constType).getColumnType(), elementAttr)) {
+          operands.push_back(elementID);
+        } else {
+          return 0;
+        }
+      }
+      spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
+      encodeInstructionInto(typesGlobalValues, opcode, operands);
+    }
   } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
     resultID = prepareArrayConstant(loc, constType, arrayAttr);
   }



More information about the Mlir-commits mailing list