[Mlir-commits] [mlir] [mlir][xegpu] refine basic routines (PR #138701)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 6 07:58:54 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Chao Chen (chencha3)

<details>
<summary>Changes</summary>

This PR adds two interfaces for `LayoutAttr` and updates the builder of `CreateNdOp` for convenience. 


---
Full diff: https://github.com/llvm/llvm-project/pull/138701.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+34-1) 
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+1-6) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+16-27) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index f1bed70253ef3..b4236af497587 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -253,6 +253,28 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
         return $_get($_ctxt, sg_layout, sg_data, inst_data,
                      DenseI32ArrayAttr::get($_ctxt, lane_layout),
                      DenseI32ArrayAttr::get($_ctxt, lane_data), order);
+      }]>,
+    AttrBuilder<(ins "llvm::ArrayRef<int>": $lane_layout,
+                     "llvm::ArrayRef<int>": $lane_data,
+                     "llvm::ArrayRef<int>": $order),
+      [{
+        auto sg_layout = DenseI32ArrayAttr();
+        auto sg_data = DenseI32ArrayAttr();
+        auto inst_data = DenseI32ArrayAttr();
+        return $_get($_ctxt, sg_layout, sg_data, inst_data,
+                     DenseI32ArrayAttr::get($_ctxt, lane_layout),
+                     DenseI32ArrayAttr::get($_ctxt, lane_data),
+                     DenseI32ArrayAttr::get($_ctxt, order));
+      }]>,
+    AttrBuilder<(ins "DenseI32ArrayAttr": $lane_layout,
+                     "DenseI32ArrayAttr": $lane_data,
+                     "DenseI32ArrayAttr": $order),
+      [{
+        auto sg_layout = DenseI32ArrayAttr();
+        auto sg_data = DenseI32ArrayAttr();
+        auto inst_data = DenseI32ArrayAttr();
+        return $_get($_ctxt, sg_layout, sg_data, inst_data,
+                     lane_layout, lane_data, order);
       }]>
   ];
 
@@ -262,7 +284,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
     }
 
     bool isSgLayout() {
-      return getSgLayout() == nullptr && getLaneLayout() != nullptr;
+      return !isWgLayout();
     }
 
     int64_t getRank() {
@@ -274,6 +296,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
         return attr.size();
       return 0;
     }
+
+    LayoutAttr dropSgLayoutAndData() {
+      return LayoutAttr::get(getContext(), nullptr, nullptr, getInstData(),
+                             getLaneLayout(), getLaneData(), getOrder());
+    }
+
+    LayoutAttr dropInstData() {
+      return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
+                             getLaneLayout(), getLaneData(), getOrder());
+    }
+
   }];
 
   let assemblyFormat = "`<` struct(params) `>`";
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 5fa18754305ca..627de858d94aa 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -142,12 +142,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
                    "llvm::ArrayRef<OpFoldResult>": $offsets)>,
 
-    OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
-                   "llvm::ArrayRef<OpFoldResult>": $offsets,
-                   "llvm::ArrayRef<OpFoldResult>": $shape,
-                   "llvm::ArrayRef<OpFoldResult>": $strides)>,
-
-    OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
+    OpBuilder<(ins "Type": $tdesc, "Value": $source,
                    "llvm::ArrayRef<OpFoldResult>": $offsets,
                    "llvm::ArrayRef<OpFoldResult>": $shape,
                    "llvm::ArrayRef<OpFoldResult>": $strides)>
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index f9d7e013826ed..7df6c794615a1 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -141,46 +141,24 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
 }
 
 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
-                           Type tdesc, TypedValue<MemRefType> source,
+                           Type tdesc, Value source,
                            llvm::ArrayRef<OpFoldResult> offsets,
                            llvm::ArrayRef<OpFoldResult> shape,
                            llvm::ArrayRef<OpFoldResult> strides) {
   assert(shape.size() && offsets.size() && strides.size() &&
          shape.size() == strides.size() && shape.size() == offsets.size());
 
-  llvm::SmallVector<int64_t> staticOffsets;
-  llvm::SmallVector<int64_t> staticShape;
-  llvm::SmallVector<int64_t> staticStrides;
+  auto intTy = dyn_cast<IntegerType>(source.getType());
+  auto memrefTy = dyn_cast<MemRefType>(source.getType());
+  assert(intTy || memrefTy && "Source has to be either int or memref.");
+
   llvm::SmallVector<Value> dynamicOffsets;
   llvm::SmallVector<Value> dynamicShape;
   llvm::SmallVector<Value> dynamicStrides;
 
-  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
-  dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
-  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
-  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
-  auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
-  auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
-
-  build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
-        dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
-}
-
-void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
-                           Type tdesc, TypedValue<IntegerType> source,
-                           llvm::ArrayRef<OpFoldResult> offsets,
-                           llvm::ArrayRef<OpFoldResult> shape,
-                           llvm::ArrayRef<OpFoldResult> strides) {
-  assert(shape.size() && offsets.size() && strides.size() &&
-         shape.size() == strides.size() && shape.size() == offsets.size());
-
   llvm::SmallVector<int64_t> staticOffsets;
   llvm::SmallVector<int64_t> staticShape;
   llvm::SmallVector<int64_t> staticStrides;
-  llvm::SmallVector<Value> dynamicOffsets;
-  llvm::SmallVector<Value> dynamicShape;
-  llvm::SmallVector<Value> dynamicStrides;
 
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
@@ -190,6 +168,17 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
   auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
   auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
 
+  if (memrefTy) {
+    auto memrefShape = memrefTy.getShape();
+    auto [memrefStrides, offset] = memrefTy.getStridesAndOffset();
+
+    // if shape and strides are from Memref, we don't need attributes for them
+    if (staticShape == memrefShape && staticStrides == memrefStrides) {
+      staticShapeAttr = DenseI64ArrayAttr();
+      staticStridesAttr = DenseI64ArrayAttr();
+    }
+  }
+
   build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
         dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/138701


More information about the Mlir-commits mailing list