[Mlir-commits] [mlir] [MLIR][XeGPU] Remove leading unit dims from vector ops before unrolling (PR #165030)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 24 12:23:17 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

<details>
<summary>Changes</summary>

This PR uses the upstream populateCastAwayVectorLeadingOneDimPatterns to remove leading unit dims from vector ops and then do the unrolling/blocking

---
Full diff: https://github.com/llvm/llvm-project/pull/165030.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp (+7) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-blocking.mlir (+46-13) 


``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 2c37140ad9c76..ec5feb8bc8c4a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -344,6 +344,13 @@ void XeGPUBlockingPass::runOnOperation() {
 
   xegpu::doSCFStructuralTypeConversionWithTensorType(op, converter);
 
+  // Remove leading unit dimensions from vector ops and then
+  // do the unrolling.
+  {
+    RewritePatternSet patterns(ctx);
+    vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+    (void)applyPatternsGreedily(op, std::move(patterns));
+  }
   xegpu::UnrollOptions options;
   options.setFilterConstraint(
       [&](Operation *op) -> LogicalResult { return success(needsUnroll(op)); });
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index 7e742af754fbe..d61908b422194 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -715,7 +715,8 @@ gpu.module @test_kernel {
 gpu.module @test_kernel {
   // CHECK-LABEL: load_store_nd_with_offsets
   // CHECK-SAME: [[arg0:%.+]]: memref<1024x1024xf32>, [[arg1:%.+]]: memref<1024x1024xf32>, [[arg2:%.+]]: memref<1024x1024xf32>
-  // CHECK-DAG: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
+  // CHECK-DAG: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<32xf32>
+  // CHECK-DAG: [[cst_0:%.+]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
   // CHECK-DAG: [[c16:%.+]] = arith.constant 16 : index
   // CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
   // CHECK: [[tdesc_a:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
@@ -723,20 +724,27 @@ gpu.module @test_kernel {
   // CHECK: [[tdesc_c:%.+]] = xegpu.create_nd_tdesc [[arg2]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
   // CHECK: [[ld_a0:%.+]] = xegpu.load_nd [[tdesc_a]][[[c0]], [[c0]]]  : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
   // CHECK: [[ld_a1:%.+]] = xegpu.load_nd [[tdesc_a]][[[c0]], [[c16]]]  : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
+  // CHECK: [[ins_a0:%.+]] = vector.insert_strided_slice [[ld_a0]], [[cst_0]] {offsets = [0, 0], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
+  // CHECK: [[ins_a1:%.+]] = vector.insert_strided_slice [[ld_a1]], [[ins_a0]] {offsets = [0, 16], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
   // CHECK: [[ld_b0:%.+]] = xegpu.load_nd [[tdesc_b]][[[c0]], [[c0]]]  : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
   // CHECK: [[ld_b1:%.+]] = xegpu.load_nd [[tdesc_b]][[[c0]], [[c16]]]  : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
-  // CHECK: [[cast_a0:%.+]] = vector.shape_cast [[ld_a0]] : vector<1x16xf32> to vector<16xf32>
-  // CHECK: [[cast_b0:%.+]] = vector.shape_cast [[ld_b0]] : vector<1x16xf32> to vector<16xf32>
-  // CHECK: [[add0:%.+]] = arith.addf [[cast_a0]], [[cast_b0]] : vector<16xf32>
-  // CHECK: [[ins0:%.+]] = vector.insert_strided_slice [[add0]], [[cst]] {offsets = [0, 0], strides = [1]} : vector<16xf32> into vector<1x32xf32>
-  // CHECK: [[cast_a1:%.+]] = vector.shape_cast [[ld_a1]] : vector<1x16xf32> to vector<16xf32>
-  // CHECK: [[cast_b1:%.+]] = vector.shape_cast [[ld_b1]] : vector<1x16xf32> to vector<16xf32>
-  // CHECK: [[add1:%.+]] = arith.addf [[cast_a1]], [[cast_b1]] : vector<16xf32>
-  // CHECK: [[ins1:%.+]] = vector.insert_strided_slice [[add1]], [[ins0]] {offsets = [0, 16], strides = [1]} : vector<16xf32> into vector<1x32xf32>
-  // CHECK: [[ext0:%.+]] = vector.extract_strided_slice [[ins1]] {offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
-  // CHECK: [[ext1:%.+]] = vector.extract_strided_slice [[ins1]] {offsets = [0, 16], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
-  // CHECK: xegpu.store_nd [[ext0]], [[tdesc_c]][[[c0]], [[c0]]]  : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
-  // CHECK: xegpu.store_nd [[ext1]], [[tdesc_c]][[[c0]], [[c16]]]  : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
+  // CHECK: [[ins_b0:%.+]] = vector.insert_strided_slice [[ld_b0]], [[cst_0]] {offsets = [0, 0], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
+  // CHECK: [[ins_b1:%.+]] = vector.insert_strided_slice [[ld_b1]], [[ins_b0]] {offsets = [0, 16], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
+  // CHECK: [[ext_a:%.+]] = vector.extract [[ins_a1]][0] : vector<32xf32> from vector<1x32xf32>
+  // CHECK: [[ext_b:%.+]] = vector.extract [[ins_b1]][0] : vector<32xf32> from vector<1x32xf32>
+  // CHECK: [[slice_a0:%.+]] = vector.extract_strided_slice [[ext_a]] {offsets = [0], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
+  // CHECK: [[slice_b0:%.+]] = vector.extract_strided_slice [[ext_b]] {offsets = [0], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
+  // CHECK: [[add0:%.+]] = arith.addf [[slice_a0]], [[slice_b0]] : vector<16xf32>
+  // CHECK: [[ins_add0:%.+]] = vector.insert_strided_slice [[add0]], [[cst]] {offsets = [0], strides = [1]} : vector<16xf32> into vector<32xf32>
+  // CHECK: [[slice_a1:%.+]] = vector.extract_strided_slice [[ext_a]] {offsets = [16], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
+  // CHECK: [[slice_b1:%.+]] = vector.extract_strided_slice [[ext_b]] {offsets = [16], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
+  // CHECK: [[add1:%.+]] = arith.addf [[slice_a1]], [[slice_b1]] : vector<16xf32>
+  // CHECK: [[ins_add1:%.+]] = vector.insert_strided_slice [[add1]], [[ins_add0]] {offsets = [16], strides = [1]} : vector<16xf32> into vector<32xf32>
+  // CHECK: [[broadcast:%.+]] = vector.broadcast [[ins_add1]] : vector<32xf32> to vector<1x32xf32>
+  // CHECK: [[ext_result0:%.+]] = vector.extract_strided_slice [[broadcast]] {offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
+  // CHECK: [[ext_result1:%.+]] = vector.extract_strided_slice [[broadcast]] {offsets = [0, 16], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
+  // CHECK: xegpu.store_nd [[ext_result0]], [[tdesc_c]][[[c0]], [[c0]]]  : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
+  // CHECK: xegpu.store_nd [[ext_result1]], [[tdesc_c]][[[c0]], [[c16]]]  : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
   gpu.func @load_store_nd_with_offsets(%A: memref<1024x1024xf32>, %B: memref<1024x1024xf32>, %C: memref<1024x1024xf32>) {
     %c0 = arith.constant 0 : index
 
@@ -752,3 +760,28 @@ gpu.module @test_kernel {
     gpu.return
   }
 }
+
+// -----
+#inst_data = #xegpu.layout<inst_data = [1, 1, 32]>
+gpu.module @test_kernel {
+  // CHECK-LABEL: load_add_store_leading_unit_dims
+  // CHECK-SAME: [[arg0:%.+]]: ui64, [[arg1:%.+]]: ui64, [[arg2:%.+]]: ui64
+  // CHECK: [[mask:%.+]] = arith.constant dense<true> : vector<32xi1>
+  // CHECK: [[offsets:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<32xindex>
+  // CHECK: [[a:%.+]] = xegpu.load [[arg0]][[[offsets]]], [[mask]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<32xindex>, vector<32xi1> -> vector<32xf32>
+  // CHECK: [[b:%.+]] = xegpu.load [[arg1]][[[offsets]]], [[mask]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<32xindex>, vector<32xi1> -> vector<32xf32>
+  // CHECK: [[add:%.+]] = arith.addf [[a]], [[b]] : vector<32xf32>
+  // CHECK: xegpu.store [[add]], [[arg2]][[[offsets]]], [[mask]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<32xf32>, ui64, vector<32xindex>, vector<32xi1>
+  gpu.func @load_add_store_leading_unit_dims(%A: ui64, %B: ui64, %C: ui64) {
+    %cst = arith.constant {layout_result_0 = #inst_data} dense<[
+      [[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120,
+        128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]]
+    ]> : vector<1x1x32xindex>
+    %mask = arith.constant {layout_result_0 = #inst_data} dense<true> : vector<1x1x32xi1>
+    %a = xegpu.load %A[%cst], %mask {chunk_size = 1, layout_result_0 = #inst_data, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
+    %b = xegpu.load %B[%cst], %mask {chunk_size = 1, layout_result_0 = #inst_data, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
+    %addf = arith.addf %a, %b {layout_result_0 = #inst_data} : vector<1x1x32xf32>
+    xegpu.store %addf, %C[%cst], %mask {chunk_size = 1, layout_operand_0 = #inst_data, layout_operand_2 = #inst_data, layout_operand_3 = #inst_data, l1_hint = #xegpu.cache_hint<cached>} : vector<1x1x32xf32>, ui64, vector<1x1x32xindex>, vector<1x1x32xi1>
+    gpu.return
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/165030


More information about the Mlir-commits mailing list