[Mlir-commits] [mlir] [mlir][SPIRV] Do not rewrite CompositeInsert for coopmatrix (PR #137837)
Hsiangkai Wang
llvmlistbot at llvm.org
Wed Apr 30 01:14:07 PDT 2025
https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/137837
>From 41edb090679440bec1cc69865d179fc4c703a80d Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Tue, 29 Apr 2025 17:12:12 +0100
Subject: [PATCH 1/2] [mlir][SPIRV] Do not rewrite CompositeInsert for
coopmatrix
When rewriting multiple CompositeInserts to CompositeConstruct, we need
to know the number of elements of the result type. However, we cannot
query the number of elements for cooperative matrix types.
---
mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
index f38282f57a2c3..bc3d0429efd19 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
@@ -84,6 +84,9 @@ void RewriteInsertsPass::runOnOperation() {
LogicalResult RewriteInsertsPass::collectInsertionChain(
spirv::CompositeInsertOp op,
SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
+ if (llvm::isa<spirv::CooperativeMatrixType>(op.getComposite().getType()))
+ return failure();
+
auto indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
// TODO: handle nested composite object.
if (indicesArrayAttr.size() == 1) {
>From 9c16ef6890a640bf64e1f47383884c96cb862394 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Wed, 30 Apr 2025 09:10:31 +0100
Subject: [PATCH 2/2] address comments
---
.../Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp | 2 +-
.../Dialect/SPIRV/Transforms/rewrite-inserts.mlir | 12 ++++++++++++
2 files changed, 13 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
index bc3d0429efd19..2e31172ab940b 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
@@ -84,7 +84,7 @@ void RewriteInsertsPass::runOnOperation() {
LogicalResult RewriteInsertsPass::collectInsertionChain(
spirv::CompositeInsertOp op,
SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
- if (llvm::isa<spirv::CooperativeMatrixType>(op.getComposite().getType()))
+ if (isa<spirv::CooperativeMatrixType>(op.getComposite().getType()))
return failure();
auto indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
diff --git a/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir
index 6d755be4f3987..a83c3f7d34693 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir
@@ -29,3 +29,15 @@ spirv.module Logical GLSL450 {
spirv.ReturnValue %3 : vector<3xf32>
}
}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @insertCoopMatrix(%value : f32) -> !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> "None" {
+ %0 = spirv.Undef : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
+ // CHECK: spirv.CompositeInsert {{%.*}}, {{%.*}} : f32 into !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
+ %1 = spirv.CompositeInsert %value, %0[0 : i32] : f32 into !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
+
+ spirv.ReturnValue %1 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
+ }
+}
More information about the Mlir-commits
mailing list