[Mlir-commits] [mlir] [MLIR][XeGPU] Distribute vector.step & vector.shape_cast op from wg to sg (PR #155443)

Adam Siemieniuk llvmlistbot at llvm.org
Mon Sep 8 10:15:20 PDT 2025


================
@@ -919,6 +905,118 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
   }
 };
 
+// This pattern distributes the vector.step ops to work at subgroup level
+struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
+  using OpConversionPattern<vector::StepOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    Location loc = op.getLoc();
+    VectorType type = op.getResult().getType();
+    auto wgShape = type.getShape();
+    std::optional<SmallVector<int64_t>> sgShape =
+        getSgShapeAndCount(wgShape, layout).first;
+    if (!sgShape)
+      return failure();
+
+    Value sgId =
+        gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+    auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+    if (failed(sgOffsets))
+      return failure();
+
+    VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
+    auto steps = vector::StepOp::create(rewriter, loc, newTy);
+    SmallVector<Value> newOps;
+    for (auto offsets : *sgOffsets) {
+      // Broadcast the offset scalar to a vector & add to the base steps
+      auto bcastOffset =
+          vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
+      auto finalSteps =
+          arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
+      if (!layout.getLaneLayoutAsInt().empty() ||
+          !layout.getLaneDataAsInt().empty()) {
+        xegpu::setDistributeLayoutAttr(steps->getResult(0),
+                                       layout.dropSgLayoutAndData());
+        xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
+                                       layout.dropSgLayoutAndData());
+        xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
+                                       layout.dropSgLayoutAndData());
+      }
+      newOps.push_back(finalSteps);
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newOps});
+    return success();
+  }
+};
+
+// This pattern transforms vector.shape_cast ops to work at subgroup level.
+struct WgToSgVectorShapeCastOp
+    : public OpConversionPattern<vector::ShapeCastOp> {
+  using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
+    if (!resultType)
+      return failure();
+
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+
+    // TODO: Add check for compatible layouts in layout attr.
+    auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
+    if (!srcType)
+      return failure();
+
+    // Check that shape_cast only adds/removes unit dimensions,
+    auto onlyUnitDims = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
+      // Remove all 1s from both shapes and compare the rest.
+      SmallVector<int64_t> srcNonUnit, dstNonUnit;
+      for (int64_t d : src)
+        if (d != 1)
+          srcNonUnit.push_back(d);
+      for (int64_t d : dst)
+        if (d != 1)
+          dstNonUnit.push_back(d);
+      return srcNonUnit == dstNonUnit;
+    };
+
+    if (!onlyUnitDims(srcType.getShape(), sgShape) ||
+        !onlyUnitDims(sgShape, srcType.getShape()))
----------------
adam-smnk wrote:

Why does it have to be called twice?

https://github.com/llvm/llvm-project/pull/155443


More information about the Mlir-commits mailing list