[Mlir-commits] [mlir] [MLIR][XeGPU] Decouple `inst_data` and `lane_layout` in propagation (PR #166941)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 7 06:34:07 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Artem Kroviakov (akroviakov)

<details>
<summary>Changes</summary>

Currently, we have a circular dependency. 
Blocking needs `inst_data` (which also attaches `lane_layout`, `lane_data`), but `lane_layout`, `lane_data` need blocking.
We decouple these layout fields in the propagation and make propagation multi-step.
As in 
```
-xegpu-propagate-layout="layout-kind=inst" -xegpu-blocking -xegpu-propagate-layout="layout-kind=lane" -xegpu-subgroup-distribute
```

---

Patch is 69.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166941.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+11) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+120-68) 
- (modified) mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir (+21-21) 
- (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+1-579) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 9c35c07a7e587..3f27d690f949b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -379,6 +379,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
   );
 
   let builders = [
+    AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $inst_data),
+      [{
+        auto sg_layout = DenseI32ArrayAttr();
+        auto sg_data = DenseI32ArrayAttr();
+        auto order = DenseI32ArrayAttr();
+        auto lane_layout = DenseI32ArrayAttr();
+        auto lane_data = DenseI32ArrayAttr();
+        return $_get($_ctxt, sg_layout, sg_data,
+                     DenseI32ArrayAttr::get($_ctxt, inst_data),
+                     lane_layout, lane_data, order);
+      }]>,
     AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $inst_data,
                       "llvm::ArrayRef<int32_t>": $lane_layout,
                      "llvm::ArrayRef<int32_t>": $lane_data),
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 4e1a539771d2f..c00b08e0ca37f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -53,6 +53,8 @@ using namespace mlir::dataflow;
 
 namespace {
 
+enum class LayoutKind { Lane, InstData };
+
 //===----------------------------------------------------------------------===//
 // LayoutInfo
 //===----------------------------------------------------------------------===//
@@ -99,7 +101,8 @@ struct LayoutInfo {
 
   bool isAssigned() const { return storage != nullptr; }
 
-  LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
+  LayoutInfo transpose(ArrayRef<int64_t> permutation,
+                       LayoutKind layoutKind) const;
 
   SmallVector<int> getLaneLayout() const;
 
@@ -167,7 +170,8 @@ LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
 }
 
 /// Construct a new layout with the transposed lane layout and lane data.
-LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
+LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation,
+                                 LayoutKind layoutKind) const {
   if (!isAssigned())
     return {};
   // Check if the permutation is valid.
@@ -186,12 +190,20 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
   SmallVector<int32_t> laneData;
   SmallVector<int32_t> instData;
   for (int64_t idx : permutation) {
-    laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
-    laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
-    instData.push_back(static_cast<int32_t>(getInstData()[idx]));
+    if (layoutKind == LayoutKind::Lane) {
+      laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
+      laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
+    } else if (layoutKind == LayoutKind::InstData)
+      instData.push_back(static_cast<int32_t>(getInstData()[idx]));
+  }
+  xegpu::LayoutAttr layoutAttr;
+  if (layoutKind == LayoutKind::Lane) {
+    layoutAttr =
+        xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
+  } else if (layoutKind == LayoutKind::InstData) {
+    layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
   }
-  return LayoutInfo(xegpu::LayoutAttr::get(storage.getContext(), instData,
-                                           laneLayout, laneData));
+  return LayoutInfo(layoutAttr);
 }
 
 //===----------------------------------------------------------------------===//
@@ -213,15 +225,14 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
 /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
 static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
                                            unsigned rank,
-                                           const xegpu::uArch::uArch *uArch,
-                                           ArrayRef<int> instData) {
+                                           const xegpu::uArch::uArch *uArch) {
   assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
   if (rank == 1) {
     return LayoutInfo(
-        xegpu::LayoutAttr::get(ctx, instData, {uArch->getSubgroupSize()}, {1}));
+        xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1}));
   }
-  return LayoutInfo(xegpu::LayoutAttr::get(
-      ctx, instData, {1, uArch->getSubgroupSize()}, {1, 1}));
+  return LayoutInfo(
+      xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1}));
 }
 
 static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
@@ -236,7 +247,6 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
 /// Helper to get the default layout for a vector type.
 static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
                                            const xegpu::uArch::uArch *uArch,
-                                           ArrayRef<int> instData,
                                            unsigned packingSize,
                                            bool isScattered = false) {
   // Expecting a 1D or 2D vector.
@@ -247,16 +257,16 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
   if (vectorTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch, instData);
+    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
   // Packing factor is determined by the element type bitwidth.
   unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
   int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
   if (isScattered) {
-    return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+    return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
                                              {uArch->getSubgroupSize(), 1},
                                              {1, packingFactor}));
   }
-  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
                                            {1, uArch->getSubgroupSize()},
                                            {1, packingFactor}));
 }
@@ -275,7 +285,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
   if (tdescTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch, instData);
+    return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch);
   // Packing factor is determined by the element type bitwidth.
   unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
   int subgroupSize = uArch->getSubgroupSize();
@@ -298,7 +308,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
 static LayoutInfo
 getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
                                 const xegpu::uArch::uArch *uArch,
-                                ArrayRef<int> instData, unsigned packingSize) {
+                                unsigned packingSize) {
   Type elementTy = vectorTy.getElementType();
   assert(elementTy.isIntOrFloat() &&
          "Expected int or float type in DPAS operands");
@@ -310,10 +320,10 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
         {static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()),
          1});
     return LayoutInfo(
-        xegpu::LayoutAttr::get(vectorTy.getContext(), instData, layout, data));
+        xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
   }
   // Otherwise, return the default layout for the vector type.
-  return getDefaultSIMTLayoutInfo(vectorTy, uArch, instData, packingSize);
+  return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize);
 }
 
 //===----------------------------------------------------------------------===//
@@ -328,6 +338,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
 class LayoutInfoPropagation
     : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
 private:
+  LayoutKind layoutKind;
   void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
                    ArrayRef<const LayoutInfoLattice *> results);
 
@@ -380,8 +391,10 @@ class LayoutInfoPropagation
 
 public:
   LayoutInfoPropagation(DataFlowSolver &solver,
-                        SymbolTableCollection &symbolTable)
-      : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
+                        SymbolTableCollection &symbolTable,
+                        LayoutKind layoutKind)
+      : SparseBackwardDataFlowAnalysis(solver, symbolTable),
+        layoutKind(layoutKind) {}
   using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
 
   LogicalResult
@@ -627,14 +640,24 @@ void LayoutInfoPropagation::visitDpasOp(
   SmallVector<int> instDataA = {maxALen, subgroupSize};
   SmallVector<int> instDataB = {subgroupSize, maxBLen};
 
-  propagateIfChanged(operands[0],
-                     operands[0]->meet(getSIMTLayoutInfoForDPASOperand(
-                         aTy, 0, uArch, instDataA,
-                         uArchInstruction->getPackedFormatBitSizeA())));
-  propagateIfChanged(operands[1],
-                     operands[1]->meet(getSIMTLayoutInfoForDPASOperand(
-                         bTy, 1, uArch, instDataB,
-                         uArchInstruction->getPackedFormatBitSizeB())));
+  LayoutInfo dpasALayout;
+  LayoutInfo dpasBLayout;
+  LayoutInfo dpasCLayout;
+
+  if (layoutKind == LayoutKind::InstData) {
+    dpasALayout =
+        LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
+    dpasBLayout =
+        LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
+  } else {
+    dpasALayout = getSIMTLayoutInfoForDPASOperand(
+        aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
+    dpasBLayout = getSIMTLayoutInfoForDPASOperand(
+        bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
+  }
+
+  propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
+  propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
   if (operands.size() > 2) {
     VectorType cTy = dpas.getAccType();
     const unsigned dataCLen = bTy.getShape().back();
@@ -645,10 +668,15 @@ void LayoutInfoPropagation::visitDpasOp(
       dpas.emitWarning(
           "No suitable instruction multiple found for the given shape.");
     SmallVector<int> instDataC = {maxALen, maxCLen};
-    propagateIfChanged(operands[2],
-                       operands[2]->meet(getSIMTLayoutInfoForDPASOperand(
-                           cTy, 2, uArch, instDataC,
-                           uArchInstruction->getPackedFormatBitSizeB())));
+
+    if (layoutKind == LayoutKind::InstData)
+      dpasCLayout =
+          LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
+    else
+      dpasCLayout = getSIMTLayoutInfoForDPASOperand(
+          cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
+
+    propagateIfChanged(operands[2], operands[2]->meet(dpasCLayout));
   }
 }
 
@@ -685,9 +713,15 @@ void LayoutInfoPropagation::visitStoreNdOp(
           "No suitable instruction multiple found for the given shape.");
     instData = {instHeight, instWidth};
   }
-  LayoutInfo storeLayout =
-      getDefaultSIMTLayoutInfo(store.getValueType(), uArch, instData,
-                               uArchInstruction->getPackedFormatBitSize());
+
+  LayoutInfo storeLayout;
+  if (layoutKind == LayoutKind::InstData)
+    storeLayout =
+        LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
+  else
+    storeLayout =
+        getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
+                                 uArchInstruction->getPackedFormatBitSize());
   // Both operands should have the same layout
   for (LayoutInfoLattice *operand : operands)
     propagateIfChanged(operand, operand->meet(storeLayout));
@@ -709,7 +743,7 @@ void LayoutInfoPropagation::visitLoadNdOp(
   if (auto transpose = load.getTranspose()) {
     load.emitWarning("Transpose effect is not expected for LoadNdOp at "
                      "LayoutInfoPropagation stage.");
-    tensorDescLayout = valueLayout.transpose(transpose.value());
+    tensorDescLayout = valueLayout.transpose(transpose.value(), layoutKind);
   }
   // Propagate the new layout to the tensor descriptor operand.
   propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
@@ -724,7 +758,8 @@ void LayoutInfoPropagation::visitTransposeOp(
   LayoutInfo resultLayout = results[0]->getValue();
   if (!resultLayout.isAssigned())
     return;
-  LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
+  LayoutInfo newLayout =
+      resultLayout.transpose(transpose.getPermutation(), layoutKind);
   // Propagate the new layout to the vector operand.
   propagateIfChanged(operands[0], operands[0]->meet(newLayout));
 }
@@ -818,9 +853,13 @@ void LayoutInfoPropagation::visitLoadGatherOp(
     if (srcTdescTy.getChunkSizeAsInt() > 1)
       instData.push_back(chunkSize);
   }
-  LayoutInfo layout = getDefaultSIMTLayoutInfo(
-      payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
-      /*scattered*/ true);
+  LayoutInfo layout;
+  if (layoutKind == LayoutKind::InstData)
+    layout = LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
+  else
+    layout = getDefaultSIMTLayoutInfo(payloadTy, uArch,
+                                      uArch->getGeneralPackedFormatBitSize(),
+                                      /*scattered*/ true);
 
   // Mask operand should have 1D default layout.
   LayoutInfo maskLayout =
@@ -864,33 +903,36 @@ void LayoutInfoPropagation::visitStoreScatterOp(
     storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
     return;
   }
+  LayoutInfo payloadLayout;
   auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
   const int subgroupSize = uArch->getSubgroupSize();
 
-  auto payloadShape = payloadTy.getShape();
-  if (payloadShape.size() > 1)
-    assert(
-        payloadShape[0] == subgroupSize &&
-        "Expected the first dimension of 2D tensor descriptor to be equal to "
-        "subgroup size.");
-
-  SmallVector<int> instData{subgroupSize};
-  if (auto chunkSize = storeScatter.getChunkSize().value_or(0); chunkSize > 1)
-    instData.push_back(chunkSize);
-  else if (auto dstTdescTy =
-               dyn_cast<xegpu::TensorDescType>(storeScatter.getDestType())) {
-    if (dstTdescTy.getChunkSizeAsInt() > 1)
-      instData.push_back(chunkSize);
-  }
-
-  LayoutInfo payloadLayout;
-
   if (auto layout = storeScatter.getLayoutAttr()) {
     payloadLayout = LayoutInfo(layout);
   } else {
-    payloadLayout = getDefaultSIMTLayoutInfo(
-        payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
-        /*scattered=*/true);
+    if (layoutKind == LayoutKind::InstData) {
+      SmallVector<int> instData{subgroupSize};
+      if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
+          chunkSize > 1)
+        instData.push_back(chunkSize);
+      else if (auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
+                   storeScatter.getDestType())) {
+        if (dstTdescTy.getChunkSizeAsInt() > 1)
+          instData.push_back(chunkSize);
+      }
+      payloadLayout = LayoutInfo(
+          xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
+    } else {
+      auto payloadShape = payloadTy.getShape();
+      if (payloadShape.size() > 1)
+        assert(payloadShape[0] == subgroupSize &&
+               "Expected the first dimension of 2D tensor descriptor to be "
+               "equal to "
+               "subgroup size.");
+      payloadLayout = getDefaultSIMTLayoutInfo(
+          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
+          /*scattered=*/true);
+    }
   }
 
   LayoutInfo maskLayout =
@@ -916,10 +958,10 @@ class RunLayoutInfoPropagation {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
 
-  RunLayoutInfoPropagation(Operation *op) : target(op) {
+  RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) {
     SymbolTableCollection symbolTable;
     loadBaselineAnalyses(solver);
-    solver.load<LayoutInfoPropagation>(symbolTable);
+    solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
     (void)solver.initializeAndRun(op);
   }
 
@@ -1159,7 +1201,19 @@ struct XeGPUPropagateLayoutPass final
 } // namespace
 
 void XeGPUPropagateLayoutPass::runOnOperation() {
-  auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
+  LayoutKind layoutKind;
+  if (this->layoutKind == "lane")
+    layoutKind = LayoutKind::Lane;
+  else if (this->layoutKind == "inst")
+    layoutKind = LayoutKind::InstData;
+  else {
+    signalPassFailure();
+    getOperation()->emitError("Unsupported layout kind option: " +
+                              this->layoutKind);
+    return;
+  }
+  RunLayoutInfoPropagation analysis(getOperation(), layoutKind);
+  // auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
   // Print the analysis result and exit. (for debugging purposes)
   if (printOnly) {
     auto &os = llvm::outs();
@@ -1173,8 +1227,6 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
       return {};
     xegpu::DistributeLayoutAttr layoutAttr =
         cast<xegpu::DistributeLayoutAttr>(layout.get());
-    if (this->layoutKind == "lane")
-      layoutAttr = layoutAttr.dropInstData();
     if (layout.isSliceLayout())
       return cast<xegpu::SliceAttr>(layoutAttr);
     return cast<xegpu::LayoutAttr>(layoutAttr);
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 58461b8be52c4..c31ef323a94d2 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -2,17 +2,17 @@
 
 // CHECK-LABEL: func.func @dpas_f16(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<8x16xf32>
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
-// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]]  {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} :
-// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
-// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]]  {layout_result_0 = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>} :
-// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
-// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<0.000000e+00> : vector<8x16xf32>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16]>
+// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16]>>
+// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]]  {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} :
+// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/166941


More information about the Mlir-commits mailing list