[Mlir-commits] [mlir] 8e17f80 - [MLIR][XeGPU] Distribute vector.step & vector.shape_cast op from wg to sg (#155443)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 12 14:33:56 PDT 2025


Author: Nishant Patel
Date: 2025-09-12T14:33:52-07:00
New Revision: 8e17f80908abd5a22acf962584371b71dffe6d15

URL: https://github.com/llvm/llvm-project/commit/8e17f80908abd5a22acf962584371b71dffe6d15
DIFF: https://github.com/llvm/llvm-project/commit/8e17f80908abd5a22acf962584371b71dffe6d15.diff

LOG: [MLIR][XeGPU] Distribute vector.step & vector.shape_cast op from wg to sg (#155443)

This PR adds patterns to distribute vector.step and vector.shape_cast op
from wg to sg and it also enables constant, broadcast and elementwise
ops to handle the slice attribute

Added: 
    

Modified: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 3f48400fedf5e..d7592fed6d186 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -468,6 +468,7 @@ struct WgToSgVectorBroadcastOp
   LogicalResult
   matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+
     VectorType resultType = op.getResult().getType();
     ArrayRef<int64_t> wgShape = resultType.getShape();
 
@@ -476,43 +477,24 @@ struct WgToSgVectorBroadcastOp
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
-    // TODO: Currently only supports cases where the source and result ranks
-    // are the same.
-    auto srcType =
-        dyn_cast<VectorType>(adaptor.getOperands().front()[0].getType());
-    if (!srcType || srcType.getRank() != resultType.getRank())
-      return failure();
-
     SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
     VectorType newResultType =
         VectorType::get(sgShape, resultType.getElementType());
 
-    // Check if the output layout is distributable
-    SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
-    if (sgLayout.empty())
-      return failure();
-
     if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
       return failure();
 
-    // Check if the srcShape has unit dim in dimensions being broadcasted,
-    // and the other dimensions are the same as the destination type
-    // TODO: Generalize it
-    auto srcShape = srcType.getShape();
-    for (size_t i = 0; i < srcShape.size(); ++i) {
-      if (srcShape[i] != 1 && srcShape[i] != sgShape[i])
-        return failure();
-    }
-
     SmallVector<Value> newBroadcastOps;
     for (auto operand : adaptor.getOperands().front()) {
       auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
                                                       newResultType, operand);
-      xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
-                                     layout.dropSgLayoutAndData());
+      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+          !layout.getEffectiveInstDataAsInt().empty())
+        xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
+                                       layout.dropSgLayoutAndData());
+
       newBroadcastOps.push_back(newBroadcast.getResult());
     }
-
     rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
     return success();
   }
@@ -564,9 +546,11 @@ struct WgToSgElementwiseOp : public ConversionPattern {
       // Copy all attributes, but update "layout_result_0" to drop
       // sgLayout/sgData
       for (auto attr : op->getAttrs()) {
-        if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue())) {
-          if (auto newLayout = layout.dropSgLayoutAndData())
-            state.addAttribute(attr.getName(), newLayout);
+        if (auto layout =
+                dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
+          if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+              !layout.getEffectiveInstDataAsInt().empty())
+            state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
         } else {
           state.addAttribute(attr.getName(), attr.getValue());
         }
@@ -757,8 +741,10 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     auto sgAttr = DenseElementsAttr::get(newType, singleVal);
     auto cstOp =
         arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
-    if (auto newLayout = layout.dropSgLayoutAndData())
-      xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
+    if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+        !layout.getEffectiveInstDataAsInt().empty())
+      xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
+                                     layout.dropSgLayoutAndData());
     SmallVector<Value> newConsts(count, cstOp);
 
     rewriter.replaceOpWithMultiple(op, {newConsts});
@@ -919,6 +905,128 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
   }
 };
 
+// This pattern distributes the vector.step ops to work at subgroup level
+struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
+  using OpConversionPattern<vector::StepOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    Location loc = op.getLoc();
+    VectorType type = op.getResult().getType();
+    auto wgShape = type.getShape();
+    std::optional<SmallVector<int64_t>> sgShape =
+        getSgShapeAndCount(wgShape, layout).first;
+    if (!sgShape)
+      return failure();
+
+    Value sgId =
+        gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+    auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+    if (failed(sgOffsets))
+      return failure();
+
+    VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
+    auto steps = vector::StepOp::create(rewriter, loc, newTy);
+    SmallVector<Value> newOps;
+    for (auto offsets : *sgOffsets) {
+      // Broadcast the offset scalar to a vector & add to the base steps
+      auto bcastOffset =
+          vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
+      auto finalSteps =
+          arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
+      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+          !layout.getEffectiveInstDataAsInt().empty()) {
+        xegpu::setDistributeLayoutAttr(steps->getResult(0),
+                                       layout.dropSgLayoutAndData());
+        xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
+                                       layout.dropSgLayoutAndData());
+        xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
+                                       layout.dropSgLayoutAndData());
+      }
+      newOps.push_back(finalSteps);
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newOps});
+    return success();
+  }
+};
+
+// This pattern transforms vector.shape_cast ops to work at subgroup level.
+struct WgToSgVectorShapeCastOp
+    : public OpConversionPattern<vector::ShapeCastOp> {
+  using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
+    if (!resultType)
+      return failure();
+
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+
+    // TODO: Add check for compatible layouts in layout attr.
+    auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
+    if (!srcType)
+      return failure();
+
+    // Check that shape_cast only adds/removes unit dimensions,
+    auto onlyUnitDims = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
+      // Remove all 1s from both shapes and compare the rest.
+      SmallVector<int64_t> srcNonUnit, dstNonUnit;
+      for (int64_t d : src)
+        if (d != 1)
+          srcNonUnit.push_back(d);
+      for (int64_t d : dst)
+        if (d != 1)
+          dstNonUnit.push_back(d);
+      return srcNonUnit == dstNonUnit;
+    };
+
+    if (!onlyUnitDims(srcType.getShape(), sgShape))
+      return failure();
+
+    // For rank reducing or increasing shape_cast ops, the lower rank layout
+    // must be a slice of higher rank layout.
+    int64_t sourceRank = srcType.getRank();
+    int64_t resultRank = sgShape.size();
+    xegpu::DistributeLayoutAttr sourceLayout =
+        xegpu::getDistributeLayoutAttr(op.getSource());
+    if (sourceRank < resultRank && !sourceLayout.isSliceOf(layout))
+      return failure();
+    if (sourceRank > resultRank && !layout.isSliceOf(sourceLayout))
+      return failure();
+
+    SmallVector<Value> newShapeCastOps;
+    for (auto src : adaptor.getSource()) {
+      auto newShapeCast =
+          rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
+      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+          !layout.getEffectiveInstDataAsInt().empty())
+        xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
+                                       layout.dropSgLayoutAndData());
+      newShapeCastOps.push_back(newShapeCast.getResult());
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -932,7 +1040,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
            WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
            WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
-           WgToSgStoreMatrixOp>(patterns.getContext());
+           WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp>(
+          patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -1054,7 +1163,16 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         auto vecType = dyn_cast<VectorType>(op.getType());
         if (!vecType)
           return true;
-        return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
+
+        auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
+        return isLegal(layout);
+      });
+
+  target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
+      [=](Operation *op) -> bool {
+        // Check for either a SliceAttr or LayoutAttr on the result.
+        auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
+        return isLegal(layout);
       });
 
   target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index afb2bf876c18f..3478a9b91da5f 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -2,6 +2,7 @@
 
 //CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
 //CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
+//CHECK: #map2 = affine_map<()[s0] -> (s0 floordiv 8)>
 gpu.module @test_distribution {
   // CHECK-LABEL: create_nd_tdesc_no_offset
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -365,4 +366,62 @@ gpu.module @test_distribution {
     xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
     gpu.return
   }
+
+  // CHECK-LABEL: vector_step_op
+  gpu.func @vector_step_op_slice_attr() {
+    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+    //CHECK-DAG: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]]
+    //CHECK-DAG: [[c32:%.+]] = arith.constant 32 : index
+    //CHECK-DAG: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
+    //CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK-DAG: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+    //CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK-DAG: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+    //CHECK-DAG: [[BASE:%.+]] = vector.step : vector<32xindex>
+    //CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
+    //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+    %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
+    gpu.return
+  }
+
+  gpu.func @vector_step_op_layout_attr() {
+    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+    //CHECK-DAG: [[c16:%.+]] = arith.constant 16 : index
+    //CHECK-DAG: [[c8:%.+]] = arith.constant 8 : index
+    //CHECK-DAG: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]]
+    //CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK-DAG: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+    //CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK-DAG: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+    //CHECK-DAG: [[BASE:%.+]] = vector.step : vector<8xindex>
+    //CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
+    //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex>
+    %step = vector.step {layout_result_0 = #xegpu.layout<sg_layout = [16], sg_data = [8]>}: vector<128xindex>
+    gpu.return
+  }
+
+  // CHECK-LABEL: constant_with_slice_attr
+  gpu.func @constant_with_slice_attr() {
+    //CHECK: [[cst:%.+]] = arith.constant dense<10> : vector<1xindex>
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 1]>, dims = [1, 2, 3]>} dense<10> : vector<4xindex>
+    gpu.return
+  }
+
+  // CHECK-LABEL: vector_shape_cast
+  gpu.func @vector_shape_cast() {
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, dims = [0, 1, 2]>} dense<10> : vector<128xindex>
+    %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, dims = [0, 1, 2]>} : vector<128xindex>
+    %muli = arith.muli %cst, %step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, dims = [0, 1, 2]>} : vector<128xindex>
+    //CHECK: vector.shape_cast {{.*}} : vector<32xindex> to vector<1x1x1x32xindex>
+    %shape_cast = vector.shape_cast %muli {layout_result_0 = #xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>} : vector<128xindex> to vector<1x1x1x128xindex>
+    gpu.return
+  }
+
+  // CHECK-LABEL: vector_broadcast
+  gpu.func @vector_broadcast(%arg0: index, %arg1: index) {
+    %muli = arith.muli %arg0, %arg1 : index
+    // CHECK: vector.broadcast {{.*}} : index to vector<1x1x1x32xindex>
+    %broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : index to vector<4x2x6x32xindex>
+    gpu.return
+  }
 }


        


More information about the Mlir-commits mailing list