[Mlir-commits] [mlir] [MLIR][XeGPU] Propagate layout from anchor ops in WG IR (PR #179490)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 3 08:19:22 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

<details>
<summary>Changes</summary>



---

Patch is 20.67 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/179490.diff


6 Files Affected:

- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp (+4-7) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+5-6) 
- (modified) mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp (+13) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir (+13-12) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir (+3-3) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+6-16) 


``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index c00b7d42d48a6..6ab473704645d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -286,13 +286,10 @@ void XeGPUBlockingPass::runOnOperation() {
   MLIRContext *ctx = &getContext();
   Operation *op = getOperation();
 
-  // TODO-LayoutRefactor: unify the local propagation for layout preprocessing
-  // replace the function with recoverTemporaryLayouts
-  // if (!xegpu::recoverTemporaryLayouts(op)) {
-  //   signalPassFailure();
-  //   return;
-  // }
-  xegpu::recoverTemporaryLayoutsDeprecated(op);
+  if (!xegpu::recoverTemporaryLayouts(op)) {
+    signalPassFailure();
+    return;
+  }
 
   auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
                                  xegpu::LayoutAttr layout) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 45a002b63abd6..adaf4ff881d69 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1723,12 +1723,11 @@ struct XeGPUWgToSgDistributePass
 
 void XeGPUWgToSgDistributePass::runOnOperation() {
 
-  // TODO-LayoutRefactor: unify the local propagation for layout preprocessing
-  // Operation *op = getOperation();
-  // if (!xegpu::recoverTemporaryLayouts(op)) {
-  //   signalPassFailure();
-  //   return;
-  // }
+  Operation *op = getOperation();
+  if (!xegpu::recoverTemporaryLayouts(op)) {
+    signalPassFailure();
+    return;
+  }
 
   // Track existing UnrealizedConversionCastOps
   SmallVector<Operation *> existingCastOps;
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 7e28c756f2d72..282d9253d7afe 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -166,6 +166,18 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
           defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
       return layout;
     }
+
+    // Handle scf.for results by looking at yielded values
+    if (auto forOp = dyn_cast<scf::ForOp>(defOp)) {
+      auto resultIndex = result.getResultNumber();
+      if (auto yieldOp =
+              dyn_cast<scf::YieldOp>(forOp.getBody()->getTerminator())) {
+        if (resultIndex < yieldOp.getOperands().size()) {
+          auto yieldedValue = yieldOp.getOperand(resultIndex);
+          return getDistributeLayoutAttr(yieldedValue);
+        }
+      }
+    }
   }
 
   if (auto arg = dyn_cast<BlockArgument>(value)) {
@@ -179,6 +191,7 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
 
   return nullptr;
 }
+
 xegpu::DistributeLayoutAttr
 xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
   Operation *op = opr.getOwner();
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
index 6e9711442b92d..608f1dc2a1253 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
@@ -27,11 +27,11 @@ gpu.module @test_elementwise_ops {
     %load_a = xegpu.load_nd %tdesc_a {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
       : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
       -> vector<24x32xf32>
-    // CHECK: math.exp {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+    // CHECK: math.exp {{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>, layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
     %exp = math.exp %load_a
       {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
       : vector<24x32xf32>
-    // CHECK: arith.negf {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+    // CHECK: arith.negf {{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>, layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
     %negf = arith.negf %load_a
       {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
       : vector<24x32xf32>
@@ -50,12 +50,12 @@ gpu.module @test_elementwise_ops {
     %load_b = xegpu.load_nd %tdesc_b {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
       : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
       -> vector<24x32xf32>
-    // CHECK: arith.addf {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+    // CHECK: arith.addf {{.*}}, {{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>, layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
     // CHECK-SAME: : vector<12x8xf32>
     %addf = arith.addf %load_a, %load_b
       {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
       : vector<24x32xf32>
-    // CHECK: math.powf {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+    // CHECK: math.powf {{.*}}, {{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>, layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
     // CHECK-SAME: : vector<12x8xf32>
     %powf = math.powf %load_a, %load_b
       {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
@@ -80,12 +80,12 @@ gpu.module @test_elementwise_ops {
     %load_c = xegpu.load_nd %tdesc_c {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
       : !xegpu.tensor_desc<24x32xi1, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
       -> vector<24x32xi1>
-    // CHECK: arith.select {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+    // CHECK: arith.select {{.*}}, {{.*}}, {{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>, layout_operand_2 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
     // CHECK-SAME: : vector<12x8xi1>, vector<12x8xf32>
     %select = arith.select %load_c, %load_a, %load_b
       {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
       : vector<24x32xi1>, vector<24x32xf32>
-    // CHECK: math.fma  {{.*}}, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+    // CHECK: math.fma  {{.*}}, {{.*}}, {{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>, layout_operand_2 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
     // CHECK-SAME: : vector<12x8xf32>
     %fma = math.fma %load_a, %load_b, %load_a
       {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
@@ -105,12 +105,12 @@ gpu.module @test_elementwise_ops {
     %load_b = xegpu.load_nd %tdesc_b {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
       : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
       -> vector<24x32xi32>
-    // CHECK: arith.truncf {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+    // CHECK: arith.truncf {{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>, layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
     // CHECK-SAME: : vector<12x8xf32> to vector<12x8xf16>
     %truncf = arith.truncf %load_a
       {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
       : vector<24x32xf32> to vector<24x32xf16>
-    // CHECK: arith.bitcast {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+    // CHECK: arith.bitcast {{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>, layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
     // CHECK-SAME: : vector<12x8xi32> to vector<12x8xf32>
     %bitcast = arith.bitcast %load_b
       {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
@@ -140,12 +140,12 @@ gpu.module @test_elementwise_ops {
     %load_d = xegpu.load_nd %tdesc_d {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
       : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
       -> vector<24x32xi32>
-    // CHECK: arith.cmpf ult, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+    // CHECK: arith.cmpf ult, {{.*}}, {{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
     // CHECK-SAME: : vector<12x8xf32>
     %cmpf = arith.cmpf ult, %load_a, %load_b
       {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
       : vector<24x32xf32>
-    // CHECK: arith.cmpi eq, {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
+    // CHECK: arith.cmpi eq, {{.*}}, {{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}
     // CHECK-SAME: : vector<12x8xi32>
     %cmpi = arith.cmpi eq, %load_c, %load_d
       {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
@@ -166,12 +166,12 @@ gpu.module @test_elementwise_ops {
     %load_b = xegpu.load_nd %tdesc_b {layout = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
       : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
       -> vector<24x32xf32>
-    // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+    // CHECK-COUNT-12: arith.negf {{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
     // CHECK-NOT: arith.negf
     %negf = arith.negf %load_a
       {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
       : vector<24x32xf32>
-    // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
+    // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
     // CHECK-NOT: math.powf
     %powf = math.powf %load_a, %load_b
       {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
@@ -179,3 +179,4 @@ gpu.module @test_elementwise_ops {
     gpu.return
   }
 }
+
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 4b11270373f95..f794f8951cf94 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -26,14 +26,14 @@ gpu.module @test_distribution {
 
   // CHECK-LABEL: store_nd_with_offset
   gpu.func @store_nd_with_offset(%src: memref<256x128xf32>) {
-    // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}] : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-NOT: xegpu.store_nd
     %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
     %load =  xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
       : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
       -> vector<256x128xf32>
-    xegpu.store_nd %load, %tdesc[0, 0]
+    xegpu.store_nd %load, %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
       : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
     gpu.return
   }
@@ -155,7 +155,7 @@ gpu.module @test_distribution {
     %cst_0 = arith.constant {layout_result_0=#xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} dense<0xFF800000> : vector<256xf32>
     %block_id_x = gpu.block_id  x
     %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
-    %1 = xegpu.load_nd %0[%block_id_x, 0]  : !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>> -> vector<256x128xf32>
+    %1 = xegpu.load_nd %0[%block_id_x, 0] {layout = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>}  : !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>> -> vector<256x128xf32>
     %2 = vector.multi_reduction <maximumf>, %1, %cst_0 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} [1] : vector<256x128xf32> to vector<256xf32>
     %3 = vector.shape_cast %2 {layout_result_0 =  #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} : vector<256xf32> to vector<256x1xf32>
     %4 = vector.broadcast %3 {layout_result_0 =  #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256x1xf32>to vector<256x128xf32>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 1fc2328d09046..b761a3c0f87b7 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -58,13 +58,13 @@ gpu.module @test_distribution {
   // CHECK-LABEL: store_nd_with_offsets
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @store_nd_with_offsets(%src: memref<256x128xf32>) {
-    //CHECK: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}]  : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    //CHECK: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}>  : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
     %load =  xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>}
       : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
       -> vector<256x128xf32>
-    xegpu.store_nd %load, %tdesc[0, 0]
+    xegpu.store_nd %load, %tdesc[0, 0] 
       : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
     gpu.return
 }
@@ -130,17 +130,6 @@ gpu.module @test_distribution {
     gpu.return
   }
 
-  // CHECK-LABEL: dpas_with_no_create_nd_desc
-  gpu.func @dpas_with_no_create_nd_desc(%a: vector<256x128xf32>, %b: vector<128x256xf32>) {
-    // CHECK-NOT: vector<32x32xf32>
-    %dpas = xegpu.dpas %a, %b
-      {layout_a =  #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>, 
-       layout_b =  #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [1, 1]>,
-       layout_cd =  #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
-      : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32>
-    gpu.return
-  }
-
   // CHECK-LABEL: broadcast_dim1
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x1xf32>
   gpu.func @broadcast_dim1(%src: memref<256x1xf32>) {
@@ -209,14 +198,15 @@ gpu.module @test_distribution {
       %9 = xegpu.dpas %arg4, %arg5, %arg6 
           {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>,
            layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>,
-           layout_cd = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}
+           layout_cd = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>,
+           layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}
           : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32>
       %10 = xegpu.load_nd %3[%arg3, %c0] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>}: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>> -> vector<128x128xf16>
       %11 = xegpu.load_nd %4[%c0, %arg3] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>}: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>> -> vector<128x128xf16>
       scf.yield %10, %11, %9 : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32>
     }
     // store_nd with offset
-    xegpu.store_nd %8#2, %2[%0, %1] : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+    xegpu.store_nd %8#2, %2[%0, %1] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>} : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
     gpu.return
   }
 
@@ -726,7 +716,7 @@ gpu.module @test_distribution {
     %cst = arith.constant {layout_result_0 = #xegpu.slice<...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/179490


More information about the Mlir-commits mailing list