[Mlir-commits] [mlir] 8063bd1 - [MLIR][XeGPU] Add support for elementwise ops in Wg to Sg distribute pass [1/N] (#142797)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 17 09:55:05 PDT 2025


Author: Nishant Patel
Date: 2025-06-17T09:55:02-07:00
New Revision: 8063bd153c6aca43869d96aee64aeceb9be98ca5

URL: https://github.com/llvm/llvm-project/commit/8063bd153c6aca43869d96aee64aeceb9be98ca5
DIFF: https://github.com/llvm/llvm-project/commit/8063bd153c6aca43869d96aee64aeceb9be98ca5.diff

LOG: [MLIR][XeGPU] Add support for elementwise ops in Wg to Sg distribute pass [1/N] (#142797)

This PR adds support for Elementwise operations' (unary & binary)
lowering from Workgroup to Subgroup.

Added: 
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir

Modified: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a26c6b52f0ddc..e3563d10bc6f1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -8,10 +8,12 @@
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 
 #include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -19,6 +21,7 @@
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include <optional>
 
 namespace mlir {
 namespace xegpu {
@@ -328,6 +331,65 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+// This pattern transforms elementwise ops to work at subgroup level.
+struct WgToSgElementwiseOp : public ConversionPattern {
+  WgToSgElementwiseOp(MLIRContext *ctx)
+      : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only match ops with elementwise trait and single result.
+    if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
+      return failure();
+
+    auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
+    assert(resultType && "Expected result to be a VectorType");
+
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
+    if (!layout || !layout.getSgLayout())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+
+    size_t numVariants = operands.empty() ? 0 : operands.front().size();
+
+    if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
+          return operandVec.size() != numVariants;
+        }))
+      return failure();
+
+    SmallVector<Value> newResults;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+
+    for (size_t i = 0; i < numVariants; ++i) {
+      SmallVector<Value> opOperands;
+      for (auto &operandVec : operands)
+        opOperands.push_back(operandVec[i]);
+
+      OperationState state(op->getLoc(), op->getName());
+      state.addOperands(opOperands);
+      state.addTypes(newResultType);
+      // Copy all attributes, but update "layout_result_0" to drop
+      // sgLayout/sgData
+      for (auto attr : op->getAttrs()) {
+        if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue()))
+          state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
+        else
+          state.addAttribute(attr.getName(), attr.getValue());
+      }
+      Operation *newOp = rewriter.create(state);
+      newResults.push_back(newOp->getResult(0));
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newResults});
+    return success();
+  }
+};
+
 // Handles UnrealizedConversionCastOp generated during
 // SCFStructuralTypeConversions (step 1). This op may appear as either a
 // target or source materialization for Vector values, e.g.:
@@ -411,7 +473,8 @@ namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
                WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
-               UnrealizedConversionCastOpPattern>(patterns.getContext());
+               UnrealizedConversionCastOpPattern, WgToSgElementwiseOp>(
+      patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -518,6 +581,30 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
     return isLegal(layout);
   });
 
+  target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
+      [=](Operation *op) -> std::optional<bool> {
+        // Only handle elementwise mappable ops
+        if (!OpTrait::hasElementwiseMappableTraits(op))
+          return true;
+
+        VectorType resultType =
+            dyn_cast<VectorType>(op->getResult(0).getType());
+        if (!resultType)
+          return true;
+
+        // Check if all operands are vectors of the same shape
+        // TODO: Support other types.
+        for (Value operand : op->getOperands()) {
+          VectorType operandType = dyn_cast<VectorType>(operand.getType());
+          if (!operandType || operandType.getShape() != resultType.getShape()) {
+            return true;
+          }
+        }
+
+        xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
+        return isLegal(layout);
+      });
+
   target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
       [=](UnrealizedConversionCastOp op) {
         return llvm::is_contained(existingCastOps, op.getOperation());

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
new file mode 100644
index 0000000000000..64f01d61d6e80
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
@@ -0,0 +1,164 @@
+// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
+
+gpu.module @test_elementwise_ops {
+  // CHECK-LABEL: unary_ops
+  gpu.func @unary_ops(%a: memref<24x32xf32>) {
+    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %load_a = xegpu.load_nd %tdesc_a
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+      -> vector<24x32xf32>
+    // CHECK: math.exp {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+    %exp = math.exp %load_a
+      {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+      : vector<24x32xf32>
+    // CHECK: arith.negf {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+    %negf = arith.negf %load_a
+      {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+      : vector<24x32xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: binary_ops
+  gpu.func @binary_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>) {
+    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %load_a = xegpu.load_nd %tdesc_a
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+      -> vector<24x32xf32>
+    %load_b = xegpu.load_nd %tdesc_b
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+      -> vector<24x32xf32>
+    // CHECK: arith.addf {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<12x8xf32>
+    %addf = arith.addf %load_a, %load_b
+      {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+      : vector<24x32xf32>
+    // CHECK: math.powf {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<12x8xf32>
+    %powf = math.powf %load_a, %load_b
+      {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+      : vector<24x32xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: ternary_ops
+  gpu.func @ternary_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi1>) {
+    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi1>
+      -> !xegpu.tensor_desc<24x32xi1, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %load_a = xegpu.load_nd %tdesc_a
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+      -> vector<24x32xf32>
+    %load_b = xegpu.load_nd %tdesc_b
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+      -> vector<24x32xf32>
+    %load_c = xegpu.load_nd %tdesc_c
+      : !xegpu.tensor_desc<24x32xi1, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+      -> vector<24x32xi1>
+    // CHECK: arith.select {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<12x8xi1>, vector<12x8xf32>
+    %select = arith.select %load_c, %load_a, %load_b
+      {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+      : vector<24x32xi1>, vector<24x32xf32>
+    // CHECK: math.fma  {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<12x8xf32>
+    %fma = math.fma %load_a, %load_b, %load_a
+      {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+      : vector<24x32xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: type_conversion_ops
+  gpu.func @type_conversion_ops(%a: memref<24x32xf32>, %b: memref<24x32xi32>) {
+    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xi32>
+      -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %load_a = xegpu.load_nd %tdesc_a
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+      -> vector<24x32xf32>
+    %load_b = xegpu.load_nd %tdesc_b
+      : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+      -> vector<24x32xi32>
+    // CHECK: arith.truncf {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<12x8xf32> to vector<12x8xf16>
+    %truncf = arith.truncf %load_a
+      {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+      : vector<24x32xf32> to vector<24x32xf16>
+    // CHECK: arith.bitcast {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<12x8xi32> to vector<12x8xf32>
+    %bitcast = arith.bitcast %load_b
+      {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+      : vector<24x32xi32> to vector<24x32xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: comparison_ops
+  gpu.func @comparison_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
+    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32>
+      -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32>
+      -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %load_a = xegpu.load_nd %tdesc_a
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+      -> vector<24x32xf32>
+    %load_b = xegpu.load_nd %tdesc_b
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+      -> vector<24x32xf32>
+    %load_c = xegpu.load_nd %tdesc_c
+      : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+      -> vector<24x32xi32>
+    %load_d = xegpu.load_nd %tdesc_d
+      : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+      -> vector<24x32xi32>
+    // CHECK: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<12x8xf32>
+    %cmpf = arith.cmpf ult, %load_a, %load_b
+      {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+      : vector<24x32xf32>
+    // CHECK: arith.cmpi eq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<12x8xi32>
+    %cmpi = arith.cmpi eq, %load_c, %load_d
+      {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+      : vector<24x32xi32>
+    gpu.return
+  }
+
+  // 1 to N decomposition of elementwise operations
+  // CHECK-LABEL: elementwise_ops_rr_assignment
+  gpu.func @elementwise_ops_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>) {
+     %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %load_a = xegpu.load_nd %tdesc_a
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      -> vector<24x32xf32>
+    %load_b = xegpu.load_nd %tdesc_b
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      -> vector<24x32xf32>
+    // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
+    // CHECK-SAME-COUNT-12: : vector<2x2xf32>
+    // CHECK-NOT: arith.negf
+    %negf = arith.negf %load_a
+      {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+      : vector<24x32xf32>
+    // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
+    // CHECK-SAME-COUNT-12: : vector<2x2xf32>
+    // CHECK-NOT: math.powf
+    %powf = math.powf %load_a, %load_b
+      {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+      : vector<24x32xf32>
+    gpu.return
+  }
+}


        


More information about the Mlir-commits mailing list