[Mlir-commits] [mlir] [MLIR][XeGPU] Remove use-by-broadcast-only restriction for ShapeCast op in Wg-to-Sg distribution pass (PR #193640)

Jianhui Li llvmlistbot at llvm.org
Wed Apr 22 18:00:18 PDT 2026


https://github.com/Jianhui-Li created https://github.com/llvm/llvm-project/pull/193640

The WgToSgVectorShapeCastOp pattern previously required that vector.shape_cast operations expanding unit dimensions could only be used by vector.broadcast operations. This constraint was not necessary anymore after the recent refectory. 

>From 30a93d72c689f98a31e6c5bdb7053067c87cda00 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 23 Apr 2026 00:32:19 +0000
Subject: [PATCH] [MLIR][XeGPU] Remove restrictive broadcast-only constraint in
 WgToSgVectorShapeCastOp

The WgToSgVectorShapeCastOp pattern previously required that vector.shape_cast
operations expanding unit dimensions could only be used by vector.broadcast
operations. This constraint was overly restrictive and caused legitimate IR
patterns to fail during the xegpu-wg-to-sg-distribute pass.

The constraint rejected shape_cast operations used directly by elementwise
operations (such as arith.addi), even though such usage is semantically valid.
When a shape_cast only expands unit dimensions, the sg_data for those
dimensions is 1, meaning there's no change in data distribution. Therefore,
elementwise operations can safely consume these shape_cast results.

This patch removes the usedByBroadcastOp check, allowing shape_cast operations
to be used by any valid consumer operation, including elementwise ops.

Added a regression test @shape_cast_used_by_elementwise to ensure this pattern
continues to work correctly.

Fixes: vector.shape_cast -> arith.addi -> xegpu.store pattern
---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 11 ---------
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 24 +++++++++++++++++++
 2 files changed, 24 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index e083507173d31..9114e37b0e42b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1161,17 +1161,6 @@ struct WgToSgVectorShapeCastOp
       xegpu::DistributeLayoutAttr sourceLayout =
           xegpu::getTemporaryLayout(op->getOpOperand(0));
 
-      auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
-        return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
-          return isa<vector::BroadcastOp>(user);
-        });
-      };
-
-      if (!usedByBroadcastOp(op))
-        return rewriter.notifyMatchFailure(
-            op, "ShapeCast ops that expand unit dimensions and are used by "
-                "non-broadcast operations are not supported.");
-
       if (!sourceLayout.isSliceOf(layout))
         return rewriter.notifyMatchFailure(
             op, "The ShapeCast op only expands dimensions, the input layout "
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 3bc43b780ade2..c3eb59adee2a6 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -960,4 +960,28 @@ gpu.module @test_distribution {
     gpu.return
   }
 
+  // CHECK-LABEL: @shape_cast_used_by_elementwise
+  gpu.func @shape_cast_used_by_elementwise(%dst: memref<1x1x16xf32>) {
+    // Regression test: shape_cast expanding unit dimensions can be used by elementwise ops
+    // This previously failed with "ShapeCast ops that expand unit dimensions and are used by
+    // non-broadcast operations are not supported."
+
+    // CHECK: vector.step : vector<16xindex>
+    // CHECK: vector.shape_cast {{.*}} : vector<16xindex> to vector<1x1x16xindex>
+    // CHECK: arith.addi {{.*}} : vector<1x1x16xindex>
+    // CHECK: xegpu.store {{.*}} : vector<1x1x16xf32>, i64, vector<1x1x16xindex>, vector<1x1x16xi1>
+    %step = vector.step : vector<16xindex>
+    %shape_cast = vector.shape_cast %step : vector<16xindex> to vector<1x1x16xindex>
+    %cst = arith.constant dense<10> : vector<1x1x16xindex>
+    %add = arith.addi %shape_cast, %cst : vector<1x1x16xindex>
+
+    %cst_val = arith.constant dense<1.0> : vector<1x1x16xf32>
+    %intptr = memref.extract_aligned_pointer_as_index %dst : memref<1x1x16xf32> -> index
+    %ptr = arith.index_cast %intptr : index to i64
+    %mask = arith.constant dense<true> : vector<1x1x16xi1>
+
+    xegpu.store %cst_val, %ptr[%add], %mask {layout = #xegpu.layout<sg_layout = [1, 1, 1], sg_data = [1, 1, 16]>} : vector<1x1x16xf32>, i64, vector<1x1x16xindex>, vector<1x1x16xi1>
+    gpu.return
+  }
+
 }



More information about the Mlir-commits mailing list