[Mlir-commits] [mlir] [mlir][gpu] Add GPU subgroup MMA extract and insert operations (PR #139048)
Hsiangkai Wang
llvmlistbot at llvm.org
Fri May 23 01:52:06 PDT 2025
https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/139048
>From 9b814e2bf72004ac5b829a6a18e4f30057e0d2e4 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Fri, 2 May 2025 14:32:50 +0100
Subject: [PATCH 1/4] [mlir][gpu] Add GPU subgroup MMA extract and insert
operations
- Introduced `gpu.subgroup_mma_extract` operation to extract values from `!gpu.mma_matrix` by invocation and indices.
- Introduced `gpu.subgroup_mma_insert` operation to insert values into `!gpu.mma_matrix` by invocation and indices.
- Updated the conversion patterns to SPIR-V for both extract and insert operations.
- Added test cases to validate the new operations in the GPU to SPIR-V conversion.
---
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 73 +++++++++++++++++++
.../Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp | 63 ++++++++++++++++
.../wmma-ops-to-spirv-khr-coop-matrix.mlir | 27 +++++++
3 files changed, 163 insertions(+)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 68095b7bf5c59..cb363b501851b 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1919,6 +1919,79 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
}];
}
+def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract",
+ [Pure,
+ TypesMatchWith<"value type matches element type of mma_matrix",
+ "matrix", "res",
+ "::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()">]>{
+
+ let summary = "Extract a value from GPU warp by invocation and indices";
+
+ let description = [{
+ The `gpu.subgroup_mma_extract` operation extracts a value from `!gpu.mma_matrix`
+ by the invocation in a subgroup.
+
+ This operation takes `!gpu.mma_matrix` as its first operand. It is the source
+ matrix across a subgroup. The op returns a scalar value stored in the invocation
+ in the subgroup. If there are multiple values packed in an invocation, use
+ `indices` to specify the element to extract.
+
+ Example:
+
+ ```mlir
+ %c0 = arith.constant 0 : index
+ %val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
+ ```
+ }];
+
+ let arguments = (ins GPU_MMAMatrix:$matrix, Variadic<Index>:$indices);
+
+ let results = (outs AnyIntegerOrFloat:$res);
+
+ let assemblyFormat = [{
+ $matrix`[`$indices`]` attr-dict `:` type($matrix) `->` type($res)
+ }];
+}
+
+def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert",
+ [Pure,
+ TypesMatchWith<"value type matches element type of mma_matrix",
+ "matrix", "value",
+ "::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()"> ]>{
+
+ let summary = "Insert a value into GPU warp by invocation and indices";
+
+ let description = [{
+ The `gpu.subgroup_mma_insert` operation inserts a value to `!gpu.mma_matrix`
+ by the invocation in a subgroup.
+
+ This operation takes scalar value as its first operand and `!gpu.mma_matrix`
+ as its second operand. It is the matrix across a subgroup. The op inserts the
+ scalar value stored in the invocation in the subgroup to the matrix. If there
+ are multiple values packed in an invocation, use `indices` to specify the
+ location to insert in the packing.
+
+ The op returns `!gpu.mma_matrix` with the updated value.
+
+ Example:
+
+ ```mlir
+ %c0 = arith.constant 0 : index
+ %s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp">
+ -> !gpu.mma_matrix<16x16xf16, "COp">
+ ```
+ }];
+
+ let arguments = (ins AnyIntegerOrFloat:$value, GPU_MMAMatrix:$matrix,
+ Variadic<Index>:$indices);
+
+ let results = (outs GPU_MMAMatrix:$res);
+
+ let assemblyFormat = [{
+ $value`,` $matrix`[`$indices`]` attr-dict `:` type($value)`,` type($matrix) `->` type($res)
+ }];
+}
+
def GPU_ElementwiseOpAddF : I32EnumAttrCase<"ADDF", 0, "addf">;
def GPU_ElementwiseOpMulF : I32EnumAttrCase<"MULF", 1, "mulf">;
def GPU_ElementwiseOpSUBF : I32EnumAttrCase<"SUBF", 2, "subf">;
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index df2da138d3b52..be76262f526d6 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -111,6 +111,68 @@ struct WmmaConstantOpToSPIRVLowering final
}
};
+/// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative
+/// matrix ops.
+struct WmmaExtractOpToSPIRVLowering final
+ : OpConversionPattern<gpu::SubgroupMmaExtractOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaExtractOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value matrix = adaptor.getMatrix();
+ auto coopType =
+ getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
+ matrix.getType());
+ if (!coopType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ SmallVector<int32_t> intValues;
+ for (Value val : op.getIndices()) {
+ if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
+ intValues.push_back(static_cast<int32_t>(constOp.value()));
+ } else {
+ return rewriter.notifyMatchFailure(op, "indices must be constants");
+ }
+ }
+
+ Type elementType = coopType.getElementType();
+ rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+ op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
+ return success();
+ }
+};
+
+/// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative
+/// matrix ops.
+struct WmmaInsertOpToSPIRVLowering final
+ : OpConversionPattern<gpu::SubgroupMmaInsertOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaInsertOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value value = adaptor.getValue();
+ Value matrix = adaptor.getMatrix();
+ auto coopType = getTypeConverter()->convertType(matrix.getType());
+ if (!coopType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ SmallVector<int32_t> intValues;
+ for (Value val : op.getIndices()) {
+ if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
+ intValues.push_back(static_cast<int32_t>(constOp.value()));
+ } else {
+ return rewriter.notifyMatchFailure(op, "indices must be constants");
+ }
+ }
+
+ rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+ op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
+ return success();
+ }
+};
+
/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
/// the default case.
struct WmmaElementwiseOpToSPIRVDefaultLowering final
@@ -296,6 +358,7 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
MLIRContext *context = patterns.getContext();
patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
+ WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
// Give the following patterns higher benefit to prevail over the default one.
patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
index 477f344b1ae5f..3e8a3b21e7e94 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
@@ -93,6 +93,33 @@ module attributes {
gpu.return
}
+ // CHECK-LABEL: spirv.func @gpu_wmma_extract_op
+ // CHECK-SAME: %[[ARG0:.+]]: !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
+ gpu.func @gpu_wmma_extract_op(%m: !gpu.mma_matrix<16x16xf32, "AOp">,
+ %ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+ // CHECK: spirv.CompositeExtract %[[ARG0]][0 : i32] : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
+ %c0 = arith.constant 0 : index
+ %val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
+ memref.store %val, %ptr[%c0, %c0] : memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
+ gpu.return
+ }
+
+ // CHECK-LABEL: spirv.func @gpu_wmma_insert_op
+ // CHECK-SAME: %[[ARG0:.+]]: f16
+ // CHECK-SAME: %[[ARG1:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+ gpu.func @gpu_wmma_insert_op(%val: f16,
+ %m: !gpu.mma_matrix<16x16xf16, "COp">,
+ %ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+ // CHECK: spirv.CompositeInsert %[[ARG0]], %[[ARG1]][0 : i32] : f16 into !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+ %c0 = arith.constant 0 : index
+ %s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+ gpu.subgroup_mma_store_matrix %s0, %ptr[%c0,%c0] {leadDimension = 16 : index} :
+ !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
+ gpu.return
+ }
+
// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
>From b6910549d818f8c83afad22304635972b0d7dec7 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Tue, 20 May 2025 11:06:11 +0100
Subject: [PATCH 2/4] Add a parsing/printing test
---
mlir/test/Dialect/GPU/ops.mlir | 14 ++++++++++++++
1 file changed, 14 insertions(+)
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 99915c493ea46..0364fc47b9308 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -430,6 +430,20 @@ module attributes {gpu.container_module} {
gpu.wait [%token16]
return
}
+
+ // CHECK-LABEL: func @extract_insert_mma
+ func.func @extract_insert_mma(%src : !gpu.mma_matrix<16x16xf32, "COp">,
+ %ptr: memref<16x16xf32>) {
+ %zero = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ // CHECK: gpu.subgroup_mma_extract
+ %val = gpu.subgroup_mma_extract %src[%c0] : !gpu.mma_matrix<16x16xf32, "COp"> -> f32
+ %m = gpu.subgroup_mma_constant_matrix %zero : !gpu.mma_matrix<16x16xf32, "COp">
+ // CHECK: gpu.subgroup_mma_insert
+ %s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f32, !gpu.mma_matrix<16x16xf32, "COp"> -> !gpu.mma_matrix<16x16xf32, "COp">
+ gpu.subgroup_mma_store_matrix %s0, %ptr[%c0, %c0] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
+ return
+ }
}
// Just check that this doesn't crash.
>From 96fcae522342fcd1dd43be631479d90e5c1528e3 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Thu, 22 May 2025 10:28:02 +0100
Subject: [PATCH 3/4] add more description about how mma matrix is stored
---
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index cb363b501851b..4c107a487aa4d 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1933,8 +1933,9 @@ def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract",
This operation takes `!gpu.mma_matrix` as its first operand. It is the source
matrix across a subgroup. The op returns a scalar value stored in the invocation
- in the subgroup. If there are multiple values packed in an invocation, use
- `indices` to specify the element to extract.
+ in the subgroup. The values of !gpu.mma_matrix are stored across multiple
+ threads in the subgroup. If there are multiple values packed in a thread, use
+ `indices` to specify the element in the local thread to extract.
Example:
@@ -1967,7 +1968,8 @@ def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert",
This operation takes scalar value as its first operand and `!gpu.mma_matrix`
as its second operand. It is the matrix across a subgroup. The op inserts the
- scalar value stored in the invocation in the subgroup to the matrix. If there
+ scalar value stored in the invocation in the subgroup to the matrix. The values
+ of !gpu.mma_matrix are stored across multiple threads in the subgroup. If there
are multiple values packed in an invocation, use `indices` to specify the
location to insert in the packing.
>From c7269d3948b604075927f0bc87e57d9bfff707cf Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Fri, 23 May 2025 07:24:43 +0100
Subject: [PATCH 4/4] refine descriptions and rename operations
---
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 46 ++++++++++++-------
.../Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp | 8 ++--
.../wmma-ops-to-spirv-khr-coop-matrix.mlir | 12 ++---
mlir/test/Dialect/GPU/ops.mlir | 8 ++--
4 files changed, 44 insertions(+), 30 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 4c107a487aa4d..fb27630ed3b48 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1919,7 +1919,7 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
}];
}
-def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract",
+def GPU_SubgroupMmaExtractThreadLocalOp : GPU_Op<"subgroup_mma_extract_thread_local",
[Pure,
TypesMatchWith<"value type matches element type of mma_matrix",
"matrix", "res",
@@ -1928,20 +1928,28 @@ def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract",
let summary = "Extract a value from GPU warp by invocation and indices";
let description = [{
- The `gpu.subgroup_mma_extract` operation extracts a value from `!gpu.mma_matrix`
- by the invocation in a subgroup.
+ The `gpu.subgroup_mma_extract_thread_local` operation extracts a value from `!gpu.mma_matrix`
+ that is stored at subgroup level.
This operation takes `!gpu.mma_matrix` as its first operand. It is the source
matrix across a subgroup. The op returns a scalar value stored in the invocation
- in the subgroup. The values of !gpu.mma_matrix are stored across multiple
- threads in the subgroup. If there are multiple values packed in a thread, use
- `indices` to specify the element in the local thread to extract.
+ in the subgroup.
+
+ Since `matrix` is packed into the the threads within a subgroup, `indices` are
+ the indices into the values stored by each thread. That is, an index of 0 (or [0, 0])
+ does not necessarily refer to the first element of the matrix, but the first element
+ that a particular thread holds.
+
+ The mapping of matrix elements to threads is not defined by this operation and may
+ not be defined by some lowerings (such as the lowering to SPIR-V). However, if the
+ size of the subgroup is S, then `subgroup_mma_extract_thread_local` at each index in
+ `[0, (M * N) / S)` will have the entire matrix extracted across the subgroup.
Example:
```mlir
%c0 = arith.constant 0 : index
- %val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
+ %val = gpu.subgroup_mma_extract_thread_local %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
```
}];
@@ -1954,7 +1962,7 @@ def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract",
}];
}
-def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert",
+def GPU_SubgroupMmaInsertThreadLocalOp : GPU_Op<"subgroup_mma_insert_thread_local",
[Pure,
TypesMatchWith<"value type matches element type of mma_matrix",
"matrix", "value",
@@ -1963,15 +1971,21 @@ def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert",
let summary = "Insert a value into GPU warp by invocation and indices";
let description = [{
- The `gpu.subgroup_mma_insert` operation inserts a value to `!gpu.mma_matrix`
- by the invocation in a subgroup.
+ The `gpu.subgroup_mma_insert_thread_local` operation inserts a value to `!gpu.mma_matrix`
+ that is stored at subgroup level.
This operation takes scalar value as its first operand and `!gpu.mma_matrix`
- as its second operand. It is the matrix across a subgroup. The op inserts the
- scalar value stored in the invocation in the subgroup to the matrix. The values
- of !gpu.mma_matrix are stored across multiple threads in the subgroup. If there
- are multiple values packed in an invocation, use `indices` to specify the
- location to insert in the packing.
+ as its second operand. The op inserts the scalar value to the matrix.
+
+ Since `matrix` is packed into the the threads within a subgroup, `indices` are
+ the indices into the values stored by each thread. That is, an index of 0 (or [0, 0])
+ does not necessarily refer to the first element of the matrix, but the first element
+ that a particular thread holds.
+
+ The mapping of matrix elements to threads is not defined by this operation and may
+ not be defined by some lowerings (such as the lowering to SPIR-V). However, if the
+ size of the subgroup is S, then `subgroup_mma_insert_thread_local` at each index in
+ `[0, (M * N) / S)` will have the entire matrix inserted across the subgroup.
The op returns `!gpu.mma_matrix` with the updated value.
@@ -1979,7 +1993,7 @@ def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert",
```mlir
%c0 = arith.constant 0 : index
- %s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp">
+ %s0 = gpu.subgroup_mma_insert_thread_local %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp">
-> !gpu.mma_matrix<16x16xf16, "COp">
```
}];
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index be76262f526d6..d2f5e35853550 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -114,11 +114,11 @@ struct WmmaConstantOpToSPIRVLowering final
/// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative
/// matrix ops.
struct WmmaExtractOpToSPIRVLowering final
- : OpConversionPattern<gpu::SubgroupMmaExtractOp> {
+ : OpConversionPattern<gpu::SubgroupMmaExtractThreadLocalOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(gpu::SubgroupMmaExtractOp op, OpAdaptor adaptor,
+ matchAndRewrite(gpu::SubgroupMmaExtractThreadLocalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value matrix = adaptor.getMatrix();
auto coopType =
@@ -146,11 +146,11 @@ struct WmmaExtractOpToSPIRVLowering final
/// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative
/// matrix ops.
struct WmmaInsertOpToSPIRVLowering final
- : OpConversionPattern<gpu::SubgroupMmaInsertOp> {
+ : OpConversionPattern<gpu::SubgroupMmaInsertThreadLocalOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(gpu::SubgroupMmaInsertOp op, OpAdaptor adaptor,
+ matchAndRewrite(gpu::SubgroupMmaInsertThreadLocalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value value = adaptor.getValue();
Value matrix = adaptor.getMatrix();
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
index 3e8a3b21e7e94..7ef3711ebe28b 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
@@ -93,28 +93,28 @@ module attributes {
gpu.return
}
- // CHECK-LABEL: spirv.func @gpu_wmma_extract_op
+ // CHECK-LABEL: spirv.func @gpu_wmma_extract_thread_local_op
// CHECK-SAME: %[[ARG0:.+]]: !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
- gpu.func @gpu_wmma_extract_op(%m: !gpu.mma_matrix<16x16xf32, "AOp">,
+ gpu.func @gpu_wmma_extract_thread_local_op(%m: !gpu.mma_matrix<16x16xf32, "AOp">,
%ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
// CHECK: spirv.CompositeExtract %[[ARG0]][0 : i32] : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
%c0 = arith.constant 0 : index
- %val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
+ %val = gpu.subgroup_mma_extract_thread_local %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
memref.store %val, %ptr[%c0, %c0] : memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
gpu.return
}
- // CHECK-LABEL: spirv.func @gpu_wmma_insert_op
+ // CHECK-LABEL: spirv.func @gpu_wmma_insert_thread_local_op
// CHECK-SAME: %[[ARG0:.+]]: f16
// CHECK-SAME: %[[ARG1:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
- gpu.func @gpu_wmma_insert_op(%val: f16,
+ gpu.func @gpu_wmma_insert_thread_local_op(%val: f16,
%m: !gpu.mma_matrix<16x16xf16, "COp">,
%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
// CHECK: spirv.CompositeInsert %[[ARG0]], %[[ARG1]][0 : i32] : f16 into !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%c0 = arith.constant 0 : index
- %s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+ %s0 = gpu.subgroup_mma_insert_thread_local %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp">
gpu.subgroup_mma_store_matrix %s0, %ptr[%c0,%c0] {leadDimension = 16 : index} :
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
gpu.return
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 0364fc47b9308..9dbe16774f517 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -436,11 +436,11 @@ module attributes {gpu.container_module} {
%ptr: memref<16x16xf32>) {
%zero = arith.constant 0.0 : f32
%c0 = arith.constant 0 : index
- // CHECK: gpu.subgroup_mma_extract
- %val = gpu.subgroup_mma_extract %src[%c0] : !gpu.mma_matrix<16x16xf32, "COp"> -> f32
+ // CHECK: gpu.subgroup_mma_extract_thread_local
+ %val = gpu.subgroup_mma_extract_thread_local %src[%c0] : !gpu.mma_matrix<16x16xf32, "COp"> -> f32
%m = gpu.subgroup_mma_constant_matrix %zero : !gpu.mma_matrix<16x16xf32, "COp">
- // CHECK: gpu.subgroup_mma_insert
- %s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f32, !gpu.mma_matrix<16x16xf32, "COp"> -> !gpu.mma_matrix<16x16xf32, "COp">
+ // CHECK: gpu.subgroup_mma_insert_thread_local
+ %s0 = gpu.subgroup_mma_insert_thread_local %val, %m[%c0] : f32, !gpu.mma_matrix<16x16xf32, "COp"> -> !gpu.mma_matrix<16x16xf32, "COp">
gpu.subgroup_mma_store_matrix %s0, %ptr[%c0, %c0] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
return
}
More information about the Mlir-commits
mailing list