[Mlir-commits] [mlir] [mlir][xegpu] refine basic routines (PR #138701)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 6 07:58:54 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Chao Chen (chencha3)
<details>
<summary>Changes</summary>
This PR adds two interfaces for `LayoutAttr` and updates the builder of `CreateNdOp` for convenience.
---
Full diff: https://github.com/llvm/llvm-project/pull/138701.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+34-1)
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+1-6)
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+16-27)
``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index f1bed70253ef3..b4236af497587 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -253,6 +253,28 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
return $_get($_ctxt, sg_layout, sg_data, inst_data,
DenseI32ArrayAttr::get($_ctxt, lane_layout),
DenseI32ArrayAttr::get($_ctxt, lane_data), order);
+ }]>,
+ AttrBuilder<(ins "llvm::ArrayRef<int>": $lane_layout,
+ "llvm::ArrayRef<int>": $lane_data,
+ "llvm::ArrayRef<int>": $order),
+ [{
+ auto sg_layout = DenseI32ArrayAttr();
+ auto sg_data = DenseI32ArrayAttr();
+ auto inst_data = DenseI32ArrayAttr();
+ return $_get($_ctxt, sg_layout, sg_data, inst_data,
+ DenseI32ArrayAttr::get($_ctxt, lane_layout),
+ DenseI32ArrayAttr::get($_ctxt, lane_data),
+ DenseI32ArrayAttr::get($_ctxt, order));
+ }]>,
+ AttrBuilder<(ins "DenseI32ArrayAttr": $lane_layout,
+ "DenseI32ArrayAttr": $lane_data,
+ "DenseI32ArrayAttr": $order),
+ [{
+ auto sg_layout = DenseI32ArrayAttr();
+ auto sg_data = DenseI32ArrayAttr();
+ auto inst_data = DenseI32ArrayAttr();
+ return $_get($_ctxt, sg_layout, sg_data, inst_data,
+ lane_layout, lane_data, order);
}]>
];
@@ -262,7 +284,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
}
bool isSgLayout() {
- return getSgLayout() == nullptr && getLaneLayout() != nullptr;
+ return !isWgLayout();
}
int64_t getRank() {
@@ -274,6 +296,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
return attr.size();
return 0;
}
+
+ LayoutAttr dropSgLayoutAndData() {
+ return LayoutAttr::get(getContext(), nullptr, nullptr, getInstData(),
+ getLaneLayout(), getLaneData(), getOrder());
+ }
+
+ LayoutAttr dropInstData() {
+ return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
+ getLaneLayout(), getLaneData(), getOrder());
+ }
+
}];
let assemblyFormat = "`<` struct(params) `>`";
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 5fa18754305ca..627de858d94aa 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -142,12 +142,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
"llvm::ArrayRef<OpFoldResult>": $offsets)>,
- OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
- "llvm::ArrayRef<OpFoldResult>": $offsets,
- "llvm::ArrayRef<OpFoldResult>": $shape,
- "llvm::ArrayRef<OpFoldResult>": $strides)>,
-
- OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
+ OpBuilder<(ins "Type": $tdesc, "Value": $source,
"llvm::ArrayRef<OpFoldResult>": $offsets,
"llvm::ArrayRef<OpFoldResult>": $shape,
"llvm::ArrayRef<OpFoldResult>": $strides)>
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index f9d7e013826ed..7df6c794615a1 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -141,46 +141,24 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
}
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
- Type tdesc, TypedValue<MemRefType> source,
+ Type tdesc, Value source,
llvm::ArrayRef<OpFoldResult> offsets,
llvm::ArrayRef<OpFoldResult> shape,
llvm::ArrayRef<OpFoldResult> strides) {
assert(shape.size() && offsets.size() && strides.size() &&
shape.size() == strides.size() && shape.size() == offsets.size());
- llvm::SmallVector<int64_t> staticOffsets;
- llvm::SmallVector<int64_t> staticShape;
- llvm::SmallVector<int64_t> staticStrides;
+ auto intTy = dyn_cast<IntegerType>(source.getType());
+ auto memrefTy = dyn_cast<MemRefType>(source.getType());
+ assert(intTy || memrefTy && "Source has to be either int or memref.");
+
llvm::SmallVector<Value> dynamicOffsets;
llvm::SmallVector<Value> dynamicShape;
llvm::SmallVector<Value> dynamicStrides;
- dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
- dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
- dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
- auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
- auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
- auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
-
- build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
- dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
-}
-
-void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
- Type tdesc, TypedValue<IntegerType> source,
- llvm::ArrayRef<OpFoldResult> offsets,
- llvm::ArrayRef<OpFoldResult> shape,
- llvm::ArrayRef<OpFoldResult> strides) {
- assert(shape.size() && offsets.size() && strides.size() &&
- shape.size() == strides.size() && shape.size() == offsets.size());
-
llvm::SmallVector<int64_t> staticOffsets;
llvm::SmallVector<int64_t> staticShape;
llvm::SmallVector<int64_t> staticStrides;
- llvm::SmallVector<Value> dynamicOffsets;
- llvm::SmallVector<Value> dynamicShape;
- llvm::SmallVector<Value> dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
@@ -190,6 +168,17 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
+ if (memrefTy) {
+ auto memrefShape = memrefTy.getShape();
+ auto [memrefStrides, offset] = memrefTy.getStridesAndOffset();
+
+ // if shape and strides are from Memref, we don't need attributes for them
+ if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ staticShapeAttr = DenseI64ArrayAttr();
+ staticStridesAttr = DenseI64ArrayAttr();
+ }
+ }
+
build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/138701
More information about the Mlir-commits
mailing list