[Mlir-commits] [mlir] [MLIR][XeGPU] Remove use-by-broadcast-only restriction for ShapeCast op in Wg-to-Sg distribution pass (PR #193640)
Jianhui Li
llvmlistbot at llvm.org
Wed Apr 22 18:00:18 PDT 2026
https://github.com/Jianhui-Li created https://github.com/llvm/llvm-project/pull/193640
The WgToSgVectorShapeCastOp pattern previously required that vector.shape_cast operations expanding unit dimensions could only be used by vector.broadcast operations. This constraint was not necessary anymore after the recent refectory.
>From 30a93d72c689f98a31e6c5bdb7053067c87cda00 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 23 Apr 2026 00:32:19 +0000
Subject: [PATCH] [MLIR][XeGPU] Remove restrictive broadcast-only constraint in
WgToSgVectorShapeCastOp
The WgToSgVectorShapeCastOp pattern previously required that vector.shape_cast
operations expanding unit dimensions could only be used by vector.broadcast
operations. This constraint was overly restrictive and caused legitimate IR
patterns to fail during the xegpu-wg-to-sg-distribute pass.
The constraint rejected shape_cast operations used directly by elementwise
operations (such as arith.addi), even though such usage is semantically valid.
When a shape_cast only expands unit dimensions, the sg_data for those
dimensions is 1, meaning there's no change in data distribution. Therefore,
elementwise operations can safely consume these shape_cast results.
This patch removes the usedByBroadcastOp check, allowing shape_cast operations
to be used by any valid consumer operation, including elementwise ops.
Added a regression test @shape_cast_used_by_elementwise to ensure this pattern
continues to work correctly.
Fixes: vector.shape_cast -> arith.addi -> xegpu.store pattern
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 11 ---------
.../XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 24 +++++++++++++++++++
2 files changed, 24 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index e083507173d31..9114e37b0e42b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1161,17 +1161,6 @@ struct WgToSgVectorShapeCastOp
xegpu::DistributeLayoutAttr sourceLayout =
xegpu::getTemporaryLayout(op->getOpOperand(0));
- auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
- return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
- return isa<vector::BroadcastOp>(user);
- });
- };
-
- if (!usedByBroadcastOp(op))
- return rewriter.notifyMatchFailure(
- op, "ShapeCast ops that expand unit dimensions and are used by "
- "non-broadcast operations are not supported.");
-
if (!sourceLayout.isSliceOf(layout))
return rewriter.notifyMatchFailure(
op, "The ShapeCast op only expands dimensions, the input layout "
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 3bc43b780ade2..c3eb59adee2a6 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -960,4 +960,28 @@ gpu.module @test_distribution {
gpu.return
}
+ // CHECK-LABEL: @shape_cast_used_by_elementwise
+ gpu.func @shape_cast_used_by_elementwise(%dst: memref<1x1x16xf32>) {
+ // Regression test: shape_cast expanding unit dimensions can be used by elementwise ops
+ // This previously failed with "ShapeCast ops that expand unit dimensions and are used by
+ // non-broadcast operations are not supported."
+
+ // CHECK: vector.step : vector<16xindex>
+ // CHECK: vector.shape_cast {{.*}} : vector<16xindex> to vector<1x1x16xindex>
+ // CHECK: arith.addi {{.*}} : vector<1x1x16xindex>
+ // CHECK: xegpu.store {{.*}} : vector<1x1x16xf32>, i64, vector<1x1x16xindex>, vector<1x1x16xi1>
+ %step = vector.step : vector<16xindex>
+ %shape_cast = vector.shape_cast %step : vector<16xindex> to vector<1x1x16xindex>
+ %cst = arith.constant dense<10> : vector<1x1x16xindex>
+ %add = arith.addi %shape_cast, %cst : vector<1x1x16xindex>
+
+ %cst_val = arith.constant dense<1.0> : vector<1x1x16xf32>
+ %intptr = memref.extract_aligned_pointer_as_index %dst : memref<1x1x16xf32> -> index
+ %ptr = arith.index_cast %intptr : index to i64
+ %mask = arith.constant dense<true> : vector<1x1x16xi1>
+
+ xegpu.store %cst_val, %ptr[%add], %mask {layout = #xegpu.layout<sg_layout = [1, 1, 1], sg_data = [1, 1, 16]>} : vector<1x1x16xf32>, i64, vector<1x1x16xindex>, vector<1x1x16xi1>
+ gpu.return
+ }
+
}
More information about the Mlir-commits
mailing list