[Mlir-commits] [mlir] [MLIR][XeGPU] Decouple `inst_data` and `lane_layout` in propagation (PR #166941)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 7 06:34:07 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Artem Kroviakov (akroviakov)
<details>
<summary>Changes</summary>
Currently, we have a circular dependency.
Blocking needs `inst_data` (which also attaches `lane_layout`, `lane_data`), but `lane_layout`, `lane_data` need blocking.
We decouple these layout fields in the propagation and make propagation multi-step.
As in
```
-xegpu-propagate-layout="layout-kind=inst" -xegpu-blocking -xegpu-propagate-layout="layout-kind=lane" -xegpu-subgroup-distribute
```
---
Patch is 69.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166941.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+11)
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+120-68)
- (modified) mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir (+21-21)
- (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+1-579)
``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 9c35c07a7e587..3f27d690f949b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -379,6 +379,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
);
let builders = [
+ AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $inst_data),
+ [{
+ auto sg_layout = DenseI32ArrayAttr();
+ auto sg_data = DenseI32ArrayAttr();
+ auto order = DenseI32ArrayAttr();
+ auto lane_layout = DenseI32ArrayAttr();
+ auto lane_data = DenseI32ArrayAttr();
+ return $_get($_ctxt, sg_layout, sg_data,
+ DenseI32ArrayAttr::get($_ctxt, inst_data),
+ lane_layout, lane_data, order);
+ }]>,
AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $inst_data,
"llvm::ArrayRef<int32_t>": $lane_layout,
"llvm::ArrayRef<int32_t>": $lane_data),
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 4e1a539771d2f..c00b08e0ca37f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -53,6 +53,8 @@ using namespace mlir::dataflow;
namespace {
+enum class LayoutKind { Lane, InstData };
+
//===----------------------------------------------------------------------===//
// LayoutInfo
//===----------------------------------------------------------------------===//
@@ -99,7 +101,8 @@ struct LayoutInfo {
bool isAssigned() const { return storage != nullptr; }
- LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
+ LayoutInfo transpose(ArrayRef<int64_t> permutation,
+ LayoutKind layoutKind) const;
SmallVector<int> getLaneLayout() const;
@@ -167,7 +170,8 @@ LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
}
/// Construct a new layout with the transposed lane layout and lane data.
-LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
+LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation,
+ LayoutKind layoutKind) const {
if (!isAssigned())
return {};
// Check if the permutation is valid.
@@ -186,12 +190,20 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
SmallVector<int32_t> laneData;
SmallVector<int32_t> instData;
for (int64_t idx : permutation) {
- laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
- laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
- instData.push_back(static_cast<int32_t>(getInstData()[idx]));
+ if (layoutKind == LayoutKind::Lane) {
+ laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
+ laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
+ } else if (layoutKind == LayoutKind::InstData)
+ instData.push_back(static_cast<int32_t>(getInstData()[idx]));
+ }
+ xegpu::LayoutAttr layoutAttr;
+ if (layoutKind == LayoutKind::Lane) {
+ layoutAttr =
+ xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
+ } else if (layoutKind == LayoutKind::InstData) {
+ layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
}
- return LayoutInfo(xegpu::LayoutAttr::get(storage.getContext(), instData,
- laneLayout, laneData));
+ return LayoutInfo(layoutAttr);
}
//===----------------------------------------------------------------------===//
@@ -213,15 +225,14 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
unsigned rank,
- const xegpu::uArch::uArch *uArch,
- ArrayRef<int> instData) {
+ const xegpu::uArch::uArch *uArch) {
assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
if (rank == 1) {
return LayoutInfo(
- xegpu::LayoutAttr::get(ctx, instData, {uArch->getSubgroupSize()}, {1}));
+ xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1}));
}
- return LayoutInfo(xegpu::LayoutAttr::get(
- ctx, instData, {1, uArch->getSubgroupSize()}, {1, 1}));
+ return LayoutInfo(
+ xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1}));
}
static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
@@ -236,7 +247,6 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
/// Helper to get the default layout for a vector type.
static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
const xegpu::uArch::uArch *uArch,
- ArrayRef<int> instData,
unsigned packingSize,
bool isScattered = false) {
// Expecting a 1D or 2D vector.
@@ -247,16 +257,16 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
"Expected int or float element type.");
// If the rank is 1, then return default layout for 1D vector.
if (vectorTy.getRank() == 1)
- return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch, instData);
+ return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
// Packing factor is determined by the element type bitwidth.
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
if (isScattered) {
- return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+ return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
{uArch->getSubgroupSize(), 1},
{1, packingFactor}));
}
- return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+ return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
{1, uArch->getSubgroupSize()},
{1, packingFactor}));
}
@@ -275,7 +285,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
"Expected int or float element type.");
// If the rank is 1, then return default layout for 1D vector.
if (tdescTy.getRank() == 1)
- return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch, instData);
+ return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch);
// Packing factor is determined by the element type bitwidth.
unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
int subgroupSize = uArch->getSubgroupSize();
@@ -298,7 +308,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
static LayoutInfo
getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
const xegpu::uArch::uArch *uArch,
- ArrayRef<int> instData, unsigned packingSize) {
+ unsigned packingSize) {
Type elementTy = vectorTy.getElementType();
assert(elementTy.isIntOrFloat() &&
"Expected int or float type in DPAS operands");
@@ -310,10 +320,10 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
{static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()),
1});
return LayoutInfo(
- xegpu::LayoutAttr::get(vectorTy.getContext(), instData, layout, data));
+ xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
}
// Otherwise, return the default layout for the vector type.
- return getDefaultSIMTLayoutInfo(vectorTy, uArch, instData, packingSize);
+ return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize);
}
//===----------------------------------------------------------------------===//
@@ -328,6 +338,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
class LayoutInfoPropagation
: public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
private:
+ LayoutKind layoutKind;
void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
@@ -380,8 +391,10 @@ class LayoutInfoPropagation
public:
LayoutInfoPropagation(DataFlowSolver &solver,
- SymbolTableCollection &symbolTable)
- : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
+ SymbolTableCollection &symbolTable,
+ LayoutKind layoutKind)
+ : SparseBackwardDataFlowAnalysis(solver, symbolTable),
+ layoutKind(layoutKind) {}
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
LogicalResult
@@ -627,14 +640,24 @@ void LayoutInfoPropagation::visitDpasOp(
SmallVector<int> instDataA = {maxALen, subgroupSize};
SmallVector<int> instDataB = {subgroupSize, maxBLen};
- propagateIfChanged(operands[0],
- operands[0]->meet(getSIMTLayoutInfoForDPASOperand(
- aTy, 0, uArch, instDataA,
- uArchInstruction->getPackedFormatBitSizeA())));
- propagateIfChanged(operands[1],
- operands[1]->meet(getSIMTLayoutInfoForDPASOperand(
- bTy, 1, uArch, instDataB,
- uArchInstruction->getPackedFormatBitSizeB())));
+ LayoutInfo dpasALayout;
+ LayoutInfo dpasBLayout;
+ LayoutInfo dpasCLayout;
+
+ if (layoutKind == LayoutKind::InstData) {
+ dpasALayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
+ dpasBLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
+ } else {
+ dpasALayout = getSIMTLayoutInfoForDPASOperand(
+ aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
+ dpasBLayout = getSIMTLayoutInfoForDPASOperand(
+ bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
+ }
+
+ propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
+ propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
if (operands.size() > 2) {
VectorType cTy = dpas.getAccType();
const unsigned dataCLen = bTy.getShape().back();
@@ -645,10 +668,15 @@ void LayoutInfoPropagation::visitDpasOp(
dpas.emitWarning(
"No suitable instruction multiple found for the given shape.");
SmallVector<int> instDataC = {maxALen, maxCLen};
- propagateIfChanged(operands[2],
- operands[2]->meet(getSIMTLayoutInfoForDPASOperand(
- cTy, 2, uArch, instDataC,
- uArchInstruction->getPackedFormatBitSizeB())));
+
+ if (layoutKind == LayoutKind::InstData)
+ dpasCLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
+ else
+ dpasCLayout = getSIMTLayoutInfoForDPASOperand(
+ cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
+
+ propagateIfChanged(operands[2], operands[2]->meet(dpasCLayout));
}
}
@@ -685,9 +713,15 @@ void LayoutInfoPropagation::visitStoreNdOp(
"No suitable instruction multiple found for the given shape.");
instData = {instHeight, instWidth};
}
- LayoutInfo storeLayout =
- getDefaultSIMTLayoutInfo(store.getValueType(), uArch, instData,
- uArchInstruction->getPackedFormatBitSize());
+
+ LayoutInfo storeLayout;
+ if (layoutKind == LayoutKind::InstData)
+ storeLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
+ else
+ storeLayout =
+ getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
+ uArchInstruction->getPackedFormatBitSize());
// Both operands should have the same layout
for (LayoutInfoLattice *operand : operands)
propagateIfChanged(operand, operand->meet(storeLayout));
@@ -709,7 +743,7 @@ void LayoutInfoPropagation::visitLoadNdOp(
if (auto transpose = load.getTranspose()) {
load.emitWarning("Transpose effect is not expected for LoadNdOp at "
"LayoutInfoPropagation stage.");
- tensorDescLayout = valueLayout.transpose(transpose.value());
+ tensorDescLayout = valueLayout.transpose(transpose.value(), layoutKind);
}
// Propagate the new layout to the tensor descriptor operand.
propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
@@ -724,7 +758,8 @@ void LayoutInfoPropagation::visitTransposeOp(
LayoutInfo resultLayout = results[0]->getValue();
if (!resultLayout.isAssigned())
return;
- LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
+ LayoutInfo newLayout =
+ resultLayout.transpose(transpose.getPermutation(), layoutKind);
// Propagate the new layout to the vector operand.
propagateIfChanged(operands[0], operands[0]->meet(newLayout));
}
@@ -818,9 +853,13 @@ void LayoutInfoPropagation::visitLoadGatherOp(
if (srcTdescTy.getChunkSizeAsInt() > 1)
instData.push_back(chunkSize);
}
- LayoutInfo layout = getDefaultSIMTLayoutInfo(
- payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
- /*scattered*/ true);
+ LayoutInfo layout;
+ if (layoutKind == LayoutKind::InstData)
+ layout = LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
+ else
+ layout = getDefaultSIMTLayoutInfo(payloadTy, uArch,
+ uArch->getGeneralPackedFormatBitSize(),
+ /*scattered*/ true);
// Mask operand should have 1D default layout.
LayoutInfo maskLayout =
@@ -864,33 +903,36 @@ void LayoutInfoPropagation::visitStoreScatterOp(
storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
return;
}
+ LayoutInfo payloadLayout;
auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
const int subgroupSize = uArch->getSubgroupSize();
- auto payloadShape = payloadTy.getShape();
- if (payloadShape.size() > 1)
- assert(
- payloadShape[0] == subgroupSize &&
- "Expected the first dimension of 2D tensor descriptor to be equal to "
- "subgroup size.");
-
- SmallVector<int> instData{subgroupSize};
- if (auto chunkSize = storeScatter.getChunkSize().value_or(0); chunkSize > 1)
- instData.push_back(chunkSize);
- else if (auto dstTdescTy =
- dyn_cast<xegpu::TensorDescType>(storeScatter.getDestType())) {
- if (dstTdescTy.getChunkSizeAsInt() > 1)
- instData.push_back(chunkSize);
- }
-
- LayoutInfo payloadLayout;
-
if (auto layout = storeScatter.getLayoutAttr()) {
payloadLayout = LayoutInfo(layout);
} else {
- payloadLayout = getDefaultSIMTLayoutInfo(
- payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
- /*scattered=*/true);
+ if (layoutKind == LayoutKind::InstData) {
+ SmallVector<int> instData{subgroupSize};
+ if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
+ chunkSize > 1)
+ instData.push_back(chunkSize);
+ else if (auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
+ storeScatter.getDestType())) {
+ if (dstTdescTy.getChunkSizeAsInt() > 1)
+ instData.push_back(chunkSize);
+ }
+ payloadLayout = LayoutInfo(
+ xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
+ } else {
+ auto payloadShape = payloadTy.getShape();
+ if (payloadShape.size() > 1)
+ assert(payloadShape[0] == subgroupSize &&
+ "Expected the first dimension of 2D tensor descriptor to be "
+ "equal to "
+ "subgroup size.");
+ payloadLayout = getDefaultSIMTLayoutInfo(
+ payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
+ /*scattered=*/true);
+ }
}
LayoutInfo maskLayout =
@@ -916,10 +958,10 @@ class RunLayoutInfoPropagation {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
- RunLayoutInfoPropagation(Operation *op) : target(op) {
+ RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) {
SymbolTableCollection symbolTable;
loadBaselineAnalyses(solver);
- solver.load<LayoutInfoPropagation>(symbolTable);
+ solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
(void)solver.initializeAndRun(op);
}
@@ -1159,7 +1201,19 @@ struct XeGPUPropagateLayoutPass final
} // namespace
void XeGPUPropagateLayoutPass::runOnOperation() {
- auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
+ LayoutKind layoutKind;
+ if (this->layoutKind == "lane")
+ layoutKind = LayoutKind::Lane;
+ else if (this->layoutKind == "inst")
+ layoutKind = LayoutKind::InstData;
+ else {
+ signalPassFailure();
+ getOperation()->emitError("Unsupported layout kind option: " +
+ this->layoutKind);
+ return;
+ }
+ RunLayoutInfoPropagation analysis(getOperation(), layoutKind);
+ // auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
// Print the analysis result and exit. (for debugging purposes)
if (printOnly) {
auto &os = llvm::outs();
@@ -1173,8 +1227,6 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
return {};
xegpu::DistributeLayoutAttr layoutAttr =
cast<xegpu::DistributeLayoutAttr>(layout.get());
- if (this->layoutKind == "lane")
- layoutAttr = layoutAttr.dropInstData();
if (layout.isSliceLayout())
return cast<xegpu::SliceAttr>(layoutAttr);
return cast<xegpu::LayoutAttr>(layoutAttr);
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 58461b8be52c4..c31ef323a94d2 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -2,17 +2,17 @@
// CHECK-LABEL: func.func @dpas_f16(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<8x16xf32>
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
-// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} :
-// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
-// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] {layout_result_0 = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>} :
-// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
-// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<0.000000e+00> : vector<8x16xf32>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16]>
+// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16]>>
+// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} :
+// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/166941
More information about the Mlir-commits
mailing list