[Mlir-commits] [mlir] [MLIR][XeGPU] Distribute non-splat constant from wg to sg (PR #161416)

Nishant Patel llvmlistbot at llvm.org
Mon Oct 6 13:21:18 PDT 2025


================
@@ -733,22 +733,116 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     int count;
     std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
 
-    // Current limitation: constant of vector with single value.
-    // TODO: support more complex cases, e.g., vector with multiple values.
-    Attribute singleVal = vecAttr.getSplatValue<Attribute>();
-
     auto newType = VectorType::get(sgShape, vecType.getElementType());
-    auto sgAttr = DenseElementsAttr::get(newType, singleVal);
-    auto cstOp =
-        arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
-    if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
-        !layout.getEffectiveInstDataAsInt().empty())
-      xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
-                                     layout.dropSgLayoutAndData());
-    SmallVector<Value> newConsts(count, cstOp);
+    Location loc = op.getLoc();
+    auto eltType = vecType.getElementType();
 
-    rewriter.replaceOpWithMultiple(op, {newConsts});
-    return success();
+    auto setLayoutIfNeeded = [&](Value val) {
+      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+          !layout.getEffectiveInstDataAsInt().empty()) {
+        xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
+                                       layout.dropSgLayoutAndData());
+      }
+    };
+
+    if (vecAttr.isSplat()) {
+      // Splat: single value for all subgroups
+      Attribute singleVal = vecAttr.getSplatValue<Attribute>();
+      auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+      auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
+      setLayoutIfNeeded(cstOp->getResult(0));
+      rewriter.replaceOp(op, cstOp);
+      return success();
+    } else if (sgShape == wgShape) { // if the entire vector is shared by all
+                                     // subgroups, don't distribute
+      auto newConstOp =
+          arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
+      setLayoutIfNeeded(newConstOp->getResult(0));
+      rewriter.replaceOp(op, newConstOp);
+      return success();
+    } else {
+      // Non-splat constant
+      // Only supports 1D & 2D (with one unit dim)
+      // TODO: support other cases that require SLM access
+      if (!eltType.isIndex())
+        return rewriter.notifyMatchFailure(
+            op, "Unsupported element type for non-splat constant op.");
+
+      SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
+      if (wgShape.size() > 2)
+        return rewriter.notifyMatchFailure(
+            op, "Only 1D & 2D vector constant supported");
+
+      // allow 2D vector/distributions with one unit dim
+      auto hasTwoNonUnitDims = [](ArrayRef<int64_t> dims) {
+        return dims.size() == 2 && dims[0] != 1 && dims[1] != 1;
+      };
+      if (hasTwoNonUnitDims(wgShape) || hasTwoNonUnitDims(sgLayout))
+        return rewriter.notifyMatchFailure(
+            op, "2D vector/distribution only supported with 1 unit dim");
+
+      int64_t nonUnitDim = 0;
+      if (wgShape.size() == 2)
+        nonUnitDim = wgShape[0] != 1 ? 0 : 1;
+
+      SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
+      int64_t stride = 0;
+      if (values.size() > 1) {
+        stride = cast<IntegerAttr>(values[1]).getInt() -
+                 cast<IntegerAttr>(values[0]).getInt();
+        for (size_t i = 2; i < values.size(); ++i) {
+          int64_t diff = cast<IntegerAttr>(values[i]).getInt() -
+                         cast<IntegerAttr>(values[i - 1]).getInt();
+          if (diff != stride)
+            return rewriter.notifyMatchFailure(
+                op, "Non-constant stride in non-splat constant op.");
+        }
+      }
+
+      int sgData = 1;
+      if (sgShape.size() == 1) {
+        sgData = static_cast<int>(sgShape[0]);
+      } else if (sgShape.size() == 2) {
+        sgData = static_cast<int>(sgShape[0] != 1 ? sgShape[0] : sgShape[1]);
+      } else {
+        return rewriter.notifyMatchFailure(
+            op, "Only 1D or 2D vector constant supported");
+      }
+
+      // Create a constant for the base tile
+      SmallVector<Attribute> baseTileValues;
+      for (int i = 0; i < sgData; ++i)
+        baseTileValues.push_back(values[i]);
+      auto tileAttr = DenseElementsAttr::get(VectorType::get({sgData}, eltType),
+                                             baseTileValues);
+      auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
+
+      // Get subgroup id
+      Value sgId =
+          gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+
+      auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+      if (failed(sgOffsets))
+        return failure();
+
+      auto strideConst = rewriter.create<arith::ConstantIndexOp>(loc, stride);
+      SmallVector<Value> newConstOps;
+      for (auto offsets : *sgOffsets) {
+        // Multiply offset with stride, broadcast it and add to baseConstVec
+        Value mulOffset = rewriter.create<arith::MulIOp>(
+            loc, rewriter.getIndexType(), offsets[nonUnitDim], strideConst);
----------------
nbpatel wrote:

I changed it to supporting 2D vectors....the high level logic is having two strides,
rowStride & columnStride, and computing the value as rowOffset*rowStride + columnOffset*colStride and then broadcasting it to baseConstVec size and adding it with the baseConstVec. 

https://github.com/llvm/llvm-project/pull/161416


More information about the Mlir-commits mailing list