[Mlir-commits] [mlir] [rfc][mlir][gpu] Add operations to extract/insert/rotate within subgroup (PR #139048)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 8 01:26:11 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Hsiangkai Wang (Hsiangkai)

<details>
<summary>Changes</summary>

Add gpu.rotate, gpu.subgroup_mma_extract, and gpu.subgroup_mma_insert operations.

---
Full diff: https://github.com/llvm/llvm-project/pull/139048.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/GPU/IR/GPUOps.td (+102) 
- (modified) mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp (+40-1) 
- (modified) mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp (+63) 
- (added) mlir/test/Conversion/GPUToSPIRV/rotate.mlir (+25) 
- (modified) mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir (+27) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 68095b7bf5c59..612ae3fac2d77 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1364,6 +1364,35 @@ def GPU_ShuffleOp : GPU_Op<
   ];
 }
 
+def GPU_RotateOp : GPU_Op<
+    "rotate", [Pure, AllTypesMatch<["value", "rotateResult"]>]>,
+    Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width)>,
+    Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult)> {
+  let summary = "Rotate values within a subgroup.";
+  let description = [{
+    The "rotate" op moves values to a across lanes circularly (a.k.a.,
+    invocations, work items) within the same subgroup. The `width` argument
+    specifies the number of lanes that participate in the rotation, and must
+    be uniform across all lanes. Further, the first `width` lanes of the
+    subgroup must be active.
+
+    Example:
+
+    ```mlir
+    %cst1 = arith.constant 1 : i32
+    %width = arith.constant 16 : i32
+    %1 = gpu.rotate %0, %cst1, %width : f32
+    ```
+
+    For lane 0 < `k` < 16, return the value from lane `(k - 1) % width`.
+    For lane k == 0, return the value from lane 15.
+  }];
+
+  let assemblyFormat = [{
+    $value `,` $offset `,` $width attr-dict `:` type($value)
+  }];
+}
+
 def GPU_BarrierOp : GPU_Op<"barrier"> {
   let summary = "Synchronizes all work items of a workgroup.";
   let description = [{
@@ -1919,6 +1948,79 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
   }];
 }
 
+def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract",
+    [Pure,
+     TypesMatchWith<"value type matches element type of mma_matrix",
+                    "matrix", "res",
+                    "::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()">]>{
+
+  let summary = "Extract a value from GPU warp by invocation and indices";
+
+  let description = [{
+    The `gpu.subgroup_mma_extract` operation extracts a value from `!gpu.mma_matrix`
+    by the invocation in a subgroup.
+
+    This operation takes `!gpu.mma_matrix` as its first operand. It is the source
+    matrix across a subgroup. The op returns a scalar value stored in the invocation
+    in the subgroup. If there are multiple values packed in an invocation, use
+    `indices` to specify the element to extract.
+
+    Example:
+
+    ```mlir
+    %c0 = arith.constant 0 : index
+    %val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
+    ```
+  }];
+
+  let arguments = (ins GPU_MMAMatrix:$matrix, Variadic<Index>:$indices);
+
+  let results = (outs AnyIntegerOrFloat:$res);
+
+  let assemblyFormat = [{
+    $matrix`[`$indices`]` attr-dict `:` type($matrix) `->` type($res)
+  }];
+}
+
+def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert",
+    [Pure,
+     TypesMatchWith<"value type matches element type of mma_matrix",
+                    "matrix", "value",
+                    "::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()"> ]>{
+
+  let summary = "Insert a value into GPU warp by invocation and indices";
+
+  let description = [{
+    The `gpu.subgroup_mma_insert` operation inserts a value to `!gpu.mma_matrix`
+    by the invocation in a subgroup.
+
+    This operation takes scalar value as its first operand and `!gpu.mma_matrix`
+    as its second operand. It is the matrix across a subgroup. The op inserts the
+    scalar value stored in the invocation in the subgroup to the matrix. If there
+    are multiple values packed in an invocation, use `indices` to specify the
+    location to insert in the packing.
+
+    The op returns `!gpu.mma_matrix` with the updated value.
+
+    Example:
+
+    ```mlir
+    %c0 = arith.constant 0 : index
+    %s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp">
+            -> !gpu.mma_matrix<16x16xf16, "COp">
+    ```
+  }];
+
+  let arguments = (ins AnyIntegerOrFloat:$value, GPU_MMAMatrix:$matrix,
+                       Variadic<Index>:$indices);
+
+  let results = (outs GPU_MMAMatrix:$res);
+
+  let assemblyFormat = [{
+    $value`,` $matrix`[`$indices`]` attr-dict `:` type($value)`,` type($matrix) `->` type($res)
+  }];
+}
+
 def GPU_ElementwiseOpAddF  : I32EnumAttrCase<"ADDF", 0, "addf">;
 def GPU_ElementwiseOpMulF  : I32EnumAttrCase<"MULF", 1, "mulf">;
 def GPU_ElementwiseOpSUBF  : I32EnumAttrCase<"SUBF", 2, "subf">;
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 3cc64b82950b5..e96709e4b4a35 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -122,6 +122,16 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Pattern to convert a gpu.rotate op into a spirv.GroupNonUniformRotateKHR op.
+class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -458,6 +468,35 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Rotate
+//===----------------------------------------------------------------------===//
+
+LogicalResult GPURotateConversion::matchAndRewrite(
+    gpu::RotateOp rotateOp, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  // Require the rotate width to be the same as the target's subgroup size,
+  // given that for SPIR-V non-uniform subgroup ops, we cannot select
+  // participating invocations.
+  auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
+  unsigned subgroupSize =
+      targetEnv.getAttr().getResourceLimits().getSubgroupSize();
+  IntegerAttr widthAttr;
+  if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
+      widthAttr.getValue().getZExtValue() != subgroupSize)
+    return rewriter.notifyMatchFailure(
+        rotateOp, "rotate width and target subgroup size mismatch");
+
+  Location loc = rotateOp.getLoc();
+  auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
+
+  Value result = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
+      loc, scope, adaptor.getValue(), adaptor.getOffset(), rotateOp.getWidth());
+
+  rewriter.replaceOp(rotateOp, result);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Group ops
 //===----------------------------------------------------------------------===//
@@ -733,7 +772,7 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
                                       RewritePatternSet &patterns) {
   patterns.add<
       GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
-      GPUReturnOpConversion, GPUShuffleConversion,
+      GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
       LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
       LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
       LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index df2da138d3b52..78d266693fc2a 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -111,6 +111,68 @@ struct WmmaConstantOpToSPIRVLowering final
   }
 };
 
+/// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative
+/// matrix ops.
+struct WmmaExtractOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaExtractOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaExtractOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Value matrix = adaptor.getMatrix();
+    auto coopType =
+        getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
+            matrix.getType());
+    if (!coopType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+    SmallVector<int32_t> intValues;
+    for (Value val : op.getIndices()) {
+      if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
+        intValues.push_back(static_cast<int32_t>(constOp.value()));
+      } else {
+        return rewriter.notifyMatchFailure(op, "Indices must be constants");
+      }
+    }
+
+    Type elementType = coopType.getElementType();
+    rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+        op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
+    return success();
+  }
+};
+
+/// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative
+/// matrix ops.
+struct WmmaInsertOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaInsertOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaInsertOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Value value = adaptor.getValue();
+    Value matrix = adaptor.getMatrix();
+    auto coopType = getTypeConverter()->convertType(matrix.getType());
+    if (!coopType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+    SmallVector<int32_t> intValues;
+    for (Value val : op.getIndices()) {
+      if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
+        intValues.push_back(static_cast<int32_t>(constOp.value()));
+      } else {
+        return rewriter.notifyMatchFailure(op, "Indices must be constants");
+      }
+    }
+
+    rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+        op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
+    return success();
+  }
+};
+
 /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
 /// the default case.
 struct WmmaElementwiseOpToSPIRVDefaultLowering final
@@ -296,6 +358,7 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
   MLIRContext *context = patterns.getContext();
   patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
                khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
+               WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
                WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
   // Give the following patterns higher benefit to prevail over the default one.
   patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
diff --git a/mlir/test/Conversion/GPUToSPIRV/rotate.mlir b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
new file mode 100644
index 0000000000000..102c2fb01edb6
--- /dev/null
+++ b/mlir/test/Conversion/GPUToSPIRV/rotate.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>, #spirv.resource_limits<subgroup_size = 16>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL:  spirv.func @rotate()
+  gpu.func @rotate() kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [4, 4, 1]>} {
+    // CHECK: %[[CST8_I32:.*]] = spirv.Constant 8 : i32
+    // CHECK: %[[CST16_I32:.*]] = spirv.Constant 16 : i32
+    // CHECK: %[[CST_F32:.*]] = spirv.Constant 4.200000e+01 : f32
+    %offset = arith.constant 8 : i32
+    %width = arith.constant 16 : i32
+    %val = arith.constant 42.0 : f32
+
+    // CHECK: spirv.GroupNonUniformRotateKHR <Subgroup>, %[[CST_F32]], %[[CST8_I32]], cluster_size(%[[CST16_I32]])
+    %result = gpu.rotate %val, %offset, %width : f32
+    gpu.return
+  }
+}
+
+}
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
index 477f344b1ae5f..3e8a3b21e7e94 100644
--- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir
@@ -93,6 +93,33 @@ module attributes {
       gpu.return
     }
 
+    // CHECK-LABEL: spirv.func @gpu_wmma_extract_op
+    // CHECK-SAME: %[[ARG0:.+]]: !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
+    gpu.func @gpu_wmma_extract_op(%m: !gpu.mma_matrix<16x16xf32, "AOp">,
+                                  %ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+      // CHECK: spirv.CompositeExtract %[[ARG0]][0 : i32] : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
+      %c0 = arith.constant 0 : index
+      %val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
+      memref.store %val, %ptr[%c0, %c0] : memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
+      gpu.return
+    }
+
+    // CHECK-LABEL: spirv.func @gpu_wmma_insert_op
+    // CHECK-SAME: %[[ARG0:.+]]: f16
+    // CHECK-SAME: %[[ARG1:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+    gpu.func @gpu_wmma_insert_op(%val: f16,
+                                 %m: !gpu.mma_matrix<16x16xf16, "COp">,
+                                 %ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
+      // CHECK: spirv.CompositeInsert %[[ARG0]], %[[ARG1]][0 : i32] : f16 into !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
+      %c0 = arith.constant 0 : index
+      %s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+      gpu.subgroup_mma_store_matrix %s0, %ptr[%c0,%c0] {leadDimension = 16 : index} :
+        !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
+      gpu.return
+    }
+
     // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
     // CHECK-SAME:    !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
     // CHECK-SAME:    !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>

``````````

</details>


https://github.com/llvm/llvm-project/pull/139048


More information about the Mlir-commits mailing list