[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for elementwise ops in Wg to Sg distribute pass [1/N] (PR #142797)
Chao Chen
llvmlistbot at llvm.org
Fri Jun 13 12:07:49 PDT 2025
================
@@ -328,6 +331,63 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};
+// This pattern transforms elementwise ops in math/arith dialect
+struct WgToSgElementwiseOp : public ConversionPattern {
+ WgToSgElementwiseOp(MLIRContext *ctx)
+ : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only match ops with elementwise trait
+ if (!OpTrait::hasElementwiseMappableTraits(op))
+ return rewriter.notifyMatchFailure(op, "Not an elementwise op");
+
+ auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
+ if (!layout || !layout.getSgLayout())
+ return rewriter.notifyMatchFailure(
+ op, "Operation does not have a valid layout attribute for subgroup "
+ "distribution");
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+
+ size_t numVariants = operands.empty() ? 0 : operands.front().size();
+ for (auto &operandVec : operands)
+ if (operandVec.size() != numVariants)
+ return rewriter.notifyMatchFailure(
+ op, "Operand lists have mismatched sizes");
+
+ SmallVector<Value> newResults;
+ VectorType newResultType =
+ VectorType::get(sgShape, resultType.getElementType());
+
+ for (size_t i = 0; i < numVariants; ++i) {
+ SmallVector<Value> opOperands;
+ for (auto &operandVec : operands)
+ opOperands.push_back(operandVec[i]);
+
+ OperationState state(op->getLoc(), op->getName());
+ state.addOperands(opOperands);
+ state.addTypes(newResultType);
+ // Copy all attributes, but update "layout_result_0" to drop
+ // sgLayout/sgData
+ for (auto attr : op->getAttrs()) {
+ if (attr.getName() != "layout_result_0")
----------------
chencha3 wrote:
using `isa<LayoutAttr>()` is more stable. Try to avoid the use of `layout_result_...` directly.
https://github.com/llvm/llvm-project/pull/142797
More information about the Mlir-commits
mailing list