[Mlir-commits] [mlir] 20d6def - [mlir][xegpu] refine basic routines (#138701)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 9 07:41:25 PDT 2025


Author: Chao Chen
Date: 2025-05-09T09:41:21-05:00
New Revision: 20d6def0ae45b0c7ebcc1627d299689fa34e8cc8

URL: https://github.com/llvm/llvm-project/commit/20d6def0ae45b0c7ebcc1627d299689fa34e8cc8
DIFF: https://github.com/llvm/llvm-project/commit/20d6def0ae45b0c7ebcc1627d299689fa34e8cc8.diff

LOG: [mlir][xegpu] refine basic routines (#138701)

This PR adds two interfaces for `LayoutAttr` and 
updates the builder of `CreateNdOp` for convenience.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
    mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index f1bed70253ef3..6d04ee5599a23 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -243,8 +243,8 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
   );
 
   let builders = [
-    AttrBuilder<(ins "llvm::ArrayRef<int>": $lane_layout,
-                     "llvm::ArrayRef<int>": $lane_data),
+    AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
+                     "llvm::ArrayRef<int32_t>": $lane_data),
       [{
         auto sg_layout = DenseI32ArrayAttr();
         auto sg_data = DenseI32ArrayAttr();
@@ -253,6 +253,25 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
         return $_get($_ctxt, sg_layout, sg_data, inst_data,
                      DenseI32ArrayAttr::get($_ctxt, lane_layout),
                      DenseI32ArrayAttr::get($_ctxt, lane_data), order);
+      }]>,
+    AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
+                     "llvm::ArrayRef<int32_t>": $lane_data,
+                     "llvm::ArrayRef<int32_t>": $order),
+      [{
+        return $_get($_ctxt,
+                     /*sg_layout =*/ nullptr,
+                     /*sg_data   =*/ nullptr,
+                     /*inst_data =*/ nullptr,
+                     DenseI32ArrayAttr::get($_ctxt, lane_layout),
+                     DenseI32ArrayAttr::get($_ctxt, lane_data),
+                     DenseI32ArrayAttr::get($_ctxt, order));
+      }]>,
+    AttrBuilder<(ins "DenseI32ArrayAttr": $lane_layout,
+                     "DenseI32ArrayAttr": $lane_data,
+                     "DenseI32ArrayAttr": $order),
+      [{
+        return $_get($_ctxt, /*sg_layout =*/ nullptr, /*sg_data =*/ nullptr,
+                  /*inst_data =*/ nullptr, lane_layout, lane_data, order);
       }]>
   ];
 
@@ -262,7 +281,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
     }
 
     bool isSgLayout() {
-      return getSgLayout() == nullptr && getLaneLayout() != nullptr;
+      return !isWgLayout();
     }
 
     int64_t getRank() {
@@ -274,6 +293,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
         return attr.size();
       return 0;
     }
+
+    LayoutAttr dropSgLayoutAndData() {
+      return LayoutAttr::get(getContext(), nullptr, nullptr, getInstData(),
+                             getLaneLayout(), getLaneData(), getOrder());
+    }
+
+    LayoutAttr dropInstData() {
+      return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
+                             getLaneLayout(), getLaneData(), getOrder());
+    }
+
   }];
 
   let assemblyFormat = "`<` struct(params) `>`";

diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index a892f701f724e..238bb1567d301 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -142,12 +142,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
                    "llvm::ArrayRef<OpFoldResult>": $offsets)>,
 
-    OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
-                   "llvm::ArrayRef<OpFoldResult>": $offsets,
-                   "llvm::ArrayRef<OpFoldResult>": $shape,
-                   "llvm::ArrayRef<OpFoldResult>": $strides)>,
-
-    OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
+    OpBuilder<(ins "Type": $tdesc, "Value": $source,
                    "llvm::ArrayRef<OpFoldResult>": $offsets,
                    "llvm::ArrayRef<OpFoldResult>": $shape,
                    "llvm::ArrayRef<OpFoldResult>": $strides)>

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index f9d7e013826ed..f2cfa50e102f8 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -141,46 +141,24 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
 }
 
 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
-                           Type tdesc, TypedValue<MemRefType> source,
+                           Type tdesc, Value source,
                            llvm::ArrayRef<OpFoldResult> offsets,
                            llvm::ArrayRef<OpFoldResult> shape,
                            llvm::ArrayRef<OpFoldResult> strides) {
   assert(shape.size() && offsets.size() && strides.size() &&
          shape.size() == strides.size() && shape.size() == offsets.size());
 
-  llvm::SmallVector<int64_t> staticOffsets;
-  llvm::SmallVector<int64_t> staticShape;
-  llvm::SmallVector<int64_t> staticStrides;
+  Type srcTy = source.getType();
+  assert(isa<IntegerType>(srcTy) ||
+         isa<MemRefType>(srcTy) && "Source has to be either int or memref.");
+
   llvm::SmallVector<Value> dynamicOffsets;
   llvm::SmallVector<Value> dynamicShape;
   llvm::SmallVector<Value> dynamicStrides;
 
-  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
-  dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
-  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
-  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
-  auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
-  auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
-
-  build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
-        dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
-}
-
-void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
-                           Type tdesc, TypedValue<IntegerType> source,
-                           llvm::ArrayRef<OpFoldResult> offsets,
-                           llvm::ArrayRef<OpFoldResult> shape,
-                           llvm::ArrayRef<OpFoldResult> strides) {
-  assert(shape.size() && offsets.size() && strides.size() &&
-         shape.size() == strides.size() && shape.size() == offsets.size());
-
   llvm::SmallVector<int64_t> staticOffsets;
   llvm::SmallVector<int64_t> staticShape;
   llvm::SmallVector<int64_t> staticStrides;
-  llvm::SmallVector<Value> dynamicOffsets;
-  llvm::SmallVector<Value> dynamicShape;
-  llvm::SmallVector<Value> dynamicStrides;
 
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
@@ -190,6 +168,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
   auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
   auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
 
+  if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
+    auto memrefShape = memrefTy.getShape();
+    auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
+
+    // if shape and strides are from Memref, we don't need attributes for them
+    // to keep the IR print clean.
+    if (staticShape == memrefShape && staticStrides == memrefStrides) {
+      staticShapeAttr = DenseI64ArrayAttr();
+      staticStridesAttr = DenseI64ArrayAttr();
+    }
+  }
+
   build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
         dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
 }


        


More information about the Mlir-commits mailing list