[Mlir-commits] [mlir] [mlir][xegpu] Add definition of SliceAttr (PR #150146)
Chao Chen
llvmlistbot at llvm.org
Wed Jul 23 11:55:36 PDT 2025
https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/150146
>From 2bc70b6a8487a8ce0f0e7e0c5ac5bc59035465ab Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 22 Jul 2025 19:46:04 +0000
Subject: [PATCH 1/7] add definition draft of SliceAttr
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 21 +++++++++++++++++++
1 file changed, 21 insertions(+)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 42b5b7a0d4e3f..abbd227b9905f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -330,4 +330,25 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
let genVerifyDecl = 1;
}
+
+def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice"> {
+ let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];
+
+ let description = [{
+ Like LayoutAttr, SliceAttr describes data distribution among subgroups or work-items.
+ However, whereas LayoutAttr requires the data to have the same rank as the attribute,
+ SliceAttr permits the data to have a lower rank. In this case, compute units in the
+ specified dimensions share the data, provided that the remaining ranks match the data
+ rank. SliceAttr is commonly used by operations such as vector.multi_reduction and
+ vector.broadcast.
+ }];
+
+ let parameters = (ins
+ "Attribute": $parent,
+ "DenseI64ArrayAttr": $dims
+ );
+
+ let assemblyFormat = "`<` $parent `,` `dim` `=` $dims `>`";
+}
+
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
>From 3959f9e5027f7c21f420c44a5e34501c115df361 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 22 Jul 2025 21:02:22 +0000
Subject: [PATCH 2/7] add layout traits
---
mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt | 6 ++++++
mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h | 1 +
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 11 +++++++++--
mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt | 1 +
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 1 +
5 files changed, 18 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
index 3f8cac4dc07c3..bbbeb71410a9b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
@@ -12,3 +12,9 @@ mlir_tablegen(XeGPUEnums.h.inc -gen-enum-decls)
mlir_tablegen(XeGPUEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRXeGPUEnumsIncGen)
add_dependencies(mlir-headers MLIRXeGPUEnumsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS XeGPUAttrs.td)
+mlir_tablegen(XeGPUAttrInterface.h.inc -gen-attr-interface-decls)
+mlir_tablegen(XeGPUAttrInterface.cpp.inc -gen-attr-interface-defs)
+add_public_tablegen_target(MLIRXeGPUAttrInterfaceIncGen)
+add_dependencies(mlir-headers MLIRXeGPUAttrInterfaceIncGen)
\ No newline at end of file
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index 8e2784f40ad39..cc8d58d8975b4 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -25,6 +25,7 @@ class TensorDescType;
} // namespace xegpu
} // namespace mlir
+#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.h.inc>
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
#define GET_ATTRDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index abbd227b9905f..b15dd4a3177f9 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -169,7 +169,14 @@ def XeGPU_FenceScopeAttr:
let assemblyFormat = "$value";
}
-def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
+def LayoutTrait: AttrInterface<"LayoutTrait"> {
+ let cppNamespace = "::mlir::xegpu";
+ let description = [{
+ Common trait for all XeGPU layouts.
+ }];
+}
+
+def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
let summary = [{
Describes the data distribution to subgroups and work-items for a tensor
specified by the tensor descriptor.
@@ -331,7 +338,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
}
-def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice"> {
+def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];
let description = [{
diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
index 242a97ccfdf6d..89d986143e965 100644
--- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRXeGPUDialect
DEPENDS
MLIRXeGPUIncGen
+ MLIRXeGPUAttrInterfaceIncGen
MLIRXeGPUAttrsIncGen
MLIRXeGPUEnumsIncGen
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 78cbf884a1911..63160c98105c3 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -753,6 +753,7 @@ LogicalResult ConvertLayoutOp::verify() {
} // namespace xegpu
} // namespace mlir
+#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
#define GET_OP_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
>From 2027cfc98321d8f68a713340cd652ab10625cfee Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 22 Jul 2025 23:46:10 +0000
Subject: [PATCH 3/7] add verifier and interface
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 54 ++++++++++++++++++-
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 21 ++++++++
2 files changed, 74 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index b15dd4a3177f9..e3b06714bdcc2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -174,6 +174,17 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> {
let description = [{
Common trait for all XeGPU layouts.
}];
+
+ let methods = [
+ InterfaceMethod<"Get the effective sg layout",
+ "std::optional<llvm::SmallVector<int>>",
+ "getEffectiveSgLayout">,
+ InterfaceMethod<"Get the effective sg data",
+ "std::optional<llvm::SmallVector<int>>",
+ "getEffectiveSgData">,
+ ];
+
+
}
def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
@@ -331,6 +342,18 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
getLaneLayout(), getLaneData(), getOrder());
}
+
+ std::optional<llvm::SmallVector<int32_t>> getEffectiveSgLayout() const {
+ if (DenseI32ArrayAttr layout = getSgLayout())
+ return llvm::to_vector(layout.asArrayRef());
+ return std::nullopt;
+ }
+
+ std::optional<llvm::SmallVector<int32_t>> getEffectiveSgData() const {
+ if (DenseI32ArrayAttr data = getSgData())
+ return llvm::to_vector(data.asArrayRef());
+ return std::nullopt;
+ }
}];
let assemblyFormat = "`<` struct(params) `>`";
@@ -351,11 +374,40 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
}];
let parameters = (ins
- "Attribute": $parent,
+ "xegpu::LayoutAttr": $parent,
"DenseI64ArrayAttr": $dims
);
+ let extraClassDeclaration = [{
+ std::optional<llvm::SmallVector<int32_t>> getEffectiveSgLayout() const {
+ if (DenseI32ArrayAttr layout = getParent().getSgLayout()) {
+ llvm::ArrayRef<int64_t> dims = getDims().asArrayRef();
+ llvm::SmallVector<int32_t> result;
+ for (auto [i, v]: llvm::enumerate(layout.asArrayRef())) {
+ if (!llvm::is_contained(dims, i))
+ result.push_back(v);
+ }
+ return result;
+ }
+ return std::nullopt;
+ }
+ std::optional<llvm::SmallVector<int32_t>> getEffectiveSgData() const {
+ if (DenseI32ArrayAttr data = getParent().getSgData()) {
+ llvm::ArrayRef<int64_t> dims = getDims().asArrayRef();
+ llvm::SmallVector<int32_t> result;
+ for (auto [i, v]: llvm::enumerate(data.asArrayRef())) {
+ if (!llvm::is_contained(dims, i))
+ result.push_back(v);
+ }
+ return result;
+ }
+ return std::nullopt;
+
+ }
+ }];
+
let assemblyFormat = "`<` $parent `,` `dim` `=` $dims `>`";
+ let genVerifyDecl = 1;
}
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 642c393cbc2c8..7e293b6f0e1a3 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -206,6 +206,27 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
return success();
}
+//===----------------------------------------------------------------------===//
+// XeGPU_SliceAttr
+//===----------------------------------------------------------------------===//
+LogicalResult
+SliceAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+ xegpu::LayoutAttr parent, DenseI64ArrayAttr dims) {
+ if (!parent || !dims)
+ return emitError() << "expected parent layout and dims attribute";
+
+ int rank = parent.getRank();
+ // check every element in dims is unique and smaller than rank
+ llvm::SmallDenseSet<int64_t> seen;
+ for (int64_t dim : dims.asArrayRef()) {
+ if (dim >= rank)
+ return emitError() << "invalid dim: " << dim;
+ if (!seen.insert(dim).second)
+ return emitError() << "repeated dim: " << dim;
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_TensorDescType
//===----------------------------------------------------------------------===//
>From 638c0853dc2b76fbc01d8410cd6bb52aa7d20891 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 23 Jul 2025 15:52:26 +0000
Subject: [PATCH 4/7] add invalid unit test
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 2 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 4 ++--
mlir/test/Dialect/XeGPU/invalid.mlir | 19 +++++++++++++++++++
3 files changed, 22 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index e3b06714bdcc2..d0b2e936d6508 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -406,7 +406,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
}
}];
- let assemblyFormat = "`<` $parent `,` `dim` `=` $dims `>`";
+ let assemblyFormat = "`<` $parent `,` `dims` `=` $dims `>`";
let genVerifyDecl = 1;
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 7e293b6f0e1a3..21007f98643bc 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -220,9 +220,9 @@ SliceAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
llvm::SmallDenseSet<int64_t> seen;
for (int64_t dim : dims.asArrayRef()) {
if (dim >= rank)
- return emitError() << "invalid dim: " << dim;
+ return emitError() << "invalid dim (" << dim << ") in slice attribute.";
if (!seen.insert(dim).second)
- return emitError() << "repeated dim: " << dim;
+ return emitError() << "repeated dim (" << dim << ") in slice attribute.";
}
return success();
}
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index eb564d55bfd51..c4e72820e9aec 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -658,3 +658,22 @@ func.func @tensor_desc_invalid_sg_data(%src: ui64, %offsets: vector<16xindex>) {
#xegpu.layout<lane_layout = [8, 1], lane_data = [1, 2], order = [0, 1, 2]>>
return
}
+
+// -----
+#l = #xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>
+// expected-error at +1 {{repeated dim (2) in slice attribute}}
+#s = #xegpu.slice<#l, dims = [2, 2]>
+func.func @slice_attr_repeat_dim() {
+ %offsets = arith.constant {layout_result_0 = #s} dense<0.8> : vector<16x8xindex>
+ return
+}
+
+// -----
+#l = #xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>
+// expected-error at +1 {{invalid dim (3) in slice attribute}}
+#s = #xegpu.slice<#l, dims = [3]>
+func.func @slice_attr_repeat_dim() {
+ %offsets = arith.constant {layout_result_0 = #s} dense<0.8> : vector<16x8xindex>
+ return
+}
+
>From 91048f06417bd8af3d58d35a516115da044e6451 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 23 Jul 2025 16:06:59 +0000
Subject: [PATCH 5/7] add wrappers
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 13 +++++++++++--
1 file changed, 11 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index d0b2e936d6508..a38878bc6a61f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -183,8 +183,6 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> {
"std::optional<llvm::SmallVector<int>>",
"getEffectiveSgData">,
];
-
-
}
def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
@@ -402,7 +400,18 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
return result;
}
return std::nullopt;
+ }
+
+ DenseI32ArrayAttr getOrder() const {
+ return getParent().getOrder();
+ }
+
+ bool isWgLayout() const {
+ return getParent().isWgLayout();
+ }
+ bool isSgLayout() const {
+ return getParent().isSgLayout();
}
}];
>From ddc42c2886ae3c49f10032caea27817dc6d542de Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 23 Jul 2025 17:51:42 +0000
Subject: [PATCH 6/7] update description
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 12 +++++++++++-
1 file changed, 11 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 78a7c48af837e..8644be8e4204c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -187,7 +187,7 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> {
"getEffectiveSgLayout">,
InterfaceMethod<"Get the effective sg data",
"std::optional<llvm::SmallVector<int>>",
- "getEffectiveSgData">,
+ "getEffectiveSgData">
];
}
@@ -375,6 +375,16 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
specified dimensions share the data, provided that the remaining ranks match the data
rank. SliceAttr is commonly used by operations such as vector.multi_reduction and
vector.broadcast.
+
+ Example:
+ ```
+ #l = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>
+ #r = #xegpu.slice<#l, dim = 0>
+
+ %exp = math.exp %input {layout_result_0 = #l}: vector<256x128xf32>
+ %red = vector.multi_reduction<add>, %exp, %acc [0] {layout_result_0 = #r}: vector<256x128xf32> to vector<128xf32>
+ %bcast = vector.broadcast %red {layout_result_0 = #l} : vector<128xf32> to vector<256x128xf32>
+ ```
}];
let parameters = (ins
>From 36e2c3a118b0167c6e4f3341533f92353ddaebe2 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 23 Jul 2025 18:44:08 +0000
Subject: [PATCH 7/7] refactor
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h | 6 +++---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 15 +++------------
.../include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td | 12 ++++++++++++
3 files changed, 18 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index cc8d58d8975b4..c2d546fa08fe0 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -22,18 +22,18 @@
namespace mlir {
namespace xegpu {
class TensorDescType;
+class LayoutAttr;
} // namespace xegpu
} // namespace mlir
+#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.h.inc>
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
+
#define GET_ATTRDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>
#define GET_TYPEDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.h.inc>
-
-#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
-
#define GET_OP_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPU.h.inc>
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 8644be8e4204c..36a12a2c2a029 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -396,24 +396,15 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
std::optional<llvm::SmallVector<int32_t>> getEffectiveSgLayout() const {
if (DenseI32ArrayAttr layout = getParent().getSgLayout()) {
llvm::ArrayRef<int64_t> dims = getDims().asArrayRef();
- llvm::SmallVector<int32_t> result;
- for (auto [i, v]: llvm::enumerate(layout.asArrayRef())) {
- if (!llvm::is_contained(dims, i))
- result.push_back(v);
- }
- return result;
+ return XeGPUDialect::dropDims(layout.asArrayRef(), dims);
}
return std::nullopt;
}
+
std::optional<llvm::SmallVector<int32_t>> getEffectiveSgData() const {
if (DenseI32ArrayAttr data = getParent().getSgData()) {
llvm::ArrayRef<int64_t> dims = getDims().asArrayRef();
- llvm::SmallVector<int32_t> result;
- for (auto [i, v]: llvm::enumerate(data.asArrayRef())) {
- if (!llvm::is_contained(dims, i))
- result.push_back(v);
- }
- return result;
+ return XeGPUDialect::dropDims(data.asArrayRef(), dims);
}
return std::nullopt;
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
index 549018b61d6fb..f07a758a59b96 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
@@ -41,6 +41,18 @@ def XeGPU_Dialect : Dialect {
/// Checks if the given shape can be evenly distributed based on the layout
/// and data factors provided by the LayoutAttr.
static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::LayoutAttr attr);
+
+ /// drops the data in the specified dimension, and return the rest. e.g.,
+ /// for data = [32, 64, 8], dropPositions = [0, 2], it will return [64]
+ template<typename T, typename U>
+ static llvm::SmallVector<T> dropDims(llvm::ArrayRef<T> data, llvm::ArrayRef<U> dropPositions) {
+ llvm::SmallVector<T> result;
+ for (auto [i, v]: llvm::enumerate(data)) {
+ if (!llvm::is_contained(dropPositions, i))
+ result.push_back(v);
+ }
+ return result;
+ }
}];
}
More information about the Mlir-commits
mailing list