[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for elementwise ops in Wg to Sg distribute pass [1/N] (PR #142797)

Chao Chen llvmlistbot at llvm.org
Tue Jun 10 12:23:22 PDT 2025


================
@@ -314,6 +317,90 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+// This pattern transforms elementwise ops (unary/binary) in math/arith dialect
+template <typename Op>
+struct WgToSgElementwiseOp : public OpConversionPattern<Op> {
+  using OpConversionPattern<Op>::OpConversionPattern;
+  using OneToNOpAdaptor = typename OpConversionPattern<Op>::OneToNOpAdaptor;
+
+  LogicalResult
+  matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // All operands/results must be 1D or 2D vectors
+    auto resultType = dyn_cast<VectorType>(op.getResult().getType());
+    if (!resultType || (resultType.getRank() != 1 && resultType.getRank() != 2))
+      return rewriter.notifyMatchFailure(
+          op, "Result type is not a 1D or 2D vector");
+
+    ArrayRef<int64_t> shape = resultType.getShape();
+    for (Value operand : op->getOperands()) {
+      auto operandType = dyn_cast<VectorType>(operand.getType());
+      if (!operandType || operandType.getRank() != resultType.getRank() ||
+          operandType.getShape() != shape) {
+        return rewriter.notifyMatchFailure(
+            op, "Operand type is not a 1D or 2D vector with the same shape as "
+                "result type");
+      }
+    }
+
+    auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+    if (!layout || !layout.getSgLayout())
+      return rewriter.notifyMatchFailure(
+          op, "Operation does not have a valid layout attribute for subgroup "
+              "distribution");
+
+    // Extract sgShape from layout
+    SmallVector<int64_t> sgShape;
+    if (auto sgDataAttr = layout.getSgData()) {
+      sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
+    } else {
+      auto sgLayoutArr = layout.getSgLayout();
+      sgShape.reserve(shape.size());
+      for (size_t i = 0; i < shape.size(); ++i) {
+        assert(sgLayoutArr[i] != 0 && "sgLayout elements must be non-zero");
+        sgShape.push_back(shape[i] / sgLayoutArr[i]);
+      }
+    }
+
+    size_t numVariants = adaptor.getOperands().empty()
+                             ? 0
+                             : adaptor.getOperands().front().size();
+    for (auto &operandVec : adaptor.getOperands())
+      if (operandVec.size() != numVariants)
+        return rewriter.notifyMatchFailure(
+            op, "Operand lists have mismatched sizes");
+
+    SmallVector<Value> newResults;
+
+    auto origResultType = dyn_cast<VectorType>(op->getResult(0).getType());
+    VectorType newResultType =
+        origResultType
+            ? VectorType::get(sgShape, origResultType.getElementType())
+            : VectorType::get(sgShape, resultType.getElementType());
+
+    for (size_t i = 0; i < numVariants; ++i) {
+      SmallVector<Value> operands;
+      for (auto &operandVec : adaptor.getOperands())
+        operands.push_back(operandVec[i]);
+
+      auto newOp = rewriter.create<Op>(op.getLoc(), newResultType, operands);
+
+      // Copy all attributes except "layout", and add "layout_result_0" with
+      // sgLayout/data dropped
+      for (auto attr : op->getAttrs()) {
+        if (attr.getName() != "layout")
+          newOp->setAttr(attr.getName(), attr.getValue());
+      }
+      newOp->setAttr("layout_result_0", layout.dropSgLayoutAndData());
----------------
chencha3 wrote:

consider the use of `xegpu::setLayoutAttr(OpResult, LayoutAttr)` interface.

https://github.com/llvm/llvm-project/pull/142797


More information about the Mlir-commits mailing list