[Mlir-commits] [mlir] [MLIR][XeGPU] Support order attribute and add pattern for vector.transpose in WgToSg Pass (PR #165307)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Mon Oct 27 12:54:46 PDT 2025
    
    
  
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-gpu
Author: Nishant Patel (nbpatel)
<details>
<summary>Changes</summary>
---
Patch is 67.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165307.diff
8 Files Affected:
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+65-15) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+92-12) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir (+19-20) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir (+2-4) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir (+36-52) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir (+35-29) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+79-66) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir (+55-68) 
``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 24e909548fe0b..dfd4093905875 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -270,26 +270,76 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
 FailureOr<SmallVector<Value>>
 LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
                                   Value linearId) {
-  // delinearizeSubgroupId is only available for
-  // workgroup-level layout attribute
   if (!isForWorkgroup())
     return failure();
 
-  // TODO: handle order attribute
-  auto hasDefaultOrder = [&]() {
-    DenseI32ArrayAttr order = getOrder();
-    return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>(
-                         llvm::reverse(order.asArrayRef())));
-  };
-  if (!hasDefaultOrder())
-    return mlir::emitError(loc, "order attribute is currently not supported.");
+  SmallVector<int64_t> sgLayoutInt = getEffectiveSgLayoutAsInt();
+  DenseI32ArrayAttr orderAttr = getOrder();
 
-  auto dims =
-      llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value {
-        return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
-      });
+  // Handle order attribute
+  SmallVector<int64_t> order;
+  if (orderAttr && !orderAttr.empty()) {
+    order = llvm::to_vector(
+        llvm::map_range(orderAttr.asArrayRef(),
+                        [](int32_t idx) { return static_cast<int64_t>(idx); }));
+  } else {
+    // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
+    order = llvm::to_vector(
+        llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
+  }
 
-  return affine::delinearizeIndex(builder, loc, linearId, dims);
+  if (order.size() != sgLayoutInt.size()) {
+    return failure();
+  }
+
+  SmallVector<Value> result(sgLayoutInt.size());
+  Value remaining = linearId;
+
+  /// Process dimensions in the order they appear in the order array
+  /// The first dimension in order is the fastest-changing
+  ///
+  /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]:
+  /// 
+  /// Initial: remaining=22, result=[?,?,?]
+  /// 
+  /// i=0 (process columns, dimIdx=2, dimSize=4):
+  ///   result[2] = 22 % 4 = 2  (column coordinate)
+  ///   remaining = 22 / 4 = 5  (5 complete groups of 4 columns processed)
+  /// 
+  /// i=1 (process rows, dimIdx=1, dimSize=4):
+  ///   result[1] = 5 % 4 = 1   (row coordinate) 
+  ///   remaining = 5 / 4 = 1   (1 complete group of 4 rows processed)
+  /// 
+  /// i=2 (process layers, dimIdx=0, dimSize=2):
+  ///   result[0] = 1 % 2 = 1   (layer coordinate)
+  ///   (no remaining update - last iteration)
+  /// 
+  /// Final result: [1,1,2] = Layer 1, Row 1, Column 2
+  for (size_t i = 0; i < order.size(); ++i) {
+    int64_t dimIdx = order[i];
+    int64_t dimSize = sgLayoutInt[dimIdx];
+
+    Value dimSizeVal =
+        builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize);
+
+    /// Extract the coordinate for this dimension using modulo operation
+    /// This gives us "how far within this dimension" we are
+    /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within this
+    /// dimension)
+    result[dimIdx] =
+        builder.createOrFold<index::RemUOp>(loc, remaining, dimSizeVal);
+
+    /// Update remaining for the next dimension by removing what we've already
+    /// processed. Division tells us "how many complete groups of this dimension
+    /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've
+    /// completed 5 groups of 4) Skip this for the last iteration since there's
+    /// no next dimension to process
+    if (i < order.size() - 1) {
+      remaining =
+          builder.createOrFold<index::DivUOp>(loc, remaining, dimSizeVal);
+    }
+  }
+  return result;
 }
 
 /// Implements DistributeLayoutAttr::getOffsets to generate
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9fc5ad9af5c7b..88d3ed743628e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1217,6 +1217,93 @@ struct WgToSgMultiDimReductionOp
   }
 };
 
+// This pattern transforms vector.transpose ops to work at subgroup level.
+struct WgToSgVectorTransposeOp
+    : public OpConversionPattern<vector::TransposeOp> {
+  using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
+    if (!resultType)
+      return failure();
+
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    xegpu::DistributeLayoutAttr sourceLayout =
+        xegpu::getDistributeLayoutAttr(op.getVector());
+    if (!sourceLayout || !sourceLayout.isForWorkgroup())
+      return failure();
+
+    SmallVector<int64_t> sourceSgLayout =
+        sourceLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> sourceSgData = sourceLayout.getEffectiveSgDataAsInt();
+    SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> resultSgData = layout.getEffectiveSgDataAsInt();
+    DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder();
+    DenseI32ArrayAttr resultOrder = layout.getOrder();
+
+    if (!sourceOrder || !resultOrder) {
+      return rewriter.notifyMatchFailure(
+          op, "Both source and result must have order attributes");
+    }
+
+    SmallVector<int64_t> sourceOrderVec = llvm::to_vector(
+        llvm::map_range(sourceOrder.asArrayRef(),
+                        [](int32_t idx) { return static_cast<int64_t>(idx); }));
+    SmallVector<int64_t> resultOrderVec = llvm::to_vector(
+        llvm::map_range(resultOrder.asArrayRef(),
+                        [](int32_t idx) { return static_cast<int64_t>(idx); }));
+
+    ArrayRef<int64_t> permutation = op.getPermutation();
+    size_t expectedSize = permutation.size();
+    if (sourceSgLayout.size() != expectedSize ||
+        sourceSgData.size() != expectedSize ||
+        resultSgLayout.size() != expectedSize ||
+        resultSgData.size() != expectedSize ||
+        sourceOrderVec.size() != expectedSize ||
+        resultOrderVec.size() != expectedSize) {
+      return rewriter.notifyMatchFailure(
+          op, "All layouts and permutation must have the same rank");
+    }
+
+    // Check that sgLayout, sgData & order are properly transposed for operand
+    // and result
+    for (size_t i = 0; i < permutation.size(); ++i) {
+      int64_t srcDim = permutation[i];
+      if (resultSgLayout[i] != sourceSgLayout[srcDim] ||
+          resultSgData[i] != sourceSgData[srcDim] ||
+          resultOrderVec[i] != sourceOrderVec[srcDim]) {
+        return rewriter.notifyMatchFailure(
+            op, "Result layout is not a valid transpose of source layout "
+                "according to permutation");
+      }
+    }
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+    SmallVector<Value> newTransposeOps;
+    for (auto src : adaptor.getVector()) {
+      auto newTranspose = vector::TransposeOp::create(
+          rewriter, op.getLoc(), newResultType, src, permutation);
+      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+          !layout.getEffectiveInstDataAsInt().empty())
+        xegpu::setDistributeLayoutAttr(newTranspose->getResult(0),
+                                       layout.dropSgLayoutAndData());
+      newTransposeOps.push_back(newTranspose.getResult());
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newTransposeOps});
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -1231,7 +1318,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
            WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
            WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
-           WgToSgMultiDimReductionOp>(patterns.getContext());
+           WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>(
+          patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -1358,7 +1446,9 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(layout);
       });
 
-  target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
+  target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
+                               vector::TransposeOp, vector::BroadcastOp,
+                               vector::MultiDimReductionOp>(
       [=](Operation *op) -> bool {
         // Check for either a SliceAttr or LayoutAttr on the result.
         auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
@@ -1377,16 +1467,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(layout);
       });
 
-  target.addDynamicallyLegalOp<vector::BroadcastOp>(
-      [=](vector::BroadcastOp op) -> bool {
-        return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
-      });
-
-  target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
-      [=](vector::MultiDimReductionOp op) -> bool {
-        return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
-      });
-
   target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
       [=](xegpu::ConvertLayoutOp op) -> bool {
         return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
index b73bc69393dab..02c5f71d5c83d 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
@@ -1,33 +1,32 @@
 // RUN: mlir-opt --test-xegpu-layout-interface --cse -split-input-file %s | FileCheck %s
 
-//CHECk: #map = affine_map<()[s0] -> (s0 floordiv 8)>
 gpu.module @test {
   gpu.func @slice_attr() -> vector<128xindex> {
-    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
-    //CHECK: [[c32:%.+]] = arith.constant 32 : index
-    //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
-    //CHECK: [[c128:%.+]] = arith.constant 128 : index
-    //CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
-    //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
-    //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
-    //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[DIVU:.*]] = index.divu %[[SGID]], %[[C8:.*]]
+    // CHECK-DAG: %[[REMU:.*]] = index.remu %[[DIVU]], %[[C4:.*]]
+    // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C32:.*]]
+    // CHECK-DAG: %[[MOD:.*]] = index.remu %[[MUL]], %[[C128:.*]]
+    // CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex>
+    // CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex>
+    // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex>
     %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
     gpu.return %step : vector<128xindex>
   }
 
   gpu.func @nested_slice_attr() -> vector<128xindex> {
-    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
-    //CHECK: [[c32:%.+]] = arith.constant 32 : index
-    //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
-    //CHECK: [[c128:%.+]] = arith.constant 128 : index
-    //CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
-    //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
-    //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
-    //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[DIVU1:.*]] = index.divu %[[SGID]], %[[C1:.*]]
+    // CHECK-DAG: %[[DIVU2:.*]] = index.divu %[[DIVU1]], %[[C8:.*]]
+    // CHECK-DAG: %[[REMU:.*]] = index.remu %[[DIVU2]], %[[C4:.*]]
+    // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C32:.*]]
+    // CHECK-DAG: %[[MOD:.*]] = index.remu %[[MUL]], %[[C128:.*]]
+    // CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex>
+    // CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex>
+    // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex>
     %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 1], sg_data = [32, 32, 1]>, dims = [2]>, dims = [1]>} : vector<128xindex>
     gpu.return %0 : vector<128xindex>
   }
 
-}
\ No newline at end of file
+}
+
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
index 09df1e4da43e2..9580769d37313 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
@@ -166,14 +166,12 @@ gpu.module @test_elementwise_ops {
     %load_b = xegpu.load_nd %tdesc_b
       : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
       -> vector<24x32xf32>
-    // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
-    // CHECK-SAME-COUNT-12: : vector<2x2xf32>
+    // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
     // CHECK-NOT: arith.negf
     %negf = arith.negf %load_a
       {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
       : vector<24x32xf32>
-    // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
-    // CHECK-SAME-COUNT-12: : vector<2x2xf32>
+    // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
     // CHECK-NOT: math.powf
     %powf = math.powf %load_a, %load_b
       {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index d2d250cbe0f66..01134d8eaabec 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -1,14 +1,10 @@
 // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
 
-#map = affine_map<()[s0] -> (s0 floordiv 4)>
-#map1 = affine_map<()[s0] -> (s0 mod 4)>
-
 gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: create_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
-      // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32>
-      // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+      // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32> -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
       // CHECK-NOT: xegpu.create_nd_tdesc
       %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -16,22 +12,23 @@ gpu.module @test_round_robin_assignment {
     }
 
   // CHECK-LABEL: create_nd_tdesc_with_shared_data
-  // CHECK-SAME: [[ARG_0:%.*]]: memref<256x128xf32>
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @create_nd_tdesc_with_shared_data(%src: memref<256x128xf32>) {
-    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[IdY:%.+]] = affine.apply #map()[[[sgId]]]
-    //CHECK: [[IdX:%.+]] = affine.apply #map1()[[[sgId]]]
-    //CHECK: [[C16:%.+]] = arith.constant 16 : index
-    //CHECK: [[LY:%.+]] = index.mul [[IdY]], [[C16]]
-    //CHECK: [[C64:%.+]] = arith.constant 64 : index
-    //CHECK: [[LX:%.+]] = index.mul [[IdX]], [[C64]]
-    //CHECK: [[C0:%.+]] = arith.constant 0 : index
-    //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
-    //CHECK: [[C128:%.+]] = arith.constant 128 : index
-    //CHECK: [[offY:%.+]] = index.remu [[LY]], [[C128]]
-    //CHECK: [[C64_2:%.+]] = arith.constant 64 : index
-    //CHECK: [[offX:%.+]] = index.remu [[LX]], [[C64_2]]
-    //CHECK: xegpu.create_nd_tdesc [[ARG_0]][[[offY]], [[offX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
+    // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK: %[[C4:.*]] = arith.constant 4 : index
+    // CHECK: %[[IDX:.*]] = index.remu %[[SGID]], %[[C4]]
+    // CHECK: %[[IDY_DIV:.*]] = index.divu %[[SGID]], %[[C4]]
+    // CHECK: %[[C8:.*]] = arith.constant 8 : index
+    // CHECK: %[[IDY:.*]] = index.remu %[[IDY_DIV]], %[[C8]]
+    // CHECK: %[[C16:.*]] = arith.constant 16 : index
+    // CHECK: %[[LY:.*]] = index.mul %[[IDY]], %[[C16]]
+    // CHECK: %[[C64:.*]] = arith.constant 64 : index
+    // CHECK: %[[LX:.*]] = index.mul %[[IDX]], %[[C64]]
+    // CHECK: %[[C128:.*]] = arith.constant 128 : index
+    // CHECK: %[[OFFY:.*]] = index.remu %[[LY]], %[[C128]]
+    // CHECK: %[[C64_1:.*]] = arith.constant 64 : index
+    // CHECK: %[[OFFX:.*]] = index.remu %[[LX]], %[[C64_1]]
+    // CHECK: xegpu.create_nd_tdesc %[[ARG_0]][%[[OFFY]], %[[OFFX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
     gpu.return
@@ -42,9 +39,7 @@ gpu.module @test_round_robin_assignment {
   gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
       %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-      // CHECK-COUNT-4: xegpu.load_nd %{{.*}}
-      // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-      // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
+      // CHECK-COUNT-4: xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf32>
       // CHECK-NOT: xegpu.load_nd
       %load =  xegpu.load_nd %tdesc
         : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -57,9 +52,8 @@ gpu.module @test_round_robin_assignment {
   gpu.func @store_nd(%src: memref<256x128xf32>) {
       %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-      // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}}
-      // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-      // CHECK-NOT : xegpu.store_nd
+      // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}} : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+      // CHECK-NOT: xegpu.store_nd
       %load = xegpu.load_nd %tdesc
         : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
         -> vector<256x128xf32>
@@ -73,8 +67,7 @@ gpu.module @test_round_robin_assignment {
   gpu.func @update_nd(%src: memref<256x128xf32>){
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       ->  !xegpu.tensor_desc<2...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/165307
    
    
More information about the Mlir-commits
mailing list