[Mlir-commits] [mlir] [rfc][mlir][gpu] Add operations to extract/insert/rotate within subgroup (PR #139048)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 8 01:26:11 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir
Author: Hsiangkai Wang (Hsiangkai)
<details>
<summary>Changes</summary>
Add gpu.rotate, gpu.subgroup_mma_extract, and gpu.subgroup_mma_insert operations.
---
Full diff: https://github.com/llvm/llvm-project/pull/139048.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/GPU/IR/GPUOps.td (+102)
- (modified) mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp (+40-1)
- (modified) mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp (+63)
- (added) mlir/test/Conversion/GPUToSPIRV/rotate.mlir (+25)
- (modified) mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir (+27)
``````````diff
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 68095b7bf5c59..612ae3fac2d77 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1364,6 +1364,35 @@ def GPU_ShuffleOp : GPU_Op<
];
}
+def GPU_RotateOp : GPU_Op<
+ "rotate", [Pure, AllTypesMatch<["value", "rotateResult"]>]>,
+ Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width)>,
+ Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult)> {
+ let summary = "Rotate values within a subgroup.";
+ let description = [{
+ The "rotate" op moves values to a across lanes circularly (a.k.a.,
+ invocations, work items) within the same subgroup. The `width` argument
+ specifies the number of lanes that participate in the rotation, and must
+ be uniform across all lanes. Further, the first `width` lanes of the
+ subgroup must be active.
+
+ Example:
+
+ ```mlir
+ %cst1 = arith.constant 1 : i32
+ %width = arith.constant 16 : i32
+ %1 = gpu.rotate %0, %cst1, %width : f32
+ ```
+
+ For lane 0 < `k` < 16, return the value from lane `(k - 1) % width`.
+ For lane k == 0, return the value from lane 15.
+ }];
+
+ let assemblyFormat = [{
+ $value `,` $offset `,` $width attr-dict `:` type($value)
+ }];
+}
+
def GPU_BarrierOp : GPU_Op<"barrier"> {
let summary = "Synchronizes all work items of a workgroup.";
let description = [{
@@ -1919,6 +1948,79 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
}];
}
+def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract",
+ [Pure,
+ TypesMatchWith<"value type matches element type of mma_matrix",
+ "matrix", "res",
+ "::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()">]>{
+
+ let summary = "Extract a value from GPU warp by invocation and indices";
+
+ let description = [{
+ The `gpu.subgroup_mma_extract` operation extracts a value from `!gpu.mma_matrix`
+ by the invocation in a subgroup.
+
+ This operation takes `!gpu.mma_matrix` as its first operand. It is the source
+ matrix across a subgroup. The op returns a scalar value stored in the invocation
+ in the subgroup. If there are multiple values packed in an invocation, use
+ `indices` to specify the element to extract.
+
+ Example:
+
+ ```mlir
+ %c0 = arith.constant 0 : index
+ %val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
+ ```
+ }];
+
+ let arguments = (ins GPU_MMAMatrix:$matrix, Variadic<Index>:$indices);
+
+ let results = (outs AnyIntegerOrFloat:$res);
+
+ let assemblyFormat = [{
+ $matrix`[`$indices`]` attr-dict `:` type($matrix) `->` type($res)
+ }];
+}
+
+def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert",
+ [Pure,
+ TypesMatchWith<"value type matches element type of mma_matrix",
+ "matrix", "value",
+ "::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()"> ]>{
+
+ let summary = "Insert a value into GPU warp by invocation and indices";
+
+ let description = [{
+ The `gpu.subgroup_mma_insert` operation inserts a value to `!gpu.mma_matrix`
+ by the invocation in a subgroup.
+
+ This operation takes scalar value as its first operand and `!gpu.mma_matrix`
+ as its second operand. It is the matrix across a subgroup. The op inserts the
+ scalar value stored in the invocation in the subgroup to the matrix. If there
+ are multiple values packed in an invocation, use `indices` to specify the
+ location to insert in the packing.
+
+ The op returns `!gpu.mma_matrix` with the updated value.
+
+ Example:
+
+ ```mlir
+ %c0 = arith.constant 0 : index
+ %s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp">
+ -> !gpu.mma_matrix<16x16xf16, "COp">
+ ```
+ }];
+
+ let arguments = (ins AnyIntegerOrFloat:$value, GPU_MMAMatrix:$matrix,
+ Variadic<Index>:$indices);
+
+ let results = (outs GPU_MMAMatrix:$res);
+
+ let assemblyFormat = [{
+ $value`,` $matrix`[`$indices`]` attr-dict `:` type($value)`,` type($matrix) `->` type($res)
+ }];
+}
+
def GPU_ElementwiseOpAddF : I32EnumAttrCase<"ADDF", 0, "addf">;
def GPU_ElementwiseOpMulF : I32EnumAttrCase<"MULF", 1, "mulf">;
def GPU_ElementwiseOpSUBF : I32EnumAttrCase<"SUBF", 2, "subf">;
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 3cc64b82950b5..e96709e4b4a35 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -122,6 +122,16 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
ConversionPatternRewriter &rewriter) const override;
};
+/// Pattern to convert a gpu.rotate op into a spirv.GroupNonUniformRotateKHR op.
+class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -458,6 +468,35 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
return success();
}
+//===----------------------------------------------------------------------===//
+// Rotate
+//===----------------------------------------------------------------------===//
+
+LogicalResult GPURotateConversion::matchAndRewrite(
+ gpu::RotateOp rotateOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // Require the rotate width to be the same as the target's subgroup size,
+ // given that for SPIR-V non-uniform subgroup ops, we cannot select
+ // participating invocations.
+ auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
+ unsigned subgroupSize =
+ targetEnv.getAttr().getResourceLimits().getSubgroupSize();
+ IntegerAttr widthAttr;
+ if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
+ widthAttr.getValue().getZExtValue() != subgroupSize)
+ return rewriter.notifyMatchFailure(
+ rotateOp, "rotate width and target subgroup size mismatch");
+
+ Location loc = rotateOp.getLoc();
+ auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
+
+ Value result = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
+ loc, scope, adaptor.getValue(), adaptor.getOffset(), rotateOp.getWidth());
+
+ rewriter.replaceOp(rotateOp, result);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Group ops
//===----------------------------------------------------------------------===//
@@ -733,7 +772,7 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<
GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
- GPUReturnOpConversion, GPUShuffleConversion,
+ GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index df2da138d3b52..78d266693fc2a 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -111,6 +111,68 @@ struct WmmaConstantOpToSPIRVLowering final
}
};
+/// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative
+/// matrix ops.
+struct WmmaExtractOpToSPIRVLowering final
+ : OpConversionPattern<gpu::SubgroupMmaExtractOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaExtractOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value matrix = adaptor.getMatrix();
+ auto coopType =
+ getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
+ matrix.getType());
+ if (!coopType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ SmallVector<int32_t> intValues;
+ for (Value val : op.getIndices()) {
+ if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
+ intValues.push_back(static_cast<int32_t>(constOp.value()));
+ } else {
+ return rewriter.notifyMatchFailure(op, "Indices must be constants");
+ }
+ }
+
+ Type elementType = coopType.getElementType();
+ rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+ op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
+ return success();
+ }
+};
+
+/// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative
+/// matrix ops.
+struct WmmaInsertOpToSPIRVLowering final
+ : OpConversionPattern<gpu::SubgroupMmaInsertOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaInsertOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value value = adaptor.getValue();
+ Value matrix = adaptor.getMatrix();
+ auto coopType = getTypeConverter()->convertType(matrix.getType());
+ if (!coopType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ SmallVector<int32_t> intValues;
+ for (Value val : op.getIndices()) {
+ if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
+ intValues.push_back(static_cast<int32_t>(constOp.value()));
+ } else {
+ return rewriter.notifyMatchFailure(op, "Indices must be constants");
+ }
+ }
+
+ rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+ op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
+ return success();
+ }
+};
+
/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
/// the default case.
struct WmmaElementwiseOpToSPIRVDefaultLowering final
@@ -296,6 +358,7 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
MLIRContext *context = patterns.getContext();
patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
+ WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
// Give the following patterns higher benefit to prevail over the default one.
patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
diff --git a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
new file mode 100644
index 0000000000000..102c2fb01edb6
--- /dev/null
+++ b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>, #spirv.resource_limits<subgroup_size = 16>>
+} {
+
+gpu.module @kernels {
+ // CHECK-LABEL: spirv.func @rotate()
+ gpu.func @rotate() kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [4, 4, 1]>} {
+ // CHECK: %[[CST8_I32:.*]] = spirv.Constant 8 : i32
+ // CHECK: %[[CST16_I32:.*]] = spirv.Constant 16 : i32
+ // CHECK: %[[CST_F32:.*]] = spirv.Constant 4.200000e+01 : f32
+ %offset = arith.constant 8 : i32
+ %width = arith.constant 16 : i32
+ %val = arith.constant 42.0 : f32
+
+ // CHECK: spirv.GroupNonUniformRotateKHR <Subgroup>, %[[CST_F32]], %[[CST8_I32]], cluster_size(%[[CST16_I32]])
+ %result = gpu.rotate %val, %offset, %width : f32
+ gpu.return
+ }
+}
+
+}
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
index 477f344b1ae5f..3e8a3b21e7e94 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
@@ -93,6 +93,33 @@ module attributes {
gpu.return
}
+ // CHECK-LABEL: spirv.func @gpu_wmma_extract_op
+ // CHECK-SAME: %[[ARG0:.+]]: !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
+ gpu.func @gpu_wmma_extract_op(%m: !gpu.mma_matrix<16x16xf32, "AOp">,
+ %ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+ // CHECK: spirv.CompositeExtract %[[ARG0]][0 : i32] : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
+ %c0 = arith.constant 0 : index
+ %val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
+ memref.store %val, %ptr[%c0, %c0] : memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
+ gpu.return
+ }
+
+ // CHECK-LABEL: spirv.func @gpu_wmma_insert_op
+ // CHECK-SAME: %[[ARG0:.+]]: f16
+ // CHECK-SAME: %[[ARG1:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+ gpu.func @gpu_wmma_insert_op(%val: f16,
+ %m: !gpu.mma_matrix<16x16xf16, "COp">,
+ %ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+ // CHECK: spirv.CompositeInsert %[[ARG0]], %[[ARG1]][0 : i32] : f16 into !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+ %c0 = arith.constant 0 : index
+ %s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+ gpu.subgroup_mma_store_matrix %s0, %ptr[%c0,%c0] {leadDimension = 16 : index} :
+ !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
+ gpu.return
+ }
+
// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
``````````
</details>
https://github.com/llvm/llvm-project/pull/139048
More information about the Mlir-commits
mailing list