[Mlir-commits] [mlir] 8063bd1 - [MLIR][XeGPU] Add support for elementwise ops in Wg to Sg distribute pass [1/N] (#142797)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 17 09:55:05 PDT 2025
Author: Nishant Patel
Date: 2025-06-17T09:55:02-07:00
New Revision: 8063bd153c6aca43869d96aee64aeceb9be98ca5
URL: https://github.com/llvm/llvm-project/commit/8063bd153c6aca43869d96aee64aeceb9be98ca5
DIFF: https://github.com/llvm/llvm-project/commit/8063bd153c6aca43869d96aee64aeceb9be98ca5.diff
LOG: [MLIR][XeGPU] Add support for elementwise ops in Wg to Sg distribute pass [1/N] (#142797)
This PR adds support for Elementwise operations' (unary & binary)
lowering from Workgroup to Subgroup.
Added:
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
Modified:
mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a26c6b52f0ddc..e3563d10bc6f1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -8,10 +8,12 @@
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -19,6 +21,7 @@
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Transforms/DialectConversion.h"
+#include <optional>
namespace mlir {
namespace xegpu {
@@ -328,6 +331,65 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};
+// This pattern transforms elementwise ops to work at subgroup level.
+struct WgToSgElementwiseOp : public ConversionPattern {
+ WgToSgElementwiseOp(MLIRContext *ctx)
+ : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only match ops with elementwise trait and single result.
+ if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
+ return failure();
+
+ auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
+ assert(resultType && "Expected result to be a VectorType");
+
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
+ if (!layout || !layout.getSgLayout())
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+
+ size_t numVariants = operands.empty() ? 0 : operands.front().size();
+
+ if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
+ return operandVec.size() != numVariants;
+ }))
+ return failure();
+
+ SmallVector<Value> newResults;
+ VectorType newResultType =
+ VectorType::get(sgShape, resultType.getElementType());
+
+ for (size_t i = 0; i < numVariants; ++i) {
+ SmallVector<Value> opOperands;
+ for (auto &operandVec : operands)
+ opOperands.push_back(operandVec[i]);
+
+ OperationState state(op->getLoc(), op->getName());
+ state.addOperands(opOperands);
+ state.addTypes(newResultType);
+ // Copy all attributes, but update "layout_result_0" to drop
+ // sgLayout/sgData
+ for (auto attr : op->getAttrs()) {
+ if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue()))
+ state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
+ else
+ state.addAttribute(attr.getName(), attr.getValue());
+ }
+ Operation *newOp = rewriter.create(state);
+ newResults.push_back(newOp->getResult(0));
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newResults});
+ return success();
+ }
+};
+
// Handles UnrealizedConversionCastOp generated during
// SCFStructuralTypeConversions (step 1). This op may appear as either a
// target or source materialization for Vector values, e.g.:
@@ -411,7 +473,8 @@ namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
- UnrealizedConversionCastOpPattern>(patterns.getContext());
+ UnrealizedConversionCastOpPattern, WgToSgElementwiseOp>(
+ patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -518,6 +581,30 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});
+ target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
+ [=](Operation *op) -> std::optional<bool> {
+ // Only handle elementwise mappable ops
+ if (!OpTrait::hasElementwiseMappableTraits(op))
+ return true;
+
+ VectorType resultType =
+ dyn_cast<VectorType>(op->getResult(0).getType());
+ if (!resultType)
+ return true;
+
+ // Check if all operands are vectors of the same shape
+ // TODO: Support other types.
+ for (Value operand : op->getOperands()) {
+ VectorType operandType = dyn_cast<VectorType>(operand.getType());
+ if (!operandType || operandType.getShape() != resultType.getShape()) {
+ return true;
+ }
+ }
+
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
+ return isLegal(layout);
+ });
+
target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
[=](UnrealizedConversionCastOp op) {
return llvm::is_contained(existingCastOps, op.getOperation());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
new file mode 100644
index 0000000000000..64f01d61d6e80
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
@@ -0,0 +1,164 @@
+// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
+
+gpu.module @test_elementwise_ops {
+ // CHECK-LABEL: unary_ops
+ gpu.func @unary_ops(%a: memref<24x32xf32>) {
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %load_a = xegpu.load_nd %tdesc_a
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ // CHECK: math.exp {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ %exp = math.exp %load_a
+ {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ // CHECK: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ %negf = arith.negf %load_a
+ {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: binary_ops
+ gpu.func @binary_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>) {
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %load_a = xegpu.load_nd %tdesc_a
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_b = xegpu.load_nd %tdesc_b
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ // CHECK: arith.addf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<12x8xf32>
+ %addf = arith.addf %load_a, %load_b
+ {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ // CHECK: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<12x8xf32>
+ %powf = math.powf %load_a, %load_b
+ {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: ternary_ops
+ gpu.func @ternary_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi1>) {
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi1>
+ -> !xegpu.tensor_desc<24x32xi1, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %load_a = xegpu.load_nd %tdesc_a
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_b = xegpu.load_nd %tdesc_b
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_c = xegpu.load_nd %tdesc_c
+ : !xegpu.tensor_desc<24x32xi1, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xi1>
+ // CHECK: arith.select {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<12x8xi1>, vector<12x8xf32>
+ %select = arith.select %load_c, %load_a, %load_b
+ {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi1>, vector<24x32xf32>
+ // CHECK: math.fma {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<12x8xf32>
+ %fma = math.fma %load_a, %load_b, %load_a
+ {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: type_conversion_ops
+ gpu.func @type_conversion_ops(%a: memref<24x32xf32>, %b: memref<24x32xi32>) {
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %load_a = xegpu.load_nd %tdesc_a
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_b = xegpu.load_nd %tdesc_b
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+ // CHECK: arith.truncf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<12x8xf32> to vector<12x8xf16>
+ %truncf = arith.truncf %load_a
+ {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32> to vector<24x32xf16>
+ // CHECK: arith.bitcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<12x8xi32> to vector<12x8xf32>
+ %bitcast = arith.bitcast %load_b
+ {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32> to vector<24x32xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: comparison_ops
+ gpu.func @comparison_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %load_a = xegpu.load_nd %tdesc_a
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_b = xegpu.load_nd %tdesc_b
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_c = xegpu.load_nd %tdesc_c
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+ %load_d = xegpu.load_nd %tdesc_d
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+ // CHECK: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<12x8xf32>
+ %cmpf = arith.cmpf ult, %load_a, %load_b
+ {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ // CHECK: arith.cmpi eq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<12x8xi32>
+ %cmpi = arith.cmpi eq, %load_c, %load_d
+ {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xi32>
+ gpu.return
+ }
+
+ // 1 to N decomposition of elementwise operations
+ // CHECK-LABEL: elementwise_ops_rr_assignment
+ gpu.func @elementwise_ops_rr_assignment(%a: memref<24x32xf32>, %b: memref<24x32xf32>) {
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ %load_a = xegpu.load_nd %tdesc_a
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_b = xegpu.load_nd %tdesc_b
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
+ // CHECK-SAME-COUNT-12: : vector<2x2xf32>
+ // CHECK-NOT: arith.negf
+ %negf = arith.negf %load_a
+ {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
+ // CHECK-SAME-COUNT-12: : vector<2x2xf32>
+ // CHECK-NOT: math.powf
+ %powf = math.powf %load_a, %load_b
+ {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ gpu.return
+ }
+}
More information about the Mlir-commits
mailing list