[Mlir-commits] [mlir] 9fa55ec - [MLIR][XeGPU] Add sg_map attribute to support Work Item level semanti… (#110876)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 2 11:00:14 PDT 2024


Author: Petr Kurapov
Date: 2024-10-02T13:00:10-05:00
New Revision: 9fa55ec3d93435a790f9990b1c6565e3ee689b2c

URL: https://github.com/llvm/llvm-project/commit/9fa55ec3d93435a790f9990b1c6565e3ee689b2c
DIFF: https://github.com/llvm/llvm-project/commit/9fa55ec3d93435a790f9990b1c6565e3ee689b2c.diff

LOG: [MLIR][XeGPU] Add sg_map attribute to support Work Item level semanti… (#110876)

Bring back #108864

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
    mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
    mlir/test/Dialect/XeGPU/XeGPUOps.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 26eec0d4f2082a..2aaa7fd4221ab1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -142,4 +142,36 @@ def XeGPU_FenceScopeAttr:
     let assemblyFormat = "$value";
 }
 
+def XeGPU_SGMapAttr : XeGPUAttr<"SGMap", "sg_map"> {
+  let summary = [{
+    Describes the mapping between work item (WI) and the 2D tensor specified by the tensor descriptor.
+  }];
+  let description = [{
+    To distribute the XeGPU operation to work items, the tensor_desc must be specified with the sg_map
+    attribute at the tensor description creation time.
+    Within the `sg_map`, `wi_layout` specifies the layout of work items,
+    describing the mapping of work items to the tensor.
+    wi_layout[0] x wi_layout[1] must be equal to the total number of work items within a subgroup.
+    `wi_data` specifies the minimum number of data elements assigned to each work item for a single distribution.
+
+    E.g., #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+    In this example, the subgroup has 16 work items in wi_layout=[1, 16],
+    each accessing 1 element as specified by wi_data=[1, 1].
+
+    `wi_data[0] * wi_data[1]` can be greater than 1, meaning that each work item operates on multiple elements,
+    which is eventually lowered to "SIMT-flavor" vector, like SPIR-V vector or llvm vector, or packed to a storage data type.
+    The multiple elements indicated by `wi_data` can only be from one dimension and must be contiguous in the memory along either dimension.
+  }];
+  let parameters = (ins
+    ArrayRefParameter<"uint32_t">:$wi_layout,
+    ArrayRefParameter<"uint32_t">:$wi_data);
+
+  let builders = [
+    AttrBuilder<(ins)>
+  ];
+
+  let hasCustomAssemblyFormat = 1;
+  let genVerifyDecl = 1;
+}
+
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD

diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 0ce1211664b5ba..d09c5c1870d506 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -63,7 +63,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
     element-type ::= float-type | integer-type | index-type
     dim-list := (static-dim-list `x`)?
     static-dim-list ::= decimal-literal `x` decimal-literal
-    attr-list = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)?
+    attr-list = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)? (, sg_map `<` wi_layout = value, wi_data = value `>`)?
     ```
 
     Examples:
@@ -77,12 +77,16 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
 
     // A TensorDesc with 8x16 f32 elements for a memory region in shared memory space.
     xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_space = slm>>
+
+    // A TensorDesc with a sg_map
+    xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
     ```
   }];
 
   let parameters = (ins ArrayRefParameter<"int64_t">: $shape,
                         "mlir::Type": $elementType,
-                        OptionalParameter<"mlir::Attribute">: $encoding);
+                        OptionalParameter<"mlir::Attribute">: $encoding,
+                        OptionalParameter<"mlir::Attribute">: $sg_map);
 
   let builders = [
     TypeBuilderWithInferredContext<(ins
@@ -90,14 +94,16 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       "mlir::Type": $elementType,
       CArg<"int", "1">: $array_length,
       CArg<"bool", "true">: $boundary_check,
-      CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space)>,
+      CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
+      CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>,
     TypeBuilderWithInferredContext<(ins
       "llvm::ArrayRef<int64_t>": $shape,
       "mlir::Type": $elementType,
       CArg<"int", "1">: $chunk_size,
-      CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space)>
+      CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
+      CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>
   ];
-
+  
   let extraClassDeclaration = [{
     using TensorType::clone;
     using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
@@ -121,6 +127,10 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       return llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
     }
 
+    SGMapAttr getSGMapAttr() const {
+      return llvm::dyn_cast_if_present<SGMapAttr>(getSgMap());
+    }
+
     xegpu::MemorySpace getMemorySpace() const {
       auto block_attr = getEncodingAsBlockTensorDescAttr();
       if (block_attr && block_attr.getMemorySpace())

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 1dfbaed454c193..eb01b15de75c60 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -55,6 +55,77 @@ ScatterTensorDescAttr::get(mlir::MLIRContext *context,
   return Base::get(context, scopeAttr, chunkSizeAttr);
 }
 
+//===----------------------------------------------------------------------===//
+// XeGPU_SGMapAttr
+//===----------------------------------------------------------------------===//
+namespace {
+template <typename T, unsigned N>
+LogicalResult parseIntArrayField(::mlir::AsmParser &parser,
+                                 llvm::SmallVector<T, N> &result,
+                                 llvm::StringRef fieldName) {
+  if (failed(parser.parseKeyword(fieldName))) {
+    parser.emitError(parser.getCurrentLocation(),
+                     "unexpected field name. Expected " + fieldName + ".");
+    return failure();
+  }
+
+  if (failed(parser.parseEqual())) {
+    parser.emitError(parser.getCurrentLocation(), "expected '=' sign.");
+    return failure();
+  }
+
+  auto elemParser = [&]() -> llvm::ParseResult {
+    uint32_t elem = 0;
+    auto res = parser.parseInteger(elem);
+    result.push_back(elem);
+    return res;
+  };
+
+  return parser.parseCommaSeparatedList(AsmParser::Delimiter::Square,
+                                        elemParser, fieldName);
+}
+} // namespace
+
+mlir::Attribute SGMapAttr::parse(::mlir::AsmParser &parser,
+                                 ::mlir::Type attrType) {
+  if (failed(parser.parseLess()))
+    return {};
+
+  llvm::SmallVector<uint32_t, 2> wi_layout, wi_data;
+  if (failed(parseIntArrayField(parser, wi_layout, "wi_layout")))
+    return {};
+
+  if (failed(parser.parseComma()))
+    return {};
+
+  if (failed(parseIntArrayField(parser, wi_data, "wi_data")))
+    return {};
+
+  return SGMapAttr::getChecked(
+      [&]() { return parser.emitError(parser.getNameLoc()); },
+      parser.getContext(), wi_layout, wi_data);
+}
+
+void SGMapAttr::print(::mlir::AsmPrinter &printer) const {
+  printer << "<";
+  printer.printKeywordOrString("wi_layout");
+  printer << " = [" << getWiLayout() << "], ";
+  printer.printKeywordOrString("wi_data");
+  printer << " = [" << getWiData() << "]";
+  printer << ">";
+}
+
+LogicalResult
+SGMapAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
+                  llvm::ArrayRef<uint32_t> wi_layout,
+                  llvm::ArrayRef<uint32_t> wi_data) {
+  if (wi_layout.size() != 2)
+    return emitError() << "expected wi_layout of size 2";
+  if (wi_data.size() != 2)
+    return emitError() << "expected wi_data of size 2";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_TensorDescType
 //===----------------------------------------------------------------------===//
@@ -63,6 +134,7 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
   llvm::SmallVector<int64_t> shape;
   mlir::Type elementType;
   mlir::FailureOr<mlir::Attribute> encoding;
+  mlir::FailureOr<mlir::Attribute> sg_map;
 
   // Parse literal '<'
   if (parser.parseLess())
@@ -81,14 +153,22 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
   }
 
   // parse optional attributes
-  if (mlir::succeeded(parser.parseOptionalComma())) {
-    encoding = mlir::FieldParser<mlir::Attribute>::parse(parser);
-    if (mlir::failed(encoding)) {
-      parser.emitError(
-          parser.getCurrentLocation(),
-          "Failed to parse the attribute field for TensorDescType.\n");
-      return {};
+  while (mlir::succeeded(parser.parseOptionalComma())) {
+    mlir::Attribute attr;
+    ParseResult res = parser.parseAttribute(attr);
+    if (mlir::succeeded(res)) {
+      if (mlir::isa<SGMapAttr>(attr)) {
+        sg_map = attr;
+        continue;
+      }
+      if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
+        encoding = attr;
+        continue;
+      }
     }
+    parser.emitError(parser.getCurrentLocation(),
+                     "Failed to parse the attribute.\n");
+    return {};
   }
 
   // Parse literal '>'
@@ -96,7 +176,8 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
     return {};
 
   return TensorDescType::get(parser.getContext(), shape, elementType,
-                             encoding.value_or(mlir::Attribute()));
+                             encoding.value_or(mlir::Attribute()),
+                             sg_map.value_or(mlir::Attribute()));
 }
 
 void TensorDescType::print(::mlir::AsmPrinter &printer) const {
@@ -116,25 +197,30 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
   if (auto encoding = getEncoding())
     printer << ", " << encoding;
 
+  if (auto sg_map = getSgMap())
+    printer << ", " << sg_map;
+
   printer << ">";
 }
 
 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
                                    mlir::Type elementType, int array_length,
                                    bool boundary_check,
-                                   MemorySpace memory_space) {
+                                   MemorySpace memory_space,
+                                   mlir::Attribute sg_map) {
   auto context = elementType.getContext();
   auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
                                        boundary_check);
-  return Base::get(context, shape, elementType, attr);
+  return Base::get(context, shape, elementType, attr, sg_map);
 }
 
 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
                                    mlir::Type elementType, int chunk_size,
-                                   MemorySpace memory_space) {
+                                   MemorySpace memory_space,
+                                   mlir::Attribute sg_map) {
   auto context = elementType.getContext();
   auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
-  return Base::get(context, shape, elementType, attr);
+  return Base::get(context, shape, elementType, attr, sg_map);
 }
 
 } // namespace xegpu

diff  --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index 6db57aad773aa8..a4587faa3345cb 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -13,6 +13,14 @@ gpu.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_create_nd_tdesc_with_sg_map(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_create_nd_tdesc_with_sg_map(%src: memref<24x32xf32>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+    !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_create_nd_tdesc_vc_2(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) {
 gpu.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
   //CHECK: %[[C:.*]] = arith.constant 1 : index
@@ -43,6 +51,13 @@ gpu.func @test_create_nd_tdesc_vc_5(%src: memref<2x24x32xf32, 3>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_create_nd_tdesc_vc_6(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_create_nd_tdesc_vc_6(%src: memref<24x32xf32>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2>, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_prefetch_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @test_prefetch_nd_vc(%src: memref<24x32xf16>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -120,6 +135,15 @@ gpu.func @test_create_tdesc_vc_1(%src: memref<?xf32, 3>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_create_tdesc_vc_with_sg_map(%[[arg0:.*]]: ui64) {
+gpu.func @test_create_tdesc_vc_with_sg_map(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>, #xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) {
 gpu.func @test_prefetch_vc(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>


        


More information about the Mlir-commits mailing list