[Mlir-commits] [mlir] [mlir][xegpu] Add definition of SliceAttr (PR #150146)

Chao Chen llvmlistbot at llvm.org
Wed Jul 23 11:55:36 PDT 2025


https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/150146

>From 2bc70b6a8487a8ce0f0e7e0c5ac5bc59035465ab Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 22 Jul 2025 19:46:04 +0000
Subject: [PATCH 1/7] add definition draft of SliceAttr

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 21 +++++++++++++++++++
 1 file changed, 21 insertions(+)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 42b5b7a0d4e3f..abbd227b9905f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -330,4 +330,25 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
   let genVerifyDecl = 1;
 }
 
+
+def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice"> {
+  let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];
+
+  let description = [{
+    Like LayoutAttr, SliceAttr describes data distribution among subgroups or work-items.
+    However, whereas LayoutAttr requires the data to have the same rank as the attribute,
+    SliceAttr permits the data to have a lower rank. In this case, compute units in the
+    specified dimensions share the data, provided that the remaining ranks match the data
+    rank. SliceAttr is commonly used by operations such as vector.multi_reduction and
+    vector.broadcast.
+  }];
+
+  let parameters = (ins
+    "Attribute": $parent,
+    "DenseI64ArrayAttr": $dims
+  );
+
+  let assemblyFormat = "`<` $parent `,` `dim` `=` $dims `>`";
+}
+
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD

>From 3959f9e5027f7c21f420c44a5e34501c115df361 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 22 Jul 2025 21:02:22 +0000
Subject: [PATCH 2/7] add layout traits

---
 mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt |  6 ++++++
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h        |  1 +
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td  | 11 +++++++++--
 mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt          |  1 +
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp            |  1 +
 5 files changed, 18 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
index 3f8cac4dc07c3..bbbeb71410a9b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
@@ -12,3 +12,9 @@ mlir_tablegen(XeGPUEnums.h.inc -gen-enum-decls)
 mlir_tablegen(XeGPUEnums.cpp.inc -gen-enum-defs)
 add_public_tablegen_target(MLIRXeGPUEnumsIncGen)
 add_dependencies(mlir-headers MLIRXeGPUEnumsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS XeGPUAttrs.td)
+mlir_tablegen(XeGPUAttrInterface.h.inc -gen-attr-interface-decls)
+mlir_tablegen(XeGPUAttrInterface.cpp.inc -gen-attr-interface-defs)
+add_public_tablegen_target(MLIRXeGPUAttrInterfaceIncGen)
+add_dependencies(mlir-headers MLIRXeGPUAttrInterfaceIncGen)
\ No newline at end of file
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index 8e2784f40ad39..cc8d58d8975b4 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -25,6 +25,7 @@ class TensorDescType;
 } // namespace xegpu
 } // namespace mlir
 
+#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.h.inc>
 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
 #define GET_ATTRDEF_CLASSES
 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index abbd227b9905f..b15dd4a3177f9 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -169,7 +169,14 @@ def XeGPU_FenceScopeAttr:
     let assemblyFormat = "$value";
 }
 
-def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
+def LayoutTrait: AttrInterface<"LayoutTrait"> {
+  let cppNamespace = "::mlir::xegpu";
+  let description = [{
+    Common trait for all XeGPU layouts.
+  }];
+}
+
+def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
   let summary = [{
     Describes the data distribution to subgroups and work-items for a tensor
     specified by the tensor descriptor.
@@ -331,7 +338,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
 }
 
 
-def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice"> {
+def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
   let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];
 
   let description = [{
diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
index 242a97ccfdf6d..89d986143e965 100644
--- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRXeGPUDialect
 
   DEPENDS
   MLIRXeGPUIncGen
+  MLIRXeGPUAttrInterfaceIncGen
   MLIRXeGPUAttrsIncGen
   MLIRXeGPUEnumsIncGen
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 78cbf884a1911..63160c98105c3 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -753,6 +753,7 @@ LogicalResult ConvertLayoutOp::verify() {
 } // namespace xegpu
 } // namespace mlir
 
+#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
 #define GET_OP_CLASSES
 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>

>From 2027cfc98321d8f68a713340cd652ab10625cfee Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 22 Jul 2025 23:46:10 +0000
Subject: [PATCH 3/7] add verifier and interface

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 54 ++++++++++++++++++-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 21 ++++++++
 2 files changed, 74 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index b15dd4a3177f9..e3b06714bdcc2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -174,6 +174,17 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> {
   let description = [{
     Common trait for all XeGPU layouts.
   }];
+
+  let methods = [
+    InterfaceMethod<"Get the effective sg layout",
+                    "std::optional<llvm::SmallVector<int>>",
+                    "getEffectiveSgLayout">,
+    InterfaceMethod<"Get the effective sg data",
+                    "std::optional<llvm::SmallVector<int>>",
+                    "getEffectiveSgData">,
+  ];
+
+
 }
 
 def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
@@ -331,6 +342,18 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
       return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
                              getLaneLayout(), getLaneData(), getOrder());
     }
+
+    std::optional<llvm::SmallVector<int32_t>> getEffectiveSgLayout() const {
+      if (DenseI32ArrayAttr layout = getSgLayout())
+        return llvm::to_vector(layout.asArrayRef());
+      return std::nullopt;
+    }
+
+    std::optional<llvm::SmallVector<int32_t>> getEffectiveSgData() const {
+      if (DenseI32ArrayAttr data = getSgData())
+        return llvm::to_vector(data.asArrayRef());
+      return std::nullopt;
+    }
   }];
 
   let assemblyFormat = "`<` struct(params) `>`";
@@ -351,11 +374,40 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
   }];
 
   let parameters = (ins
-    "Attribute": $parent,
+    "xegpu::LayoutAttr": $parent,
     "DenseI64ArrayAttr": $dims
   );
 
+  let extraClassDeclaration = [{
+    std::optional<llvm::SmallVector<int32_t>> getEffectiveSgLayout() const {
+      if (DenseI32ArrayAttr layout = getParent().getSgLayout()) {
+        llvm::ArrayRef<int64_t> dims = getDims().asArrayRef();
+        llvm::SmallVector<int32_t> result;
+        for (auto [i, v]: llvm::enumerate(layout.asArrayRef())) {
+          if (!llvm::is_contained(dims, i))
+            result.push_back(v);
+        }
+        return result;
+      }
+      return std::nullopt;
+    }
+    std::optional<llvm::SmallVector<int32_t>> getEffectiveSgData() const {
+      if (DenseI32ArrayAttr data = getParent().getSgData()) {
+        llvm::ArrayRef<int64_t> dims = getDims().asArrayRef();
+        llvm::SmallVector<int32_t> result;
+        for (auto [i, v]: llvm::enumerate(data.asArrayRef())) {
+          if (!llvm::is_contained(dims, i))
+            result.push_back(v);
+        }
+        return result;
+      }
+      return std::nullopt;
+
+    }
+  }];
+
   let assemblyFormat = "`<` $parent `,` `dim` `=` $dims `>`";
+  let genVerifyDecl = 1;
 }
 
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 642c393cbc2c8..7e293b6f0e1a3 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -206,6 +206,27 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// XeGPU_SliceAttr
+//===----------------------------------------------------------------------===//
+LogicalResult
+SliceAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+                  xegpu::LayoutAttr parent, DenseI64ArrayAttr dims) {
+  if (!parent || !dims)
+    return emitError() << "expected parent layout and dims attribute";
+
+  int rank = parent.getRank();
+  // check every element in dims is unique and smaller than rank
+  llvm::SmallDenseSet<int64_t> seen;
+  for (int64_t dim : dims.asArrayRef()) {
+    if (dim >= rank)
+      return emitError() << "invalid dim: " << dim;
+    if (!seen.insert(dim).second)
+      return emitError() << "repeated dim: " << dim;
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_TensorDescType
 //===----------------------------------------------------------------------===//

>From 638c0853dc2b76fbc01d8410cd6bb52aa7d20891 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 23 Jul 2025 15:52:26 +0000
Subject: [PATCH 4/7] add invalid unit test

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |  2 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  4 ++--
 mlir/test/Dialect/XeGPU/invalid.mlir          | 19 +++++++++++++++++++
 3 files changed, 22 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index e3b06714bdcc2..d0b2e936d6508 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -406,7 +406,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
     }
   }];
 
-  let assemblyFormat = "`<` $parent `,` `dim` `=` $dims `>`";
+  let assemblyFormat = "`<` $parent `,` `dims` `=` $dims `>`";
   let genVerifyDecl = 1;
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 7e293b6f0e1a3..21007f98643bc 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -220,9 +220,9 @@ SliceAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
   llvm::SmallDenseSet<int64_t> seen;
   for (int64_t dim : dims.asArrayRef()) {
     if (dim >= rank)
-      return emitError() << "invalid dim: " << dim;
+      return emitError() << "invalid dim (" << dim << ") in slice attribute.";
     if (!seen.insert(dim).second)
-      return emitError() << "repeated dim: " << dim;
+      return emitError() << "repeated dim (" << dim << ") in slice attribute.";
   }
   return success();
 }
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index eb564d55bfd51..c4e72820e9aec 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -658,3 +658,22 @@ func.func @tensor_desc_invalid_sg_data(%src: ui64, %offsets: vector<16xindex>) {
         #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 2], order = [0, 1, 2]>>
   return
 }
+
+// -----
+#l = #xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>
+// expected-error at +1 {{repeated dim (2) in slice attribute}}
+#s = #xegpu.slice<#l, dims = [2, 2]>
+func.func @slice_attr_repeat_dim() {
+  %offsets = arith.constant {layout_result_0 = #s} dense<0.8> : vector<16x8xindex>
+  return
+}
+
+// -----
+#l = #xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>
+// expected-error at +1 {{invalid dim (3) in slice attribute}}
+#s = #xegpu.slice<#l, dims = [3]>
+func.func @slice_attr_repeat_dim() {
+  %offsets = arith.constant {layout_result_0 = #s} dense<0.8> : vector<16x8xindex>
+  return
+}
+

>From 91048f06417bd8af3d58d35a516115da044e6451 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 23 Jul 2025 16:06:59 +0000
Subject: [PATCH 5/7] add wrappers

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 13 +++++++++++--
 1 file changed, 11 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index d0b2e936d6508..a38878bc6a61f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -183,8 +183,6 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> {
                     "std::optional<llvm::SmallVector<int>>",
                     "getEffectiveSgData">,
   ];
-
-
 }
 
 def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
@@ -402,7 +400,18 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
         return result;
       }
       return std::nullopt;
+    }
+
+    DenseI32ArrayAttr getOrder() const {
+      return getParent().getOrder();
+    }
+
+    bool isWgLayout() const {
+      return getParent().isWgLayout();
+    }
 
+    bool isSgLayout() const {
+      return getParent().isSgLayout();
     }
   }];
 

>From ddc42c2886ae3c49f10032caea27817dc6d542de Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 23 Jul 2025 17:51:42 +0000
Subject: [PATCH 6/7] update description

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 12 +++++++++++-
 1 file changed, 11 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 78a7c48af837e..8644be8e4204c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -187,7 +187,7 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> {
                     "getEffectiveSgLayout">,
     InterfaceMethod<"Get the effective sg data",
                     "std::optional<llvm::SmallVector<int>>",
-                    "getEffectiveSgData">,
+                    "getEffectiveSgData">
   ];
 }
 
@@ -375,6 +375,16 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
     specified dimensions share the data, provided that the remaining ranks match the data
     rank. SliceAttr is commonly used by operations such as vector.multi_reduction and
     vector.broadcast.
+
+    Example:
+    ```
+    #l = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>
+    #r = #xegpu.slice<#l, dim = 0>
+
+    %exp = math.exp %input {layout_result_0 = #l}: vector<256x128xf32>
+    %red = vector.multi_reduction<add>, %exp, %acc [0] {layout_result_0 = #r}: vector<256x128xf32> to vector<128xf32>
+    %bcast = vector.broadcast %red {layout_result_0 = #l} : vector<128xf32> to vector<256x128xf32>
+    ```
   }];
 
   let parameters = (ins

>From 36e2c3a118b0167c6e4f3341533f92353ddaebe2 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 23 Jul 2025 18:44:08 +0000
Subject: [PATCH 7/7] refactor

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h        |  6 +++---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td  | 15 +++------------
 .../include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td | 12 ++++++++++++
 3 files changed, 18 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index cc8d58d8975b4..c2d546fa08fe0 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -22,18 +22,18 @@
 namespace mlir {
 namespace xegpu {
 class TensorDescType;
+class LayoutAttr;
 } // namespace xegpu
 } // namespace mlir
 
+#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.h.inc>
 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
+
 #define GET_ATTRDEF_CLASSES
 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>
 #define GET_TYPEDEF_CLASSES
 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.h.inc>
-
-#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
-
 #define GET_OP_CLASSES
 #include <mlir/Dialect/XeGPU/IR/XeGPU.h.inc>
 
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 8644be8e4204c..36a12a2c2a029 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -396,24 +396,15 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
     std::optional<llvm::SmallVector<int32_t>> getEffectiveSgLayout() const {
       if (DenseI32ArrayAttr layout = getParent().getSgLayout()) {
         llvm::ArrayRef<int64_t> dims = getDims().asArrayRef();
-        llvm::SmallVector<int32_t> result;
-        for (auto [i, v]: llvm::enumerate(layout.asArrayRef())) {
-          if (!llvm::is_contained(dims, i))
-            result.push_back(v);
-        }
-        return result;
+        return XeGPUDialect::dropDims(layout.asArrayRef(), dims);
       }
       return std::nullopt;
     }
+
     std::optional<llvm::SmallVector<int32_t>> getEffectiveSgData() const {
       if (DenseI32ArrayAttr data = getParent().getSgData()) {
         llvm::ArrayRef<int64_t> dims = getDims().asArrayRef();
-        llvm::SmallVector<int32_t> result;
-        for (auto [i, v]: llvm::enumerate(data.asArrayRef())) {
-          if (!llvm::is_contained(dims, i))
-            result.push_back(v);
-        }
-        return result;
+        return XeGPUDialect::dropDims(data.asArrayRef(), dims);
       }
       return std::nullopt;
     }
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
index 549018b61d6fb..f07a758a59b96 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
@@ -41,6 +41,18 @@ def XeGPU_Dialect : Dialect {
       /// Checks if the given shape can be evenly distributed based on the layout
       /// and data factors provided by the LayoutAttr.
       static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::LayoutAttr attr);
+
+      /// drops the data in the specified dimension, and return the rest. e.g.,
+      /// for data = [32, 64, 8], dropPositions = [0, 2], it will return [64]
+      template<typename T, typename U>
+      static llvm::SmallVector<T> dropDims(llvm::ArrayRef<T> data, llvm::ArrayRef<U> dropPositions) {
+        llvm::SmallVector<T> result;
+        for (auto [i, v]: llvm::enumerate(data)) {
+          if (!llvm::is_contained(dropPositions, i))
+            result.push_back(v);
+        }
+        return result;
+      }
     }];
 }
 



More information about the Mlir-commits mailing list