[Mlir-commits] [mlir] 9f1da90 - [mlir][SPIRV] Do not rewrite CompositeInsert for coopmatrix (#137837)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 21 00:12:27 PDT 2025
Author: Hsiangkai Wang
Date: 2025-05-21T08:12:24+01:00
New Revision: 9f1da90d6f996f2d4606b2d1e31b494d72a82b0e
URL: https://github.com/llvm/llvm-project/commit/9f1da90d6f996f2d4606b2d1e31b494d72a82b0e
DIFF: https://github.com/llvm/llvm-project/commit/9f1da90d6f996f2d4606b2d1e31b494d72a82b0e.diff
LOG: [mlir][SPIRV] Do not rewrite CompositeInsert for coopmatrix (#137837)
When rewriting multiple CompositeInserts to CompositeConstruct, we need
to know the number of elements of the result type. However, we cannot
query the number of elements for cooperative matrix types.
Added:
Modified:
mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
index f38282f57a2c3..2e31172ab940b 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
@@ -84,6 +84,9 @@ void RewriteInsertsPass::runOnOperation() {
LogicalResult RewriteInsertsPass::collectInsertionChain(
spirv::CompositeInsertOp op,
SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
+ if (isa<spirv::CooperativeMatrixType>(op.getComposite().getType()))
+ return failure();
+
auto indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
// TODO: handle nested composite object.
if (indicesArrayAttr.size() == 1) {
diff --git a/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir
index 6d755be4f3987..a83c3f7d34693 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir
@@ -29,3 +29,15 @@ spirv.module Logical GLSL450 {
spirv.ReturnValue %3 : vector<3xf32>
}
}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @insertCoopMatrix(%value : f32) -> !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> "None" {
+ %0 = spirv.Undef : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
+ // CHECK: spirv.CompositeInsert {{%.*}}, {{%.*}} : f32 into !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
+ %1 = spirv.CompositeInsert %value, %0[0 : i32] : f32 into !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
+
+ spirv.ReturnValue %1 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
+ }
+}
More information about the Mlir-commits
mailing list