[all-commits] [llvm/llvm-project] 7bcdec: [MLIR][XeGPU] Enable WG-level mxfp GEMM via gener...
Jianhui Li via All-commits
all-commits at lists.llvm.org
Wed Jun 10 12:18:00 PDT 2026
Branch: refs/heads/main
Home: https://github.com/llvm/llvm-project
Commit: 7bcdec0b48ee8f64c16d1c13d7940073c3cb03a7
https://github.com/llvm/llvm-project/commit/7bcdec0b48ee8f64c16d1c13d7940073c3cb03a7
Author: Jianhui Li <jian.hui.li at intel.com>
Date: 2026-06-10 (Wed, 10 Jun 2026)
Changed paths:
M mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
M mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
M mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
M mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
M mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
M mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
M mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
M mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
M mlir/test/Dialect/XeGPU/propagate-layout.mlir
A mlir/test/Integration/Dialect/XeGPU/WG/simple_mxfp_gemm_dequantizeB_F4.mlir
A mlir/test/Integration/Dialect/XeGPU/WG/simple_mxfp_gemm_quantizeA_F4.mlir
Log Message:
-----------
[MLIR][XeGPU] Enable WG-level mxfp GEMM via generalized shape_cast collapse inference (#201496)
Summary
Bringing up two WG-level mxfp GEMM integration tests —
simple_mxfp_gemm_quantizeA_F4 and
simple_mxfp_gemm_dequantizeB_F4 — exposed several gaps in the XeGPU
layout-propagation and unroll paths
that previously kept them from compiling end-to-end. This PR lands those
two tests as the motivating
workloads, plus the supporting changes:
1. A generalized shape_cast collapse layout inference — required because
the mxfp lowering inserts
vector.shape_cast ops that collapse multiple src dims into a single dst
dim with non-trivial sg / lane
layouts spanning across them. The previous matchCollapseToInnermostDim
only covered the narrow […] →
[N] / [1, N] shape and could not infer correct source layouts for these
patterns.
2. A small primitive (expandDims) on the layout attribute so the new
code stays as elegant as the use
case of collapseDims.
3. Bug fixes uncovered while running these workloads end-to-end
(transpose layout check, layout-attr
unroll cast crash, drop-dims order pollution).
What's in this PR
Motivating integration tests (the driving force)
-
mlir/test/Integration/Dialect/XeGPU/WG/simple_mxfp_gemm_quantizeA_F4.mlir
-
mlir/test/Integration/Dialect/XeGPU/WG/simple_mxfp_gemm_dequantizeB_F4.mlir
These exercise WG-level GEMM with mxfp quantization (BF16 × F4 paths).
They depend on every other
change in the PR; without them, layout propagation crashes or yields
conflicting layouts on the
inserted vector.shape_cast and xegpu.load_matrix / xegpu.store_matrix
ops.
Generalized shape_cast collapse inference
- New utility xegpu::matchDimCollapse(srcShape, resShape, collapseDims)
in XeGPUUtils.{h,cpp} — the
dual of matchSplitDimExpansion, returning per-dst-dim groups of src
indices.
- inferShapeCastSourceLayout use case 3 now handles arbitrary collapse
patterns:
- sg_layout / lane_layout spread outer-to-inner, so each subgroup / lane
owns a contiguous run in the
collapsed dst dim's row-major linearization.
- sg_data / lane_data / inst_data fill innermost-first, with per-dim
caps from any layout already placed.
- inst_data is seeded from lane_layout * lane_data per dim; the
remaining factor spreads innermost-first.
- order is rewritten by walking dst order fastest-first and emitting
each group's src dims innermost-fastest.
- Net effect for the mxfp tests: no data movement across sg / lane
boundaries when shape_cast collapses dims.
Refactor: expandDims interface method
- Added expandDims(int64_t dim, ArrayRef<int64_t> targetShape) to the
DistributeLayoutAttr interface,
with implementations on both LayoutAttr and SliceAttr. It's the
rank-increasing dual of collapseDims
and bakes in the distribution policy above.
- inferShapeCastSourceLayout use case 3 now mirrors use case 2's
per-group loop:
auto srcLayout = resLayout;
for (dst dim in reverse) {
if (group.empty()) srcLayout = srcLayout.dropDims({dstIdx});
else if (group.size() > 1) srcLayout = srcLayout.expandDims(dstIdx,
targetShape);
}
return srcLayout;
- Replaces ~190 lines of inlined per-field distribution logic with a
handful of lines.
Bug fixes uncovered while bringing up the integration tests
- LayoutAttr::isTransposeOf: corrected the per-dim check to match
vector.transpose semantics (dst[i] =
src[perm[i]]); the old comparison indexed src and dst inversely.
- LayoutAttr::dropDims: stop synthesizing a default [rank-1,...,0] order
when the input had none — that
synthesized order tripped collapseDims's adjacency check downstream.
- UnrollLoadMatrixOp / UnrollStoreMatrixOp: stop assuming the op's
layout is always a LayoutAttr. Use
DistributeLayoutAttr and guard dropInstData() so SliceAttr /
missing-layout inputs no longer crash
unrolling.
Unit-test coverage
- New shape_cast collapse coverage in both
propagate-layout-subgroup.mlir and
propagate-layout-inst-data.mlir for: plain innermost collapse, layout
spill across multiple src dims,
and multi-group collapse.
- Updated one lane_layout expectation in propagate-layout.mlir to
reflect the generalized distribution.
Files changed
- mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td — interface + class
declarations for expandDims
- mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h — declaration for
matchDimCollapse
- mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp — expandDims impls,
dropDims order fix, isTransposeOf fix
- mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp — refactored use
case 3
- mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp — load/store_matrix
unroll hardening
- mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp — matchDimCollapse impl
- mlir/test/Dialect/XeGPU/propagate-layout-{subgroup,inst-data,}.mlir —
new tests / updated expectation
-
mlir/test/Integration/Dialect/XeGPU/WG/simple_mxfp_gemm_{dequantizeB_F4,quantizeA_F4}.mlir
— new
motivating integration tests
---------
Co-authored-by: Claude Opus 4.7 (1M context) <noreply at anthropic.com>
To unsubscribe from these emails, change your notification settings at https://github.com/llvm/llvm-project/settings/notifications
More information about the All-commits
mailing list