[Mlir-commits] [mlir] [MLIR][XeGPU] Distribute non-splat constant from wg to sg (PR #161416)

Igor Zamyatin llvmlistbot at llvm.org
Mon Oct 6 14:40:12 PDT 2025


================
@@ -733,22 +733,172 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     int count;
     std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
 
-    // Current limitation: constant of vector with single value.
-    // TODO: support more complex cases, e.g., vector with multiple values.
-    Attribute singleVal = vecAttr.getSplatValue<Attribute>();
-
     auto newType = VectorType::get(sgShape, vecType.getElementType());
-    auto sgAttr = DenseElementsAttr::get(newType, singleVal);
-    auto cstOp =
-        arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
-    if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
-        !layout.getEffectiveInstDataAsInt().empty())
-      xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
-                                     layout.dropSgLayoutAndData());
-    SmallVector<Value> newConsts(count, cstOp);
+    Location loc = op.getLoc();
+    auto eltType = vecType.getElementType();
 
-    rewriter.replaceOpWithMultiple(op, {newConsts});
-    return success();
+    auto setLayoutIfNeeded = [&](Value val) {
+      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+          !layout.getEffectiveInstDataAsInt().empty()) {
+        xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
+                                       layout.dropSgLayoutAndData());
+      }
+    };
+
+    if (vecAttr.isSplat()) {
+      // Splat: single value for all subgroups
+      Attribute singleVal = vecAttr.getSplatValue<Attribute>();
+      auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+      auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
+      setLayoutIfNeeded(cstOp->getResult(0));
+      rewriter.replaceOp(op, cstOp);
+      return success();
+    } else if (sgShape == wgShape) { // if the entire vector is shared by all
+                                     // subgroups, don't distribute
+      auto newConstOp =
+          arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
+      setLayoutIfNeeded(newConstOp->getResult(0));
+      rewriter.replaceOp(op, newConstOp);
+      return success();
+    } else {
+      // Non-splat constant
+      // Only supports 1D & 2D
+      // TODO: support other cases that require SLM access
+      if (!eltType.isIndex())
+        return rewriter.notifyMatchFailure(
+            op, "Unsupported element type for non-splat constant op.");
+
+      if (wgShape.size() > 2)
+        return rewriter.notifyMatchFailure(
+            op, "Only 1D & 2D vector constant supported");
+
+      SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
+      int64_t stride = 0;
+      int64_t rowStride = 0, colStride = 0;
+      if (wgShape.size() == 1) {
+        // 1D case: single stride
+        if (values.size() > 1) {
+          stride = cast<IntegerAttr>(values[1]).getInt() -
+                   cast<IntegerAttr>(values[0]).getInt();
+          for (size_t i = 2; i < values.size(); ++i) {
+            int64_t diff = cast<IntegerAttr>(values[i]).getInt() -
+                           cast<IntegerAttr>(values[i - 1]).getInt();
+            if (diff != stride)
+              return rewriter.notifyMatchFailure(
+                  op, "Non-constant stride in non-splat constant op.");
+          }
+        }
+      } else if (wgShape.size() == 2) {
+        // 2D case: row stride and column stride
+        int64_t rows = wgShape[0], cols = wgShape[1];
+        // Compute col stride (stride between elements in a column)
+        if (cols > 1) {
+          colStride = cast<IntegerAttr>(values[1]).getInt() -
+                      cast<IntegerAttr>(values[0]).getInt();
+          for (int64_t r = 0; r < rows; ++r) {
+            for (int64_t c = 1; c < cols; ++c) {
+              int64_t idx = r * cols + c;
+              int64_t prevIdx = r * cols + (c - 1);
+              int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
+                             cast<IntegerAttr>(values[prevIdx]).getInt();
+              if (diff != colStride)
+                return rewriter.notifyMatchFailure(
+                    op, "Non-constant column stride in 2D constant op.");
+            }
+          }
+        }
+        // Compute row stride (stride between elements in a row)
+        if (rows > 1) {
+          rowStride = cast<IntegerAttr>(values[cols]).getInt() -
+                      cast<IntegerAttr>(values[0]).getInt();
+          for (int64_t c = 0; c < cols; ++c) {
+            for (int64_t r = 1; r < rows; ++r) {
+              int64_t idx = r * cols + c;
+              int64_t prevIdx = (r - 1) * cols + c;
+              int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
+                             cast<IntegerAttr>(values[prevIdx]).getInt();
+              if (diff != rowStride)
+                return rewriter.notifyMatchFailure(
+                    op, "Non-constant row stride in 2D constant op.");
+            }
+          }
+        }
+      }
+
+      // Determine the shape of the base tile for each subgroup.
+      SmallVector<int64_t> baseTileShape;
----------------
Garra1980 wrote:

can you just use sgShape directly instead of new var?

https://github.com/llvm/llvm-project/pull/161416


More information about the Mlir-commits mailing list