[Mlir-commits] [mlir] 20d6def - [mlir][xegpu] refine basic routines (#138701)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 9 07:41:25 PDT 2025
Author: Chao Chen
Date: 2025-05-09T09:41:21-05:00
New Revision: 20d6def0ae45b0c7ebcc1627d299689fa34e8cc8
URL: https://github.com/llvm/llvm-project/commit/20d6def0ae45b0c7ebcc1627d299689fa34e8cc8
DIFF: https://github.com/llvm/llvm-project/commit/20d6def0ae45b0c7ebcc1627d299689fa34e8cc8.diff
LOG: [mlir][xegpu] refine basic routines (#138701)
This PR adds two interfaces for `LayoutAttr` and
updates the builder of `CreateNdOp` for convenience.
Added:
Modified:
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index f1bed70253ef3..6d04ee5599a23 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -243,8 +243,8 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
);
let builders = [
- AttrBuilder<(ins "llvm::ArrayRef<int>": $lane_layout,
- "llvm::ArrayRef<int>": $lane_data),
+ AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
+ "llvm::ArrayRef<int32_t>": $lane_data),
[{
auto sg_layout = DenseI32ArrayAttr();
auto sg_data = DenseI32ArrayAttr();
@@ -253,6 +253,25 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
return $_get($_ctxt, sg_layout, sg_data, inst_data,
DenseI32ArrayAttr::get($_ctxt, lane_layout),
DenseI32ArrayAttr::get($_ctxt, lane_data), order);
+ }]>,
+ AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
+ "llvm::ArrayRef<int32_t>": $lane_data,
+ "llvm::ArrayRef<int32_t>": $order),
+ [{
+ return $_get($_ctxt,
+ /*sg_layout =*/ nullptr,
+ /*sg_data =*/ nullptr,
+ /*inst_data =*/ nullptr,
+ DenseI32ArrayAttr::get($_ctxt, lane_layout),
+ DenseI32ArrayAttr::get($_ctxt, lane_data),
+ DenseI32ArrayAttr::get($_ctxt, order));
+ }]>,
+ AttrBuilder<(ins "DenseI32ArrayAttr": $lane_layout,
+ "DenseI32ArrayAttr": $lane_data,
+ "DenseI32ArrayAttr": $order),
+ [{
+ return $_get($_ctxt, /*sg_layout =*/ nullptr, /*sg_data =*/ nullptr,
+ /*inst_data =*/ nullptr, lane_layout, lane_data, order);
}]>
];
@@ -262,7 +281,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
}
bool isSgLayout() {
- return getSgLayout() == nullptr && getLaneLayout() != nullptr;
+ return !isWgLayout();
}
int64_t getRank() {
@@ -274,6 +293,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
return attr.size();
return 0;
}
+
+ LayoutAttr dropSgLayoutAndData() {
+ return LayoutAttr::get(getContext(), nullptr, nullptr, getInstData(),
+ getLaneLayout(), getLaneData(), getOrder());
+ }
+
+ LayoutAttr dropInstData() {
+ return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
+ getLaneLayout(), getLaneData(), getOrder());
+ }
+
}];
let assemblyFormat = "`<` struct(params) `>`";
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index a892f701f724e..238bb1567d301 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -142,12 +142,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
"llvm::ArrayRef<OpFoldResult>": $offsets)>,
- OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
- "llvm::ArrayRef<OpFoldResult>": $offsets,
- "llvm::ArrayRef<OpFoldResult>": $shape,
- "llvm::ArrayRef<OpFoldResult>": $strides)>,
-
- OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
+ OpBuilder<(ins "Type": $tdesc, "Value": $source,
"llvm::ArrayRef<OpFoldResult>": $offsets,
"llvm::ArrayRef<OpFoldResult>": $shape,
"llvm::ArrayRef<OpFoldResult>": $strides)>
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index f9d7e013826ed..f2cfa50e102f8 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -141,46 +141,24 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
}
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
- Type tdesc, TypedValue<MemRefType> source,
+ Type tdesc, Value source,
llvm::ArrayRef<OpFoldResult> offsets,
llvm::ArrayRef<OpFoldResult> shape,
llvm::ArrayRef<OpFoldResult> strides) {
assert(shape.size() && offsets.size() && strides.size() &&
shape.size() == strides.size() && shape.size() == offsets.size());
- llvm::SmallVector<int64_t> staticOffsets;
- llvm::SmallVector<int64_t> staticShape;
- llvm::SmallVector<int64_t> staticStrides;
+ Type srcTy = source.getType();
+ assert(isa<IntegerType>(srcTy) ||
+ isa<MemRefType>(srcTy) && "Source has to be either int or memref.");
+
llvm::SmallVector<Value> dynamicOffsets;
llvm::SmallVector<Value> dynamicShape;
llvm::SmallVector<Value> dynamicStrides;
- dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
- dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
- dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
- auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
- auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
- auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
-
- build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
- dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
-}
-
-void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
- Type tdesc, TypedValue<IntegerType> source,
- llvm::ArrayRef<OpFoldResult> offsets,
- llvm::ArrayRef<OpFoldResult> shape,
- llvm::ArrayRef<OpFoldResult> strides) {
- assert(shape.size() && offsets.size() && strides.size() &&
- shape.size() == strides.size() && shape.size() == offsets.size());
-
llvm::SmallVector<int64_t> staticOffsets;
llvm::SmallVector<int64_t> staticShape;
llvm::SmallVector<int64_t> staticStrides;
- llvm::SmallVector<Value> dynamicOffsets;
- llvm::SmallVector<Value> dynamicShape;
- llvm::SmallVector<Value> dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
@@ -190,6 +168,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
+ if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
+ auto memrefShape = memrefTy.getShape();
+ auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
+
+ // if shape and strides are from Memref, we don't need attributes for them
+ // to keep the IR print clean.
+ if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ staticShapeAttr = DenseI64ArrayAttr();
+ staticStridesAttr = DenseI64ArrayAttr();
+ }
+ }
+
build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
}
More information about the Mlir-commits
mailing list