[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for elementwise ops in Wg to Sg distribute pass [1/N] (PR #142797)

Adam Siemieniuk llvmlistbot at llvm.org
Tue Jun 10 09:31:08 PDT 2025


================
@@ -0,0 +1,1048 @@
+// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
+
+gpu.module @elementwise_ops {
+   // CHECK-LABEL: elemwise_ops
+   gpu.func @elemwise_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
+        %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+          -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+        %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
+          -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+        %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32>
+          -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+        %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32>
+          -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+
+        %load_a = xegpu.load_nd %tdesc_a
+          : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+          -> vector<24x32xf32>
+        %load_b = xegpu.load_nd %tdesc_b
+          : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+          -> vector<24x32xf32>
+        %load_c = xegpu.load_nd %tdesc_c
+          : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+          -> vector<24x32xi32>
+        %load_d = xegpu.load_nd %tdesc_d
+          : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+          -> vector<24x32xi32>
+
+        // Floating point ops
----------------
adam-smnk wrote:

Kudos for coverage but that's a bit of an overkill 😅
I'd suggest one op per dialect (arith, math) and type (unary, binary) + ops for which there's a specific pattern/behavior exclusive to them if there are any + a negative case like not matching ternary (like `arith.select`).

https://github.com/llvm/llvm-project/pull/142797


More information about the Mlir-commits mailing list