[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for elementwise ops in Wg to Sg distribute pass [1/N] (PR #142797)

Nishant Patel llvmlistbot at llvm.org
Mon Jun 16 10:33:51 PDT 2025


================
@@ -328,6 +331,63 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+// This pattern transforms elementwise ops in math/arith dialect
+struct WgToSgElementwiseOp : public ConversionPattern {
+  WgToSgElementwiseOp(MLIRContext *ctx)
+      : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only match ops with elementwise trait
+    if (!OpTrait::hasElementwiseMappableTraits(op))
+      return rewriter.notifyMatchFailure(op, "Not an elementwise op");
+
+    auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
+    if (!layout || !layout.getSgLayout())
+      return rewriter.notifyMatchFailure(
+          op, "Operation does not have a valid layout attribute for subgroup "
+              "distribution");
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+
+    size_t numVariants = operands.empty() ? 0 : operands.front().size();
+    for (auto &operandVec : operands)
+      if (operandVec.size() != numVariants)
+        return rewriter.notifyMatchFailure(
+            op, "Operand lists have mismatched sizes");
+
+    SmallVector<Value> newResults;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+
+    for (size_t i = 0; i < numVariants; ++i) {
+      SmallVector<Value> opOperands;
+      for (auto &operandVec : operands)
+        opOperands.push_back(operandVec[i]);
+
+      OperationState state(op->getLoc(), op->getName());
+      state.addOperands(opOperands);
+      state.addTypes(newResultType);
+      // Copy all attributes, but update "layout_result_0" to drop
+      // sgLayout/sgData
+      for (auto attr : op->getAttrs()) {
+        if (attr.getName() != "layout_result_0")
+          state.addAttribute(attr.getName(), attr.getValue());
+      }
+      Operation *newOp = rewriter.create(state);
+      xegpu::setLayoutAttr(newOp->getResult(0), layout.dropSgLayoutAndData());
----------------
nbpatel wrote:

the attribute is added to the newOp (transformed op) ..and the loop is traversing over the attributes not numvariants

https://github.com/llvm/llvm-project/pull/142797


More information about the Mlir-commits mailing list