[Mlir-commits] [mlir] [MLIR][GEN] Add GEN dialect (PR #88734)

Victor Perez llvmlistbot at llvm.org
Tue Apr 16 03:00:39 PDT 2024


https://github.com/victor-eds updated https://github.com/llvm/llvm-project/pull/88734

>From 4969ef9cb11ebca8bc3b5e17a481fa62ffe7e101 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Mon, 15 Apr 2024 11:08:04 +0100
Subject: [PATCH 1/4] [MLIR][GEN] Add GEN dialect

Add GEN dialect to represent operations on Intel GPUs. GEN will offer
six initial operations:

- `gen.local_id`: query a work-item's local id
- `gen.work_group_id`: query the id of a work-item's work-group
- `gen.work_group_size`: query the size of a work-item's work-group
- `gen.num_work_groups`: query the number of work-groups
- `gen.barrier`: work-group barrier
- `gen.sub_group_shuffle`: sub-group shuffle

Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
 mlir/include/mlir/Dialect/CMakeLists.txt      |   1 +
 mlir/include/mlir/Dialect/GEN/CMakeLists.txt  |   1 +
 .../mlir/Dialect/GEN/IR/CMakeLists.txt        |  13 ++
 .../mlir/Dialect/GEN/IR/GENAttrDefs.td        |  26 ++++
 mlir/include/mlir/Dialect/GEN/IR/GENDialect.h |  20 +++
 .../include/mlir/Dialect/GEN/IR/GENDialect.td |  24 ++++
 mlir/include/mlir/Dialect/GEN/IR/GENOps.h     |  26 ++++
 mlir/include/mlir/Dialect/GEN/IR/GENOps.td    | 133 ++++++++++++++++++
 mlir/include/mlir/Dialect/GEN/IR/GENTraits.h  |  32 +++++
 mlir/include/mlir/InitAllDialects.h           |   2 +
 mlir/lib/Dialect/CMakeLists.txt               |   1 +
 mlir/lib/Dialect/GEN/CMakeLists.txt           |   1 +
 mlir/lib/Dialect/GEN/IR/CMakeLists.txt        |  16 +++
 mlir/lib/Dialect/GEN/IR/GENDialect.cpp        |  34 +++++
 mlir/lib/Dialect/GEN/IR/GENOps.cpp            |  24 ++++
 mlir/lib/Dialect/GEN/IR/GENTraits.cpp         |  24 ++++
 mlir/test/Dialect/GEN/invalid.mlir            |  17 +++
 mlir/test/Dialect/GEN/ops.mlir                |  26 ++++
 18 files changed, 421 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/GEN/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/GEN/IR/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/GEN/IR/GENAttrDefs.td
 create mode 100644 mlir/include/mlir/Dialect/GEN/IR/GENDialect.h
 create mode 100644 mlir/include/mlir/Dialect/GEN/IR/GENDialect.td
 create mode 100644 mlir/include/mlir/Dialect/GEN/IR/GENOps.h
 create mode 100644 mlir/include/mlir/Dialect/GEN/IR/GENOps.td
 create mode 100644 mlir/include/mlir/Dialect/GEN/IR/GENTraits.h
 create mode 100644 mlir/lib/Dialect/GEN/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/GEN/IR/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/GEN/IR/GENDialect.cpp
 create mode 100644 mlir/lib/Dialect/GEN/IR/GENOps.cpp
 create mode 100644 mlir/lib/Dialect/GEN/IR/GENTraits.cpp
 create mode 100644 mlir/test/Dialect/GEN/invalid.mlir
 create mode 100644 mlir/test/Dialect/GEN/ops.mlir

diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 2da79011fa26a3..a775532996e1e5 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -12,6 +12,7 @@ add_subdirectory(ControlFlow)
 add_subdirectory(DLTI)
 add_subdirectory(EmitC)
 add_subdirectory(Func)
+add_subdirectory(GEN)
 add_subdirectory(GPU)
 add_subdirectory(Index)
 add_subdirectory(IRDL)
diff --git a/mlir/include/mlir/Dialect/GEN/CMakeLists.txt b/mlir/include/mlir/Dialect/GEN/CMakeLists.txt
new file mode 100644
index 00000000000000..f33061b2d87cff
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/include/mlir/Dialect/GEN/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/GEN/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..b441e52a143487
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect(GENOps gen)
+add_mlir_doc(GENDialect GENDialect Dialects/ -gen-dialect-doc)
+add_mlir_doc(GENOps GENOps Dialects/ -gen-op-doc)
+
+set(LLVM_TARGET_DEFINITIONS GENOps.td)
+mlir_tablegen(GENOpsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(GENOpsEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRGENOpsEnumsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS GENAttrDefs.td)
+mlir_tablegen(GENOpsAttrDefs.h.inc -gen-attrdef-decls)
+mlir_tablegen(GENOpsAttrDefs.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRGENOpsAttrDefsIncGen)
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENAttrDefs.td b/mlir/include/mlir/Dialect/GEN/IR/GENAttrDefs.td
new file mode 100644
index 00000000000000..ef9de1cf581634
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENAttrDefs.td
@@ -0,0 +1,26 @@
+//===-- GENAttrDefs.td - GEN dialect attributes def. file --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef GEN_ATTRDEFS
+#define GEN_ATTRDEFS
+
+include "mlir/IR/EnumAttr.td"
+
+/// Enum attribute of the different shuffle kinds.
+/// Based on https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_non_uniform_instructions
+def GEN_ShflKindAttr : I32EnumAttr<"ShflKind", "GEN shuffle kind",
+  [
+    I32EnumAttrCase<"XOR",  0, "xor">,
+    I32EnumAttrCase<"UP",   1, "up">,
+    I32EnumAttrCase<"DOWN", 2, "down">,
+    I32EnumAttrCase<"IDX",  3, "idx">
+  ]> {
+  let cppNamespace = "::mlir::GEN";
+}
+
+#endif // GEN_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENDialect.h b/mlir/include/mlir/Dialect/GEN/IR/GENDialect.h
new file mode 100644
index 00000000000000..abf08bf229d138
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENDialect.h
@@ -0,0 +1,20 @@
+//===- GENDialect.h - MLIR GEN dialect --------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the GEN dialect in MLIR, containing Intel GEN operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_GEN_IR_GENDIALECT_H
+#define MLIR_DIALECT_GEN_IR_GENDIALECT_H
+
+#include "mlir/IR/Dialect.h"
+
+#include "mlir/Dialect/GEN/IR/GENOpsDialect.h.inc"
+
+#endif // MLIR_DIALECT_GEN_IR_GENDIALECT_H
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENDialect.td b/mlir/include/mlir/Dialect/GEN/IR/GENDialect.td
new file mode 100644
index 00000000000000..738f6cede89d16
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENDialect.td
@@ -0,0 +1,24 @@
+//===-- GENDialect.td - GEN dialect op definition file -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef GEN_DIALECT
+#define GEN_DIALECT
+
+include "mlir/IR/OpBase.td"
+
+def GEN_Dialect : Dialect {
+  let name = "gen";
+  let cppNamespace = "::mlir::GEN";
+  let summary = "The GEN dialect.";
+
+  let description = [{
+    GEN is a dialect for representing operations on Intel GPUs.
+  }];
+}
+
+#endif // GEN_DIALECT
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENOps.h b/mlir/include/mlir/Dialect/GEN/IR/GENOps.h
new file mode 100644
index 00000000000000..be81cd111de7e3
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENOps.h
@@ -0,0 +1,26 @@
+//===--- GENOps.h - GEN Dialect Operations ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_GEN_IR_GENOPS_H
+#define MLIR_DIALECT_GEN_IR_GENOPS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include "mlir/Dialect/GEN/IR/GENOpsEnums.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/GEN/IR/GENOpsAttrDefs.h.inc"
+
+#include "mlir/Dialect/GEN/IR/GENTraits.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/GEN/IR/GENOps.h.inc"
+
+#endif // MLIR_DIALECT_GEN_IR_GENOPS_H
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENOps.td b/mlir/include/mlir/Dialect/GEN/IR/GENOps.td
new file mode 100644
index 00000000000000..afe72fae722bef
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENOps.td
@@ -0,0 +1,133 @@
+//===-- GENOps.td - GEN IR dialect op definition file ------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the GEN IR operation definition file.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef GEN_OPS
+#define GEN_OPS
+
+include "mlir/Dialect/GEN/IR/GENDialect.td"
+include "mlir/Dialect/GEN/IR/GENAttrDefs.td"
+include "mlir/IR/OpBase.td"
+include "mlir/IR/EnumAttr.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpAsmInterface.td"
+
+//===----------------------------------------------------------------------===//
+// GEN op definitions
+//===----------------------------------------------------------------------===//
+
+class GEN_Op<string mnemonic, list<Trait> traits = []> :
+    Op<GEN_Dialect, mnemonic, traits>;
+
+class GENOpTrait<string name, list<Trait> traits = [],
+                 code extraOpDeclaration = [{}],
+                 code extraOpDefinition = [{}]>
+    : NativeOpTrait<name, traits, extraOpDeclaration, extraOpDefinition> {
+  let cppNamespace = "::mlir::OpTrait::GEN";
+}
+
+//===----------------------------------------------------------------------===//
+// ND-Range Operations
+//===----------------------------------------------------------------------===//
+
+def GEN3DNDRange : GENOpTrait<"GEN3DNDRange">;
+
+class GEN_3DNDRangeOp<string mnemonic, list<Trait> traits = []>
+    : GEN_Op<mnemonic, [GEN3DNDRange, Pure] # traits>,
+      Arguments<(ins I32:$dim)>,
+      Results<(outs Index:$res)> {
+  let assemblyFormat = "$dim attr-dict";
+}
+
+def GEN_LocalIdOp : GEN_3DNDRangeOp<"local_id"> {
+  let summary = "Query a work-item's local id.";
+  let description = [{
+    Query the work-item's position in its work-group, i.e., its local id, in a
+    given dimension.
+    ```mlir
+    %local_id = gen.local_id %dim
+    ```
+  }];
+}
+
+def GEN_WorkGroupIdOp : GEN_3DNDRangeOp<"work_group_id"> {
+  let summary = "Query the id of a work-item's work-group.";
+  let description = [{
+    Query the id of a work-item's work-group in a given dimension.
+    ```mlir
+    %work_group_id = gen.work_group_id %dim
+    ```
+  }];
+}
+
+def GEN_WorkGroupSizeOp : GEN_3DNDRangeOp<"work_group_size"> {
+  let summary = "Query the work-group size.";
+  let description = [{
+    Query the work-item's work-group size in a given dimension.
+    ```mlir
+    %work_group_size = gen.work_group_size %dim
+    ```
+  }];
+}
+
+def GEN_NumWorkGroupsOp : GEN_3DNDRangeOp<"num_work_groups"> {
+  let summary = "Query the number of work-groups in the ND-range.";
+  let description = [{
+    Query the number of work-groups in the ND-range in a given dimension.
+    ```mlir
+    %wg_number = gen.num_work_groups %dim
+    ```
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Synchronization
+//===----------------------------------------------------------------------===//
+
+def GEN_BarrierOp : GEN_Op<"barrier"> {
+  let summary = "Work-group barrier";
+
+  string baseDescription = [{
+    The `gen.barrier` operation performs a work-group barrier and ensures all
+    outstanding memory transaction using local or global memory are complete.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
+def IntegerOrFloatType : AnyTypeOf<[AnySignlessInteger, AnyFloat]>;
+
+def GEN_SubGroupShuffleOp : GEN_Op<"sub_group_shuffle", [
+      TypesMatchWith<"result and value have the same type",
+                     "res", "value", "$_self">]>,
+  Results<(outs IntegerOrFloatType:$res)>,
+  Arguments<(ins IntegerOrFloatType:$value,
+                 I32:$mask,
+                 GEN_ShflKindAttr:$kind)> {
+  let summary = "Sub-group shuffle";
+  string baseDescription = [{
+    The `gen.sub_group_shuffle` operation is invoked by different work items
+    with different values, given by $value. Different work items have different
+    sub-group local IDs. The shuffle kind, $kind, is given to determine how to
+    calculate the associated sub-group local ID. It returns the associated
+    $value for the work item with sub-group local ID equal to:
+    - $kind == xor, the current invocation’s sub-group local ID xor’ed with $mask.
+    - $kind == up, the current invocation’s sub-group local ID - $mask.
+    - $kind == down, the current invocation’s sub-group local ID + $mask.
+    - $kind == idx, the sub-group local ID $mask.
+  }];
+
+  let assemblyFormat = [{
+    $kind $value `,` $mask attr-dict `:` type($res)
+  }];
+}
+
+#endif // GEN_OPS
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENTraits.h b/mlir/include/mlir/Dialect/GEN/IR/GENTraits.h
new file mode 100644
index 00000000000000..11ae8014f8ad7f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENTraits.h
@@ -0,0 +1,32 @@
+//===--- GENTraits.h - GEN Dialect Traits -----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_GEN_IR_GENTRAITS_H
+#define MLIR_DIALECT_GEN_IR_GENTRAITS_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace OpTrait {
+namespace GEN {
+namespace detail {
+LogicalResult verifyGEN3DNDRange(Operation *op);
+} // namespace detail
+
+template <typename ConcreteType>
+class GEN3DNDRange : public TraitBase<ConcreteType, GEN3DNDRange> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return detail::verifyGEN3DNDRange(op);
+  }
+};
+} // namespace GEN
+} // namespace OpTrait
+} // namespace mlir
+
+#endif // MLIR_DIALECT_GEN_IR_GENTRAITS_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index c558dc53cc7fac..a940d134024165 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -36,6 +36,7 @@
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GEN/IR/GENDialect.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Transforms/BufferDeallocationOpInterfaceImpl.h"
 #include "mlir/Dialect/IRDL/IR/IRDL.h"
@@ -116,6 +117,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   DLTIDialect,
                   emitc::EmitCDialect,
                   func::FuncDialect,
+                  GEN::GENDialect,
                   gpu::GPUDialect,
                   index::IndexDialect,
                   irdl::IRDLDialect,
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index b1ba5a3bc8817d..72e11b4eb33dfa 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -12,6 +12,7 @@ add_subdirectory(ControlFlow)
 add_subdirectory(DLTI)
 add_subdirectory(EmitC)
 add_subdirectory(Func)
+add_subdirectory(GEN)
 add_subdirectory(GPU)
 add_subdirectory(Index)
 add_subdirectory(IRDL)
diff --git a/mlir/lib/Dialect/GEN/CMakeLists.txt b/mlir/lib/Dialect/GEN/CMakeLists.txt
new file mode 100644
index 00000000000000..f33061b2d87cff
--- /dev/null
+++ b/mlir/lib/Dialect/GEN/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/GEN/IR/CMakeLists.txt b/mlir/lib/Dialect/GEN/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..8b402987d5008f
--- /dev/null
+++ b/mlir/lib/Dialect/GEN/IR/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_dialect_library(MLIRGENDialect
+  GENDialect.cpp
+  GENOps.cpp
+  GENTraits.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GEN
+
+  DEPENDS
+  MLIRGENOpsIncGen
+  MLIRGENOpsEnumsIncGen
+  MLIRGENOpsAttrDefsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+)
diff --git a/mlir/lib/Dialect/GEN/IR/GENDialect.cpp b/mlir/lib/Dialect/GEN/IR/GENDialect.cpp
new file mode 100644
index 00000000000000..801bd499a4256c
--- /dev/null
+++ b/mlir/lib/Dialect/GEN/IR/GENDialect.cpp
@@ -0,0 +1,34 @@
+//===- GENDialect.cpp - MLIR GEN Dialect implementation -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GEN/IR/GENDialect.h"
+
+#include "mlir/Dialect/GEN/IR/GENOps.h"
+#include "mlir/IR/DialectImplementation.h"
+
+using namespace mlir;
+using namespace mlir::GEN;
+
+#include "mlir/Dialect/GEN/IR/GENOpsDialect.cpp.inc"
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/GEN/IR/GENOpsAttrDefs.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// GEN dialect.
+//===----------------------------------------------------------------------===//
+
+void GENDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/GEN/IR/GENOps.cpp.inc"
+      >();
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/GEN/IR/GENOpsAttrDefs.cpp.inc"
+      >();
+}
diff --git a/mlir/lib/Dialect/GEN/IR/GENOps.cpp b/mlir/lib/Dialect/GEN/IR/GENOps.cpp
new file mode 100644
index 00000000000000..da8a5075cc56fb
--- /dev/null
+++ b/mlir/lib/Dialect/GEN/IR/GENOps.cpp
@@ -0,0 +1,24 @@
+//===- GENOps.cpp - GEN dialect operations --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GEN/IR/GENOps.h"
+
+#include "mlir/IR/Builders.h"
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/GEN/IR/GENOps.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// TableGen'd enum attribute definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GEN/IR/GENOpsEnums.cpp.inc"
diff --git a/mlir/lib/Dialect/GEN/IR/GENTraits.cpp b/mlir/lib/Dialect/GEN/IR/GENTraits.cpp
new file mode 100644
index 00000000000000..6115aabff289c1
--- /dev/null
+++ b/mlir/lib/Dialect/GEN/IR/GENTraits.cpp
@@ -0,0 +1,24 @@
+//===- GENTraits.cpp - GEN dialect traits ---------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GEN/IR/GENTraits.h"
+
+#include "mlir/IR/Matchers.h"
+
+using namespace mlir;
+
+LogicalResult mlir::OpTrait::GEN::detail::verifyGEN3DNDRange(Operation *op) {
+  llvm::APInt value;
+  if (matchPattern(op->getOperand(0), m_ConstantInt(&value)) &&
+      !(/*value in [0, 3)*/ value.sge(0) && value.slt(3))) {
+    return op->emitOpError()
+           << "input dimension must be in the range [0, 3). Got "
+           << value.getSExtValue();
+  }
+  return success();
+}
diff --git a/mlir/test/Dialect/GEN/invalid.mlir b/mlir/test/Dialect/GEN/invalid.mlir
new file mode 100644
index 00000000000000..1285f8af216647
--- /dev/null
+++ b/mlir/test/Dialect/GEN/invalid.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt -split-input-file %s -verify-diagnostics
+
+func.func @test_3d_nd_range_bounds_low() {
+  %c-1 = arith.constant -1 : i32
+  // expected-error @below {{'gen.local_id' op input dimension must be in the range [0, 3). Got -1}}
+  %0 = gen.local_id %c-1
+  func.return
+}
+
+// -----
+
+func.func @test_3d_nd_range_bounds_high() {
+  %c3 = arith.constant 3 : i32
+  // expected-error @below {{'gen.work_group_id' op input dimension must be in the range [0, 3). Got 3}}
+  %0 = gen.work_group_id %c3
+  func.return
+}
diff --git a/mlir/test/Dialect/GEN/ops.mlir b/mlir/test/Dialect/GEN/ops.mlir
new file mode 100644
index 00000000000000..a5a1191cd10698
--- /dev/null
+++ b/mlir/test/Dialect/GEN/ops.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: test_nd_range
+func.func @test_nd_range(%dim: i32) {
+  %0 = gen.local_id %dim
+  %1 = gen.work_group_id %dim
+  %2 = gen.work_group_size %dim
+  %3 = gen.num_work_groups %dim
+  return
+}
+
+// CHECK-LABEL: test_barrier
+func.func @test_barrier() {
+  gen.barrier
+  return
+}
+
+// CHECK-LABEL: test_sub_group_shuffle
+func.func @test_sub_group_shuffle(%arg0: i32, %arg1: i64, %arg2: f32, %arg3: f64, %arg4: i32) {
+  %0 = gen.sub_group_shuffle xor %arg0, %arg4 : i32
+  %1 = gen.sub_group_shuffle up %arg1, %arg4 : i64
+  %2 = gen.sub_group_shuffle down %arg2, %arg4 : f32
+  %3 = gen.sub_group_shuffle idx %arg3, %arg4 : f64
+  return
+}

>From aafd18d4ec8dc9eff5482a40a6ae4e51d4cd83f2 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Mon, 15 Apr 2024 15:29:50 +0100
Subject: [PATCH 2/4] Use `description`

---
 mlir/include/mlir/Dialect/GEN/IR/GENOps.td | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENOps.td b/mlir/include/mlir/Dialect/GEN/IR/GENOps.td
index afe72fae722bef..ed097e6cdc0b23 100644
--- a/mlir/include/mlir/Dialect/GEN/IR/GENOps.td
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENOps.td
@@ -95,7 +95,7 @@ def GEN_NumWorkGroupsOp : GEN_3DNDRangeOp<"num_work_groups"> {
 def GEN_BarrierOp : GEN_Op<"barrier"> {
   let summary = "Work-group barrier";
 
-  string baseDescription = [{
+  let description = [{
     The `gen.barrier` operation performs a work-group barrier and ensures all
     outstanding memory transaction using local or global memory are complete.
   }];
@@ -113,7 +113,7 @@ def GEN_SubGroupShuffleOp : GEN_Op<"sub_group_shuffle", [
                  I32:$mask,
                  GEN_ShflKindAttr:$kind)> {
   let summary = "Sub-group shuffle";
-  string baseDescription = [{
+  let description = [{
     The `gen.sub_group_shuffle` operation is invoked by different work items
     with different values, given by $value. Different work items have different
     sub-group local IDs. The shuffle kind, $kind, is given to determine how to

>From fc8f883adbfc09832a61af609c1fb329744f774b Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Mon, 15 Apr 2024 15:37:53 +0100
Subject: [PATCH 3/4] Modify `gen.sub_group_shuffle` traits list

---
 mlir/include/mlir/Dialect/GEN/IR/GENOps.td | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENOps.td b/mlir/include/mlir/Dialect/GEN/IR/GENOps.td
index ed097e6cdc0b23..c452d9f4627276 100644
--- a/mlir/include/mlir/Dialect/GEN/IR/GENOps.td
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENOps.td
@@ -105,9 +105,8 @@ def GEN_BarrierOp : GEN_Op<"barrier"> {
 
 def IntegerOrFloatType : AnyTypeOf<[AnySignlessInteger, AnyFloat]>;
 
-def GEN_SubGroupShuffleOp : GEN_Op<"sub_group_shuffle", [
-      TypesMatchWith<"result and value have the same type",
-                     "res", "value", "$_self">]>,
+def GEN_SubGroupShuffleOp
+    : GEN_Op<"sub_group_shuffle", [Pure, AllTypesMatch<["res", "value"]>]>,
   Results<(outs IntegerOrFloatType:$res)>,
   Arguments<(ins IntegerOrFloatType:$value,
                  I32:$mask,

>From 9d5d34de90c36c05a3c8d11fd334f8384052b15d Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Tue, 16 Apr 2024 11:00:20 +0100
Subject: [PATCH 4/4] Address comments

---
 mlir/include/mlir/Dialect/GEN/IR/GENOps.h    |  2 -
 mlir/include/mlir/Dialect/GEN/IR/GENOps.td   | 98 +++++++++++++-------
 mlir/include/mlir/Dialect/GEN/IR/GENTraits.h | 32 -------
 mlir/lib/Dialect/GEN/IR/CMakeLists.txt       |  1 -
 mlir/lib/Dialect/GEN/IR/GENTraits.cpp        | 24 -----
 mlir/test/Dialect/GEN/invalid.mlir           | 17 ----
 6 files changed, 66 insertions(+), 108 deletions(-)
 delete mode 100644 mlir/include/mlir/Dialect/GEN/IR/GENTraits.h
 delete mode 100644 mlir/lib/Dialect/GEN/IR/GENTraits.cpp
 delete mode 100644 mlir/test/Dialect/GEN/invalid.mlir

diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENOps.h b/mlir/include/mlir/Dialect/GEN/IR/GENOps.h
index be81cd111de7e3..c109e436949550 100644
--- a/mlir/include/mlir/Dialect/GEN/IR/GENOps.h
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENOps.h
@@ -18,8 +18,6 @@
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/GEN/IR/GENOpsAttrDefs.h.inc"
 
-#include "mlir/Dialect/GEN/IR/GENTraits.h"
-
 #define GET_OP_CLASSES
 #include "mlir/Dialect/GEN/IR/GENOps.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENOps.td b/mlir/include/mlir/Dialect/GEN/IR/GENOps.td
index c452d9f4627276..92433b1699cf04 100644
--- a/mlir/include/mlir/Dialect/GEN/IR/GENOps.td
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENOps.td
@@ -27,23 +27,15 @@ include "mlir/IR/OpAsmInterface.td"
 class GEN_Op<string mnemonic, list<Trait> traits = []> :
     Op<GEN_Dialect, mnemonic, traits>;
 
-class GENOpTrait<string name, list<Trait> traits = [],
-                 code extraOpDeclaration = [{}],
-                 code extraOpDefinition = [{}]>
-    : NativeOpTrait<name, traits, extraOpDeclaration, extraOpDefinition> {
-  let cppNamespace = "::mlir::OpTrait::GEN";
-}
-
 //===----------------------------------------------------------------------===//
 // ND-Range Operations
 //===----------------------------------------------------------------------===//
 
-def GEN3DNDRange : GENOpTrait<"GEN3DNDRange">;
-
 class GEN_3DNDRangeOp<string mnemonic, list<Trait> traits = []>
-    : GEN_Op<mnemonic, [GEN3DNDRange, Pure] # traits>,
-      Arguments<(ins I32:$dim)>,
-      Results<(outs Index:$res)> {
+    : GEN_Op<mnemonic, [Pure] # traits> {
+  let arguments = (ins I32:$dim);
+  let results = (outs Index:$res);
+
   let assemblyFormat = "$dim attr-dict";
 }
 
@@ -51,7 +43,10 @@ def GEN_LocalIdOp : GEN_3DNDRangeOp<"local_id"> {
   let summary = "Query a work-item's local id.";
   let description = [{
     Query the work-item's position in its work-group, i.e., its local id, in a
-    given dimension.
+    given dimension `dim`, which must be either 0, 1 or 2. Behavior is undefined
+    for invalid `dim` values.
+
+    Example:
     ```mlir
     %local_id = gen.local_id %dim
     ```
@@ -61,7 +56,10 @@ def GEN_LocalIdOp : GEN_3DNDRangeOp<"local_id"> {
 def GEN_WorkGroupIdOp : GEN_3DNDRangeOp<"work_group_id"> {
   let summary = "Query the id of a work-item's work-group.";
   let description = [{
-    Query the id of a work-item's work-group in a given dimension.
+    Query the id of a work-item's work-group in a given dimension `dim`, which
+    must be either 0, 1 or 2. Behavior is undefined for invalid `dim` values.
+
+    Example:
     ```mlir
     %work_group_id = gen.work_group_id %dim
     ```
@@ -71,7 +69,10 @@ def GEN_WorkGroupIdOp : GEN_3DNDRangeOp<"work_group_id"> {
 def GEN_WorkGroupSizeOp : GEN_3DNDRangeOp<"work_group_size"> {
   let summary = "Query the work-group size.";
   let description = [{
-    Query the work-item's work-group size in a given dimension.
+    Query the work-item's work-group size in a given dimension `dim`, which must
+    be either 0, 1 or 2. Behavior is undefined for invalid `dim` values.
+
+    Example:
     ```mlir
     %work_group_size = gen.work_group_size %dim
     ```
@@ -81,7 +82,11 @@ def GEN_WorkGroupSizeOp : GEN_3DNDRangeOp<"work_group_size"> {
 def GEN_NumWorkGroupsOp : GEN_3DNDRangeOp<"num_work_groups"> {
   let summary = "Query the number of work-groups in the ND-range.";
   let description = [{
-    Query the number of work-groups in the ND-range in a given dimension.
+    Query the number of work-groups in the ND-range in a given dimension `dim`,
+    which must be either 0, 1 or 2. Behavior is undefined for invalid `dim`
+    values.
+
+    Example:
     ```mlir
     %wg_number = gen.num_work_groups %dim
     ```
@@ -93,37 +98,66 @@ def GEN_NumWorkGroupsOp : GEN_3DNDRangeOp<"num_work_groups"> {
 //===----------------------------------------------------------------------===//
 
 def GEN_BarrierOp : GEN_Op<"barrier"> {
-  let summary = "Work-group barrier";
+  let summary = "Synchronizes all work-items of a work-group.";
 
   let description = [{
-    The `gen.barrier` operation performs a work-group barrier and ensures all
-    outstanding memory transaction using local or global memory are complete.
+    Wait for all work-items in a given work-group to reach this execution point.
+    All memory accesses made by these work-items prior to the operation are
+    visible to all work-items in the work-group.
+
+    It is undefined behavior unless none or all work-items in the work-group
+    reach this execution point.
+
+    Example:
+    ```mlir
+    gen.barrier
+    ```
   }];
 
   let assemblyFormat = "attr-dict";
 }
 
-def IntegerOrFloatType : AnyTypeOf<[AnySignlessInteger, AnyFloat]>;
+def ShuffleValueType
+    : AnyTypeOf<[SignlessIntOfWidths<[8, 16, 32, 64]>, FloatOfWidths<[16, 32, 64]>]>;
 
 def GEN_SubGroupShuffleOp
-    : GEN_Op<"sub_group_shuffle", [Pure, AllTypesMatch<["res", "value"]>]>,
-  Results<(outs IntegerOrFloatType:$res)>,
-  Arguments<(ins IntegerOrFloatType:$value,
-                 I32:$mask,
-                 GEN_ShflKindAttr:$kind)> {
+    : GEN_Op<"sub_group_shuffle", [Pure, AllTypesMatch<["res", "value"]>]> {
   let summary = "Sub-group shuffle";
   let description = [{
     The `gen.sub_group_shuffle` operation is invoked by different work items
-    with different values, given by $value. Different work items have different
-    sub-group local IDs. The shuffle kind, $kind, is given to determine how to
+    with different values, given by `value`. Different work items have different
+    sub-group local IDs. The shuffle kind, `kind`, is given to determine how to
     calculate the associated sub-group local ID. It returns the associated
-    $value for the work item with sub-group local ID equal to:
-    - $kind == xor, the current invocation’s sub-group local ID xor’ed with $mask.
-    - $kind == up, the current invocation’s sub-group local ID - $mask.
-    - $kind == down, the current invocation’s sub-group local ID + $mask.
-    - $kind == idx, the sub-group local ID $mask.
+    `value` for the work item with sub-group local ID equal to:
+    - `kind` == xor, the current invocation’s sub-group local ID xor'ed with `mask`.
+    - `kind` == up, the current invocation’s sub-group local ID - `mask`.
+    - `kind` == down, the current invocation’s sub-group local ID + `mask`.
+    - `kind` == idx, the sub-group local ID `mask`.
+
+    `value` and `res` types must match and can be any of: `i8`, `i16`, `i32`,
+    `i64`, `f16`, `f32` or `f64`.
+
+    Example:
+    ```mlir
+    // xor shuffle
+    %0 = gen.sub_group_shuffle xor %arg0, %arg4 : i32
+
+    // up shuffle
+    %1 = gen.sub_group_shuffle up %arg1, %arg4 : i64
+
+    // down shuffle
+    %2 = gen.sub_group_shuffle down %arg2, %arg4 : f32
+
+    // idx shuffle
+    %3 = gen.sub_group_shuffle idx %arg3, %arg4 : f64
+    ```
   }];
 
+  let arguments = (ins ShuffleValueType:$value,
+                       I32:$mask,
+                       GEN_ShflKindAttr:$kind);
+  let results = (outs ShuffleValueType:$res);
+
   let assemblyFormat = [{
     $kind $value `,` $mask attr-dict `:` type($res)
   }];
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENTraits.h b/mlir/include/mlir/Dialect/GEN/IR/GENTraits.h
deleted file mode 100644
index 11ae8014f8ad7f..00000000000000
--- a/mlir/include/mlir/Dialect/GEN/IR/GENTraits.h
+++ /dev/null
@@ -1,32 +0,0 @@
-//===--- GENTraits.h - GEN Dialect Traits -----------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_GEN_IR_GENTRAITS_H
-#define MLIR_DIALECT_GEN_IR_GENTRAITS_H
-
-#include "mlir/IR/OpDefinition.h"
-
-namespace mlir {
-namespace OpTrait {
-namespace GEN {
-namespace detail {
-LogicalResult verifyGEN3DNDRange(Operation *op);
-} // namespace detail
-
-template <typename ConcreteType>
-class GEN3DNDRange : public TraitBase<ConcreteType, GEN3DNDRange> {
-public:
-  static LogicalResult verifyTrait(Operation *op) {
-    return detail::verifyGEN3DNDRange(op);
-  }
-};
-} // namespace GEN
-} // namespace OpTrait
-} // namespace mlir
-
-#endif // MLIR_DIALECT_GEN_IR_GENTRAITS_H
diff --git a/mlir/lib/Dialect/GEN/IR/CMakeLists.txt b/mlir/lib/Dialect/GEN/IR/CMakeLists.txt
index 8b402987d5008f..40369d548823ae 100644
--- a/mlir/lib/Dialect/GEN/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/GEN/IR/CMakeLists.txt
@@ -1,7 +1,6 @@
 add_mlir_dialect_library(MLIRGENDialect
   GENDialect.cpp
   GENOps.cpp
-  GENTraits.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GEN
diff --git a/mlir/lib/Dialect/GEN/IR/GENTraits.cpp b/mlir/lib/Dialect/GEN/IR/GENTraits.cpp
deleted file mode 100644
index 6115aabff289c1..00000000000000
--- a/mlir/lib/Dialect/GEN/IR/GENTraits.cpp
+++ /dev/null
@@ -1,24 +0,0 @@
-//===- GENTraits.cpp - GEN dialect traits ---------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/GEN/IR/GENTraits.h"
-
-#include "mlir/IR/Matchers.h"
-
-using namespace mlir;
-
-LogicalResult mlir::OpTrait::GEN::detail::verifyGEN3DNDRange(Operation *op) {
-  llvm::APInt value;
-  if (matchPattern(op->getOperand(0), m_ConstantInt(&value)) &&
-      !(/*value in [0, 3)*/ value.sge(0) && value.slt(3))) {
-    return op->emitOpError()
-           << "input dimension must be in the range [0, 3). Got "
-           << value.getSExtValue();
-  }
-  return success();
-}
diff --git a/mlir/test/Dialect/GEN/invalid.mlir b/mlir/test/Dialect/GEN/invalid.mlir
deleted file mode 100644
index 1285f8af216647..00000000000000
--- a/mlir/test/Dialect/GEN/invalid.mlir
+++ /dev/null
@@ -1,17 +0,0 @@
-// RUN: mlir-opt -split-input-file %s -verify-diagnostics
-
-func.func @test_3d_nd_range_bounds_low() {
-  %c-1 = arith.constant -1 : i32
-  // expected-error @below {{'gen.local_id' op input dimension must be in the range [0, 3). Got -1}}
-  %0 = gen.local_id %c-1
-  func.return
-}
-
-// -----
-
-func.func @test_3d_nd_range_bounds_high() {
-  %c3 = arith.constant 3 : i32
-  // expected-error @below {{'gen.work_group_id' op input dimension must be in the range [0, 3). Got 3}}
-  %0 = gen.work_group_id %c3
-  func.return
-}



More information about the Mlir-commits mailing list