[Mlir-commits] [mlir] [MLIR][GEN] Define GEN dialect (PR #87757)

Victor Perez llvmlistbot at llvm.org
Fri Apr 5 01:50:51 PDT 2024


https://github.com/victor-eds created https://github.com/llvm/llvm-project/pull/87757

Define GEN dialect with the following operations:

- `gen.local_id` to query a work-item's local id
- `gen.work_group_id` to query the work-group's id
- `gen.work_group_size` to query the work-group's size
- `gen.num_work_groups` to query the number of work-groups
- `gen.barrier` to insert a barrier
- `gen.sub_group_shuffle` for different subgroup shuffles

Also define two conversion passes:

- `-convert-gen-to-spirv`
- `-convert-gen-to-llvm`

>From 0c5aefcf204cd6108bd1867081e05ee4db1dd5b0 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Fri, 5 Apr 2024 09:42:36 +0100
Subject: [PATCH] [MLIR][GEN] Define GEN dialect

Define GEN dialect with the following operations:

- `gen.local_id` to query a work-item's local id
- `gen.work_group_id` to query the work-group's id
- `gen.work_group_size` to query the work-group's size
- `gen.num_work_groups` to query the number of work-groups
- `gen.barrier` to insert a barrier
- `gen.sub_group_shuffle` for different subgroup shuffles

Also define two conversion passes:

- `-convert-gen-to-spirv`
- `-convert-gen-to-llvm`

Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
 .../mlir/Conversion/GENToLLVM/GENToLLVMPass.h |  32 ++
 .../mlir/Conversion/GENToSPIRV/GENToSPIRV.h   |  31 ++
 mlir/include/mlir/Conversion/Passes.h         |   2 +
 mlir/include/mlir/Conversion/Passes.td        |  23 ++
 mlir/include/mlir/Dialect/CMakeLists.txt      |   1 +
 mlir/include/mlir/Dialect/GEN/CMakeLists.txt  |   1 +
 .../mlir/Dialect/GEN/IR/CMakeLists.txt        |  13 +
 .../mlir/Dialect/GEN/IR/GENAttrDefs.td        |  26 ++
 mlir/include/mlir/Dialect/GEN/IR/GENDialect.h |  38 +++
 .../include/mlir/Dialect/GEN/IR/GENDialect.td |  44 +++
 mlir/include/mlir/Dialect/GEN/IR/GENOps.h     |  26 ++
 mlir/include/mlir/Dialect/GEN/IR/GENOps.td    | 130 ++++++++
 mlir/include/mlir/Dialect/GEN/IR/GENTraits.h  |  30 ++
 mlir/include/mlir/InitAllDialects.h           |   2 +
 mlir/lib/Conversion/CMakeLists.txt            |   2 +
 mlir/lib/Conversion/GENToLLVM/CMakeLists.txt  |  14 +
 .../Conversion/GENToLLVM/GENToLLVMPass.cpp    | 280 ++++++++++++++++++
 mlir/lib/Conversion/GENToSPIRV/CMakeLists.txt |  17 ++
 mlir/lib/Conversion/GENToSPIRV/GENToSPIRV.cpp | 131 ++++++++
 mlir/lib/Dialect/CMakeLists.txt               |   1 +
 mlir/lib/Dialect/GEN/CMakeLists.txt           |   1 +
 mlir/lib/Dialect/GEN/IR/CMakeLists.txt        |  17 ++
 mlir/lib/Dialect/GEN/IR/GENDialect.cpp        |  42 +++
 mlir/lib/Dialect/GEN/IR/GENOps.cpp            |  17 ++
 mlir/lib/Dialect/GEN/IR/GENTraits.cpp         |  24 ++
 .../Conversion/GENToLLVM/gen-to-llvm.mlir     |  85 ++++++
 .../Conversion/GENToSPIRV/gen-to-spirv.mlir   |  28 ++
 mlir/test/Dialect/GEN/gen.mlir                |  54 ++++
 mlir/test/Dialect/GEN/invalid.mlir            |  17 ++
 29 files changed, 1129 insertions(+)
 create mode 100644 mlir/include/mlir/Conversion/GENToLLVM/GENToLLVMPass.h
 create mode 100644 mlir/include/mlir/Conversion/GENToSPIRV/GENToSPIRV.h
 create mode 100644 mlir/include/mlir/Dialect/GEN/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/GEN/IR/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/GEN/IR/GENAttrDefs.td
 create mode 100644 mlir/include/mlir/Dialect/GEN/IR/GENDialect.h
 create mode 100644 mlir/include/mlir/Dialect/GEN/IR/GENDialect.td
 create mode 100644 mlir/include/mlir/Dialect/GEN/IR/GENOps.h
 create mode 100644 mlir/include/mlir/Dialect/GEN/IR/GENOps.td
 create mode 100644 mlir/include/mlir/Dialect/GEN/IR/GENTraits.h
 create mode 100644 mlir/lib/Conversion/GENToLLVM/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/GENToLLVM/GENToLLVMPass.cpp
 create mode 100644 mlir/lib/Conversion/GENToSPIRV/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/GENToSPIRV/GENToSPIRV.cpp
 create mode 100644 mlir/lib/Dialect/GEN/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/GEN/IR/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/GEN/IR/GENDialect.cpp
 create mode 100644 mlir/lib/Dialect/GEN/IR/GENOps.cpp
 create mode 100644 mlir/lib/Dialect/GEN/IR/GENTraits.cpp
 create mode 100644 mlir/test/Conversion/GENToLLVM/gen-to-llvm.mlir
 create mode 100644 mlir/test/Conversion/GENToSPIRV/gen-to-spirv.mlir
 create mode 100644 mlir/test/Dialect/GEN/gen.mlir
 create mode 100644 mlir/test/Dialect/GEN/invalid.mlir

diff --git a/mlir/include/mlir/Conversion/GENToLLVM/GENToLLVMPass.h b/mlir/include/mlir/Conversion/GENToLLVM/GENToLLVMPass.h
new file mode 100644
index 00000000000000..1788a10a7843fb
--- /dev/null
+++ b/mlir/include/mlir/Conversion/GENToLLVM/GENToLLVMPass.h
@@ -0,0 +1,32 @@
+//===- GENToLLVMPass.h - GEN to LLVM dialect conversion ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_GENTOLLVM_GENTOLLVMPASS_H
+#define MLIR_CONVERSION_GENTOLLVM_GENTOLLVMPASS_H
+
+#include "mlir/Pass/Pass.h"
+#include <memory>
+
+namespace mlir {
+
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTGENTOLLVM
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace GEN {
+void populateGENToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                         RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createConvertGENToLLVM();
+} // namespace GEN
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GENTOLLVM_GENTOLLVMPASS_H
diff --git a/mlir/include/mlir/Conversion/GENToSPIRV/GENToSPIRV.h b/mlir/include/mlir/Conversion/GENToSPIRV/GENToSPIRV.h
new file mode 100644
index 00000000000000..224d2dbc3ca3f8
--- /dev/null
+++ b/mlir/include/mlir/Conversion/GENToSPIRV/GENToSPIRV.h
@@ -0,0 +1,31 @@
+//===- GENToSPIRV.h - Convert GEN to SPIRV dialect -----*- C++ ----------*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_GENTOSPIRV_GENTOSPIRV_H
+#define MLIR_CONVERSION_GENTOSPIRV_GENTOSPIRV_H
+
+#include "mlir/Pass/Pass.h"
+#include <memory>
+
+namespace mlir {
+
+class SPIRVTypeConverter;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTGENTOSPIRV
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace GEN {
+void populateGENToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
+                                RewritePatternSet &patterns);
+
+std::unique_ptr<OperationPass<>> createConvertGENToSPIRVPass();
+} // namespace GEN
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GENTOSPIRV_GENTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 2179ae18ac074b..8bfab7bbfa9a1b 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -33,6 +33,8 @@
 #include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"
+#include "mlir/Conversion/GENToLLVM/GENToLLVMPass.h"
+#include "mlir/Conversion/GENToSPIRV/GENToSPIRV.h"
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
 #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d094ee3b36ab95..1f7728b9105570 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -453,6 +453,29 @@ def ConvertFuncToSPIRV : Pass<"convert-func-to-spirv"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// GENToLLVM
+//===----------------------------------------------------------------------===//
+
+def ConvertGENToLLVM : Pass<"convert-gen-to-llvm", "ModuleOp"> {
+  let summary = "Convert the GEN dialect to the LLVM dialect";
+  let description = [{
+    This pass converts GEN dialect operations to LLVM dialect operations.
+  }];
+  let constructor = "mlir::GEN::createConvertGENToLLVM()";
+  let dependentDialects = ["LLVM::LLVMDialect"];
+}
+
+//===----------------------------------------------------------------------===//
+// GENToSPIRV
+//===----------------------------------------------------------------------===//
+
+def ConvertGENToSPIRV : Pass<"convert-gen-to-spirv"> {
+  let summary = "Convert GEN dialect to SPIR-V dialect";
+  let constructor = "mlir::GEN::createConvertGENToSPIRVPass()";
+  let dependentDialects = ["spirv::SPIRVDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // GPUCommon
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 2da79011fa26a3..a775532996e1e5 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -12,6 +12,7 @@ add_subdirectory(ControlFlow)
 add_subdirectory(DLTI)
 add_subdirectory(EmitC)
 add_subdirectory(Func)
+add_subdirectory(GEN)
 add_subdirectory(GPU)
 add_subdirectory(Index)
 add_subdirectory(IRDL)
diff --git a/mlir/include/mlir/Dialect/GEN/CMakeLists.txt b/mlir/include/mlir/Dialect/GEN/CMakeLists.txt
new file mode 100644
index 00000000000000..f33061b2d87cff
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/include/mlir/Dialect/GEN/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/GEN/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..b441e52a143487
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect(GENOps gen)
+add_mlir_doc(GENDialect GENDialect Dialects/ -gen-dialect-doc)
+add_mlir_doc(GENOps GENOps Dialects/ -gen-op-doc)
+
+set(LLVM_TARGET_DEFINITIONS GENOps.td)
+mlir_tablegen(GENOpsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(GENOpsEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRGENOpsEnumsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS GENAttrDefs.td)
+mlir_tablegen(GENOpsAttrDefs.h.inc -gen-attrdef-decls)
+mlir_tablegen(GENOpsAttrDefs.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRGENOpsAttrDefsIncGen)
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENAttrDefs.td b/mlir/include/mlir/Dialect/GEN/IR/GENAttrDefs.td
new file mode 100644
index 00000000000000..ef9de1cf581634
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENAttrDefs.td
@@ -0,0 +1,26 @@
+//===-- GENAttrDefs.td - GEN dialect attributes def. file --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef GEN_ATTRDEFS
+#define GEN_ATTRDEFS
+
+include "mlir/IR/EnumAttr.td"
+
+/// Enum attribute of the different shuffle kinds.
+/// Based on https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_non_uniform_instructions
+def GEN_ShflKindAttr : I32EnumAttr<"ShflKind", "GEN shuffle kind",
+  [
+    I32EnumAttrCase<"XOR",  0, "xor">,
+    I32EnumAttrCase<"UP",   1, "up">,
+    I32EnumAttrCase<"DOWN", 2, "down">,
+    I32EnumAttrCase<"IDX",  3, "idx">
+  ]> {
+  let cppNamespace = "::mlir::GEN";
+}
+
+#endif // GEN_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENDialect.h b/mlir/include/mlir/Dialect/GEN/IR/GENDialect.h
new file mode 100644
index 00000000000000..eee2e7e8b440ad
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENDialect.h
@@ -0,0 +1,38 @@
+//===- GENDialect.h - MLIR GEN dialect -----------------------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the GEN dialect in MLIR, containing Intel GEN operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_GEN_IR_GENDIALECT_H
+#define MLIR_DIALECT_GEN_IR_GENDIALECT_H
+
+#include "mlir/IR/Dialect.h"
+
+#include "mlir/Dialect/GEN/IR/GENOpsDialect.h.inc"
+
+namespace mlir {
+namespace GEN {
+
+/// GEN memory space identifiers following SPIRV storage class convention
+/// https://github.com/KhronosGroup/SPIRV-LLVM-Translator/blob/main/docs/SPIRVRepresentationInLLVM.rst#address-spaces
+///
+enum class GENStorageClass {
+  Function = 0,        // OpenCL workitem address space
+  CrossWorkgroup = 1,  // OpenCL Global memory
+  UniformConstant = 2, // OpenCL Constant memory
+  Workgroup = 3,       // OpenCL Local memory
+  Generic = 4          // OpenCL Generic memory
+};
+
+} // namespace GEN
+} // namespace mlir
+
+#endif // MLIR_DIALECT_GEN_IR_GENDIALECT_H
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENDialect.td b/mlir/include/mlir/Dialect/GEN/IR/GENDialect.td
new file mode 100644
index 00000000000000..3ba6248decb7a4
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENDialect.td
@@ -0,0 +1,44 @@
+//===-- GENDialect.td - GEN dialect op definition file -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef GEN_DIALECT
+#define GEN_DIALECT
+
+include "mlir/IR/OpBase.td"
+
+def GEN_Dialect : Dialect {
+  let name = "gen";
+  let cppNamespace = "::mlir::GEN";
+  let summary = "The GEN dialect.";
+
+  let description = [{
+    GEN is a dialect for representing operations on Intel GPUs.
+  }];
+
+  let extraClassDeclaration = [{
+    /// Get the name of the attribute used to annotate max work group size
+    /// required for kernels.
+    static StringRef getMaxWorkGroupSizeAttrName() {
+      return "gen.max_work_group_size";
+    }
+
+    /// Get the name of the attribute used to annotate exact work group size
+    /// required for kernels.
+    static StringRef getReqdWorkGroupSizeAttrName() {
+      return "gen.reqd_work_group_size";
+    }
+
+    /// Get the name for the attribute used to annotate the exact sub group
+    /// size required for kernels.
+    static StringRef getReqdSubGroupSizeAttrName() {
+      return "gen.intel_reqd_sub_group_size";
+    }
+  }];
+}
+
+#endif // GEN_DIALECT
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENOps.h b/mlir/include/mlir/Dialect/GEN/IR/GENOps.h
new file mode 100644
index 00000000000000..be81cd111de7e3
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENOps.h
@@ -0,0 +1,26 @@
+//===--- GENOps.h - GEN Dialect Operations ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_GEN_IR_GENOPS_H
+#define MLIR_DIALECT_GEN_IR_GENOPS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include "mlir/Dialect/GEN/IR/GENOpsEnums.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/GEN/IR/GENOpsAttrDefs.h.inc"
+
+#include "mlir/Dialect/GEN/IR/GENTraits.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/GEN/IR/GENOps.h.inc"
+
+#endif // MLIR_DIALECT_GEN_IR_GENOPS_H
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENOps.td b/mlir/include/mlir/Dialect/GEN/IR/GENOps.td
new file mode 100644
index 00000000000000..f3328f17899685
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENOps.td
@@ -0,0 +1,130 @@
+//===-- GENOps.td - GEN IR dialect op definition file ------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the GEN IR operation definition file.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef GEN_OPS
+#define GEN_OPS
+
+include "mlir/Dialect/GEN/IR/GENDialect.td"
+include "mlir/Dialect/GEN/IR/GENAttrDefs.td"
+include "mlir/IR/OpBase.td"
+include "mlir/IR/EnumAttr.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpAsmInterface.td"
+
+//===----------------------------------------------------------------------===//
+// GEN op definitions
+//===----------------------------------------------------------------------===//
+
+class GEN_Op<string mnemonic, list<Trait> traits = []> :
+  LLVM_OpBase<GEN_Dialect, mnemonic, traits> {
+}
+
+//===----------------------------------------------------------------------===//
+// ND-Range Operations
+//===----------------------------------------------------------------------===//
+
+def GEN3DNDRange : NativeOpTrait<"GEN3DNDRange">;
+
+class GEN_3DNDRangeOp<string mnemonic, list<Trait> traits = []>
+    : GEN_Op<mnemonic, [GEN3DNDRange, Pure] # traits>,
+      Arguments<(ins I32:$dim)>,
+      Results<(outs Index:$res)> {
+  let assemblyFormat = "$dim attr-dict";
+}
+
+def GEN_LocalIdOp : GEN_3DNDRangeOp<"local_id"> {
+  let summary = "Query a work-item's local id.";
+  let description = [{
+    Query the work-item's position in its work-group, i.e., its local id, in a
+    given dimension.
+    ```mlir
+    %local_id = gen.local_id %dim
+    ```
+  }];
+}
+
+def GEN_WorkGroupIdOp : GEN_3DNDRangeOp<"work_group_id"> {
+  let summary = "Query a work-item's work-group id.";
+  let description = [{
+    Query the work-item's work-group id in a given dimension.
+    ```mlir
+    %work_group_id = gen.work_group_id %dim
+    ```
+  }];
+}
+
+def GEN_WorkGroupSizeOp : GEN_3DNDRangeOp<"work_group_size"> {
+  let summary = "Query the work-group size.";
+  let description = [{
+    Query the work-item's work-group size in a given dimension.
+    ```mlir
+    %work_group_size = gen.work_group_size %dim
+    ```
+  }];
+}
+
+def GEN_NumWorkGroupsOp : GEN_3DNDRangeOp<"num_work_groups"> {
+  let summary = "Query the number of work-groups in the ND-range.";
+  let description = [{
+    Query the number of work-groups in the ND-range in a given dimension.
+    ```mlir
+    %wg_number = gen.num_work_groups %dim
+    ```
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Synchronization
+//===----------------------------------------------------------------------===//
+
+def GEN_BarrierOp : GEN_Op<"barrier"> {
+  let summary = "Workgroup barrier";
+
+  string baseDescription = [{
+    The `gen.barrier` operation performs a workgroup barrier and ensures all
+    outstanding memory transaction using local or global memory are complete.
+  }];
+
+  let arguments = (ins);
+  let results = (outs);
+  let assemblyFormat = "attr-dict";
+}
+
+def IntegerOrFloatType : AnyTypeOf<[AnySignlessInteger, AnyFloat]>;
+
+def GEN_SubGroupShuffleOp : GEN_Op<"sub_group_shuffle", [
+      TypesMatchWith<"result and value have the same type",
+                     "res", "value", "$_self">]>,
+  Results<(outs IntegerOrFloatType:$res)>,
+  Arguments<(ins IntegerOrFloatType:$value,
+                 I32:$mask,
+                 GEN_ShflKindAttr:$kind)> {
+  let summary = "Subgroup shuffle";
+  string baseDescription = [{
+    The `gen.sub_group_shuffle` operation is invoked by different work items
+    with different values, given by $value. Different work items have different
+    subgroup local IDs. The shuffle kind, $kind, is given to determine how to
+    calculate the associated subgroup local ID. It returns the associated
+    $value for the work item with subgroup local ID equal to:
+    - $kind == xor, the current invocation’s subgroup local ID xor’ed with $mask.
+    - $kind == up, the current invocation’s subgroup local ID - $mask.
+    - $kind == down, the current invocation’s subgroup local ID + $mask.
+    - $kind == idx, the subgroup local ID $mask.
+  }];
+
+  let assemblyFormat = [{
+    $kind $value `,` $mask attr-dict `:` type($res)
+  }];
+}
+
+#endif // GEN_OPS
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENTraits.h b/mlir/include/mlir/Dialect/GEN/IR/GENTraits.h
new file mode 100644
index 00000000000000..3b7754887749bf
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENTraits.h
@@ -0,0 +1,30 @@
+//===--- GENTraits.h - GEN Dialect Traits -----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_GEN_IR_GENTRAITS_H
+#define MLIR_DIALECT_GEN_IR_GENTRAITS_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace OpTrait {
+namespace detail {
+LogicalResult verifyGEN3DNDRange(Operation *op);
+} // namespace detail
+
+template <typename ConcreteType>
+class GEN3DNDRange : public TraitBase<ConcreteType, GEN3DNDRange> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return detail::verifyGEN3DNDRange(op);
+  }
+};
+} // namespace OpTrait
+} // namespace mlir
+
+#endif // MLIR_DIALECT_GEN_IR_GENTRAITS_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index c558dc53cc7fac..a940d134024165 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -36,6 +36,7 @@
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GEN/IR/GENDialect.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Transforms/BufferDeallocationOpInterfaceImpl.h"
 #include "mlir/Dialect/IRDL/IR/IRDL.h"
@@ -116,6 +117,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   DLTIDialect,
                   emitc::EmitCDialect,
                   func::FuncDialect,
+                  GEN::GENDialect,
                   gpu::GPUDialect,
                   index::IndexDialect,
                   irdl::IRDLDialect,
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 41ab7046b91ce3..0d59ce960090c4 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -22,6 +22,8 @@ add_subdirectory(ConvertToLLVM)
 add_subdirectory(FuncToEmitC)
 add_subdirectory(FuncToLLVM)
 add_subdirectory(FuncToSPIRV)
+add_subdirectory(GENToLLVM)
+add_subdirectory(GENToSPIRV)
 add_subdirectory(GPUCommon)
 add_subdirectory(GPUToNVVM)
 add_subdirectory(GPUToROCDL)
diff --git a/mlir/lib/Conversion/GENToLLVM/CMakeLists.txt b/mlir/lib/Conversion/GENToLLVM/CMakeLists.txt
new file mode 100644
index 00000000000000..4ec773ccf46dd0
--- /dev/null
+++ b/mlir/lib/Conversion/GENToLLVM/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_conversion_library(MLIRGENToLLVM
+  GENToLLVMPass.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/GENToLLVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRGENDialect
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+)
diff --git a/mlir/lib/Conversion/GENToLLVM/GENToLLVMPass.cpp b/mlir/lib/Conversion/GENToLLVM/GENToLLVMPass.cpp
new file mode 100644
index 00000000000000..244817d8cf00c6
--- /dev/null
+++ b/mlir/lib/Conversion/GENToLLVM/GENToLLVMPass.cpp
@@ -0,0 +1,280 @@
+//===- GENToLLVMPass.cpp - MLIR GEN to LLVM dialect conversion ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GENToLLVM/GENToLLVMPass.h"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/GEN/IR/GENDialect.h"
+#include "mlir/Dialect/GEN/IR/GENOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/ErrorHandling.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTGENTOLLVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Helper Functions
+//===----------------------------------------------------------------------===//
+
+static LLVM::CallOp createDeviceFunctionCall(
+    ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
+    ArrayRef<Type> argTypes, ArrayRef<Value> args, bool convergent = false) {
+  auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
+  MLIRContext *context = rewriter.getContext();
+  Location loc = UnknownLoc::get(context);
+  auto convergentAttr =
+      rewriter.getArrayAttr(StringAttr::get(context, "convergent"));
+
+  auto getOrCreateFunction = [&](StringRef funcName) {
+    Operation *funcOp = moduleOp.lookupSymbol(funcName);
+    if (funcOp)
+      return cast<LLVM::LLVMFuncOp>(funcOp);
+
+    auto funcType = LLVM::LLVMFunctionType::get(retType, argTypes);
+    ConversionPatternRewriter::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPointToStart(moduleOp.getBody());
+    auto func = rewriter.create<LLVM::LLVMFuncOp>(loc, funcName, funcType);
+    func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
+    if (convergent)
+      func.setPassthroughAttr(convergentAttr);
+
+    return func;
+  };
+
+  LLVM::LLVMFuncOp funcOp = getOrCreateFunction(funcName);
+  auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args);
+  if (convergent)
+    callOp->setAttr("passthrough", convergentAttr);
+
+  return callOp;
+}
+
+static LLVM::CallOp createSubGroupShuffle(ConversionPatternRewriter &rewriter,
+                                          Value value, Value mask,
+                                          GEN::ShflKind kind) {
+  assert(isa<IntegerType>(mask.getType()) &&
+         cast<IntegerType>(mask.getType()).isInteger(32) &&
+         "Expecting mask type to be i32");
+
+  std::string fnName = "";
+  switch (kind) {
+  case GEN::ShflKind::XOR:
+    fnName = "_Z21sub_group_shuffle_xor";
+    break;
+  case GEN::ShflKind::UP:
+    fnName = "_Z20sub_group_shuffle_up";
+    break;
+  case GEN::ShflKind::DOWN:
+    fnName = "_Z22sub_group_shuffle_down";
+    break;
+  case GEN::ShflKind::IDX:
+    fnName = "_Z17sub_group_shuffle";
+    break;
+  }
+
+  TypeSwitch<Type>(value.getType())
+      .Case<Float16Type>([&](auto) { fnName += "Dh"; })
+      .Case<Float32Type>([&](auto) { fnName += "f"; })
+      .Case<Float64Type>([&](auto) { fnName += "d"; })
+      .Case<IntegerType>([&](auto ty) {
+        switch (ty.getWidth()) {
+        case 8:
+          fnName += "c";
+          break;
+        case 16:
+          fnName += "s";
+          break;
+        case 32:
+          fnName += "i";
+          break;
+        case 64:
+          fnName += "l";
+          break;
+        default:
+          llvm_unreachable("unhandled integer type");
+        }
+      });
+
+  fnName += "j";
+
+  return createDeviceFunctionCall(rewriter, fnName, value.getType(),
+                                  {value.getType(), mask.getType()},
+                                  {value, mask}, true /*convergent*/);
+}
+
+static Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v) {
+  auto i32ty = rewriter.getIntegerType(32);
+  return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
+                                           IntegerAttr::get(i32ty, v));
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ND-range Ops Lowerings
+//===----------------------------------------------------------------------===//
+
+class GEN3DNDRangeLoweringBase : public ConvertToLLVMPattern {
+public:
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(op->getNumOperands() == 1 && "Expecting a single operand");
+    Type resType = typeConverter->convertType(op->getResult(0).getType());
+    LLVM::CallOp callOp = createDeviceFunctionCall(
+        rewriter, builtinName, resType, rewriter.getI32Type(), operands[0]);
+    rewriter.replaceOp(op, callOp);
+    return success();
+  }
+
+protected:
+  GEN3DNDRangeLoweringBase(StringRef builtinName, StringRef rootOpName,
+                           const LLVMTypeConverter &typeConverter,
+                           PatternBenefit benefit)
+      : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(),
+                             typeConverter, benefit),
+        builtinName(builtinName) {}
+
+private:
+  StringRef builtinName;
+};
+
+template <typename SourceOp>
+constexpr StringRef getBuiltinName();
+
+template <>
+StringRef getBuiltinName<GEN::LocalIdOp>() {
+  return "_Z12get_local_idj";
+}
+
+template <>
+StringRef getBuiltinName<GEN::WorkGroupIdOp>() {
+  return "_Z12get_group_idj";
+}
+
+template <>
+StringRef getBuiltinName<GEN::WorkGroupSizeOp>() {
+  return "_Z14get_local_sizej";
+}
+
+template <>
+StringRef getBuiltinName<GEN::NumWorkGroupsOp>() {
+  return "_Z14get_num_groupsj";
+}
+
+template <typename SourceOp>
+struct GEN3DNDRangeLowering : public GEN3DNDRangeLoweringBase {
+  GEN3DNDRangeLowering(const LLVMTypeConverter &typeConverter,
+                       PatternBenefit benefit = 1)
+      : GEN3DNDRangeLoweringBase(getBuiltinName<SourceOp>(),
+                                 SourceOp::getOperationName(), typeConverter,
+                                 benefit) {}
+};
+
+//===----------------------------------------------------------------------===//
+// Synchronization Ops Lowerings
+//===----------------------------------------------------------------------===//
+
+struct GENBarrierLowering : public ConvertOpToLLVMPattern<GEN::BarrierOp> {
+  using ConvertOpToLLVMPattern<GEN::BarrierOp>::ConvertOpToLLVMPattern;
+
+  enum MemFence {
+    Local = 0x01,
+    Global = 0x02,
+  };
+
+  LogicalResult
+  matchAndRewrite(GEN::BarrierOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto retType = LLVM::LLVMVoidType::get(rewriter.getContext());
+    auto argType = rewriter.getIntegerType(32);
+    auto arg = createConstantI32(op->getLoc(), rewriter, MemFence::Local);
+    LLVM::CallOp callOp =
+        createDeviceFunctionCall(rewriter, "_Z7barrierj", {retType}, {argType},
+                                 {arg}, true /*convergent*/);
+    rewriter.replaceOp(op, callOp);
+    return success();
+  }
+};
+
+struct SubGroupShuffleLowering
+    : public ConvertOpToLLVMPattern<GEN::SubGroupShuffleOp> {
+  using ConvertOpToLLVMPattern<GEN::SubGroupShuffleOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(GEN::SubGroupShuffleOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Value val = op.getValue();
+    Value mask = op.getMask();
+    GEN::ShflKind kind = op.getKind();
+    LLVM::CallOp callOp = createSubGroupShuffle(rewriter, val, mask, kind);
+    rewriter.replaceOp(op, callOp);
+    return success();
+  }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ConvertGENToLLVM final
+    : public impl::ConvertGENToLLVMBase<ConvertGENToLLVM> {
+  using Base::Base;
+
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    RewritePatternSet pattern(context);
+    LowerToLLVMOptions options(context);
+    LLVMTypeConverter converter(context, options);
+    LLVMConversionTarget target(*context);
+
+    GEN::populateGENToLLVMConversionPatterns(converter, pattern);
+
+    if (failed(
+            applyPartialConversion(getOperation(), target, std::move(pattern))))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern Population and Registration
+//===----------------------------------------------------------------------===//
+
+void mlir::GEN::populateGENToLLVMConversionPatterns(
+    LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+  patterns.add<GEN3DNDRangeLowering<GEN::LocalIdOp>,
+               GEN3DNDRangeLowering<GEN::WorkGroupIdOp>,
+               GEN3DNDRangeLowering<GEN::WorkGroupSizeOp>,
+               GEN3DNDRangeLowering<GEN::NumWorkGroupsOp>, GENBarrierLowering,
+               SubGroupShuffleLowering>(converter);
+}
+
+std::unique_ptr<Pass> mlir::GEN::createConvertGENToLLVM() {
+  return std::make_unique<ConvertGENToLLVM>();
+}
diff --git a/mlir/lib/Conversion/GENToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/GENToSPIRV/CMakeLists.txt
new file mode 100644
index 00000000000000..3afb89ad958a64
--- /dev/null
+++ b/mlir/lib/Conversion/GENToSPIRV/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(MLIRGENToSPIRV
+  GENToSPIRV.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/GENToSPIRV
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRGENDialect
+  MLIRSPIRVConversion
+  MLIRSPIRVDialect
+  )
diff --git a/mlir/lib/Conversion/GENToSPIRV/GENToSPIRV.cpp b/mlir/lib/Conversion/GENToSPIRV/GENToSPIRV.cpp
new file mode 100644
index 00000000000000..37077c6f1cab8f
--- /dev/null
+++ b/mlir/lib/Conversion/GENToSPIRV/GENToSPIRV.cpp
@@ -0,0 +1,131 @@
+//===- GENToSPIRV.cpp - GEN to SPIRV dialect conversion -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GENToSPIRV/GENToSPIRV.h"
+
+#include "mlir/Dialect/GEN/IR/GENDialect.h"
+#include "mlir/Dialect/GEN/IR/GENOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTGENTOSPIRV
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "gen-to-spirv-pattern"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// ND-range Ops Lowerings
+//===----------------------------------------------------------------------===//
+
+/// Pattern to convert GEN3DNDRange operations to SPIR-V.
+///
+/// Convert:
+/// ```mlir
+/// %0 = gen.operation_name %dim
+///  ```
+/// To:
+/// ```mlir
+/// %__spirv_BuiltinName___addr = spirv.mlir.addressof
+/// @__spirv_BuiltInBuiltinName : !spirv.ptr<vector<3xIndexType>, Input>
+/// %__builtin_value = spirv.Load "Input" %__builtin__BuiltinName___addr :
+/// vector<3xIndexType>
+/// %0 = spirv.VectorExtractDynamic %__builtin_value[%dim] :
+/// vector<3xIndexType>, i32
+/// ```
+/// With `BuiltinName` the name of a SPIR-V builtin, and `IndexType`, `i32` for
+/// 32-bit targets and `i64` for 64-bit targets.
+class GEN3DNDRangeLoweringBase : public ConversionPattern {
+public:
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(operands.size() == 1 && "Expecting a single operand");
+    // The builtin variable must be of type <3xi32> for 32-bit targets and
+    // <3xi64> for 64-bit targets.
+    Type builtinType =
+        this->template getTypeConverter<SPIRVTypeConverter>()->getIndexType();
+    constexpr StringLiteral spvBuiltinPrefix = "__spirv_BuiltIn";
+    constexpr StringLiteral spvBuiltinSuffix = "";
+    Value vector = spirv::getBuiltinVariableValue(
+        op, builtin, builtinType, rewriter, spvBuiltinPrefix, spvBuiltinSuffix);
+    rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(op, vector,
+                                                               operands[0]);
+    return success();
+  }
+
+protected:
+  GEN3DNDRangeLoweringBase(spirv::BuiltIn builtin,
+                           const TypeConverter &typeConverter, StringRef opName,
+                           PatternBenefit benefit, MLIRContext *context)
+      : ConversionPattern(typeConverter, opName, benefit, context),
+        builtin(builtin) {}
+
+private:
+  spirv::BuiltIn builtin;
+};
+
+template <typename SourceOp, spirv::BuiltIn Builtin>
+struct GEN3DNDRangeLowering : public GEN3DNDRangeLoweringBase {
+  GEN3DNDRangeLowering(const TypeConverter &typeConverter, MLIRContext *context,
+                       PatternBenefit benefit = 1)
+      : GEN3DNDRangeLoweringBase(Builtin, typeConverter,
+                                 SourceOp::getOperationName(), benefit,
+                                 context) {}
+};
+
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+
+void mlir::GEN::populateGENToSPIRVPatterns(SPIRVTypeConverter &converter,
+                                           RewritePatternSet &patterns) {
+  patterns.add<
+      GEN3DNDRangeLowering<GEN::LocalIdOp, spirv::BuiltIn::LocalInvocationId>,
+      GEN3DNDRangeLowering<GEN::WorkGroupIdOp, spirv::BuiltIn::WorkgroupId>,
+      GEN3DNDRangeLowering<GEN::WorkGroupSizeOp, spirv::BuiltIn::WorkgroupSize>,
+      GEN3DNDRangeLowering<GEN::NumWorkGroupsOp,
+                           spirv::BuiltIn::NumWorkgroups>>(
+      converter, patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ConvertGENToSPIRVPass
+    : public impl::ConvertGENToSPIRVBase<ConvertGENToSPIRVPass> {
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
+    std::unique_ptr<SPIRVConversionTarget> target =
+        SPIRVConversionTarget::get(targetAttr);
+
+    SPIRVTypeConverter typeConverter(targetAttr);
+
+    // Fail hard when there are any remaining GEN ops.
+    target->addIllegalDialect<GEN::GENDialect>();
+
+    RewritePatternSet patterns(&getContext());
+    GEN::populateGENToSPIRVPatterns(typeConverter, patterns);
+
+    if (failed(applyPartialConversion(op, *target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<>> mlir::GEN::createConvertGENToSPIRVPass() {
+  return std::make_unique<ConvertGENToSPIRVPass>();
+}
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index b1ba5a3bc8817d..72e11b4eb33dfa 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -12,6 +12,7 @@ add_subdirectory(ControlFlow)
 add_subdirectory(DLTI)
 add_subdirectory(EmitC)
 add_subdirectory(Func)
+add_subdirectory(GEN)
 add_subdirectory(GPU)
 add_subdirectory(Index)
 add_subdirectory(IRDL)
diff --git a/mlir/lib/Dialect/GEN/CMakeLists.txt b/mlir/lib/Dialect/GEN/CMakeLists.txt
new file mode 100644
index 00000000000000..f33061b2d87cff
--- /dev/null
+++ b/mlir/lib/Dialect/GEN/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/GEN/IR/CMakeLists.txt b/mlir/lib/Dialect/GEN/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..160c6eb2c53bd5
--- /dev/null
+++ b/mlir/lib/Dialect/GEN/IR/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRGENDialect
+  GENDialect.cpp
+  GENOps.cpp
+  GENTraits.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GEN
+
+  DEPENDS
+  MLIRGENOpsIncGen
+  MLIRGENOpsEnumsIncGen
+  MLIRGENOpsAttrDefsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMDialect
+)
diff --git a/mlir/lib/Dialect/GEN/IR/GENDialect.cpp b/mlir/lib/Dialect/GEN/IR/GENDialect.cpp
new file mode 100644
index 00000000000000..351e5a32a67505
--- /dev/null
+++ b/mlir/lib/Dialect/GEN/IR/GENDialect.cpp
@@ -0,0 +1,42 @@
+//===- GENDialect.cpp - MLIR GEN Dialect implementation -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GEN/IR/GENDialect.h"
+#include "mlir/Dialect/GEN/IR/GENOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Function.h"
+#include "llvm/Support/SourceMgr.h"
+
+using namespace mlir;
+using namespace mlir::GEN;
+
+#include "mlir/Dialect/GEN/IR/GENOpsDialect.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// GEN dialect.
+//===----------------------------------------------------------------------===//
+
+void GENDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/GEN/IR/GENOps.cpp.inc"
+      >();
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/GEN/IR/GENOpsAttrDefs.cpp.inc"
+      >();
+}
diff --git a/mlir/lib/Dialect/GEN/IR/GENOps.cpp b/mlir/lib/Dialect/GEN/IR/GENOps.cpp
new file mode 100644
index 00000000000000..f7827e0520fb27
--- /dev/null
+++ b/mlir/lib/Dialect/GEN/IR/GENOps.cpp
@@ -0,0 +1,17 @@
+//===- GENOps.cpp - GEN dialect operations --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GEN/IR/GENOps.h"
+#include "mlir/Dialect/GEN/IR/GENDialect.h"
+#include "mlir/IR/Builders.h"
+
+#include "mlir/Dialect/GEN/IR/GENOpsEnums.cpp.inc"
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/GEN/IR/GENOpsAttrDefs.cpp.inc"
+#define GET_OP_CLASSES
+#include "mlir/Dialect/GEN/IR/GENOps.cpp.inc"
diff --git a/mlir/lib/Dialect/GEN/IR/GENTraits.cpp b/mlir/lib/Dialect/GEN/IR/GENTraits.cpp
new file mode 100644
index 00000000000000..d1a01675c3d08a
--- /dev/null
+++ b/mlir/lib/Dialect/GEN/IR/GENTraits.cpp
@@ -0,0 +1,24 @@
+//===- GENTraits.cpp - GEN dialect traits ---------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GEN/IR/GENTraits.h"
+
+#include "mlir/IR/Matchers.h"
+
+using namespace mlir;
+
+LogicalResult mlir::OpTrait::detail::verifyGEN3DNDRange(Operation *op) {
+  llvm::APInt value;
+  if (matchPattern(op->getOperand(0), m_ConstantInt(&value)) &&
+      !(/*value in [0, 3)*/ value.sge(0) && value.slt(3))) {
+    return op->emitOpError()
+           << "input dimension must be in the range [0, 3). Got "
+           << value.getSExtValue();
+  }
+  return success();
+}
diff --git a/mlir/test/Conversion/GENToLLVM/gen-to-llvm.mlir b/mlir/test/Conversion/GENToLLVM/gen-to-llvm.mlir
new file mode 100644
index 00000000000000..16679487032212
--- /dev/null
+++ b/mlir/test/Conversion/GENToLLVM/gen-to-llvm.mlir
@@ -0,0 +1,85 @@
+// RUN: mlir-opt -convert-gen-to-llvm -split-input-file %s | FileCheck %s
+
+llvm.func @gen_nd_range(%dim: i32) {
+  // CHECK-LABEL: gen_nd_range
+  // CHECK-SAME:              (%[[DIM:.*]]: i32)
+  // CHECK:         llvm.call @_Z12get_local_idj(%[[DIM]]) : (i32) -> i64
+  %0 = gen.local_id %dim
+  // CHECK:         llvm.call @_Z12get_group_idj(%[[DIM]]) : (i32) -> i64
+  %1 = gen.work_group_id %dim
+  // CHECK:         llvm.call @_Z14get_local_sizej(%[[DIM]]) : (i32) -> i64
+  %2 = gen.work_group_size %dim
+  // CHECK:         llvm.call @_Z14get_num_groupsj(%[[DIM]]) : (i32) -> i64
+  %3 = gen.num_work_groups %dim
+  llvm.return
+}
+
+// -----
+
+// CHECK: llvm.func spir_funccc @_Z7barrierj(i32) attributes {passthrough = ["convergent"]}
+
+llvm.func @gen.barrier() {
+  // CHECK-LABEL: gen.barrier
+  // CHECK: [[CST:%.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: llvm.call @_Z7barrierj([[CST]]) {passthrough = ["convergent"]} : (i32) -> ()
+  gen.barrier
+  llvm.return
+}
+
+// -----
+
+// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xordj(f64, i32) -> f64 attributes {passthrough = ["convergent"]}
+// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorfj(f32, i32) -> f32 attributes {passthrough = ["convergent"]}
+// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorDhj(f16, i32) -> f16 attributes {passthrough = ["convergent"]}
+// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorlj(i64, i32) -> i64 attributes {passthrough = ["convergent"]}
+// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorsj(i16, i32) -> i16 attributes {passthrough = ["convergent"]}
+// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorcj(i8, i32) -> i8 attributes {passthrough = ["convergent"]}
+// CHECK-DAG: llvm.func spir_funccc @_Z17sub_group_shuffleij(i32, i32) -> i32 attributes {passthrough = ["convergent"]}
+// CHECK-DAG: llvm.func spir_funccc @_Z22sub_group_shuffle_downij(i32, i32) -> i32 attributes {passthrough = ["convergent"]}
+// CHECK-DAG: llvm.func spir_funccc @_Z20sub_group_shuffle_upij(i32, i32) -> i32 attributes {passthrough = ["convergent"]}
+// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorij(i32, i32) -> i32 attributes {passthrough = ["convergent"]}
+
+llvm.func @gen.sub_group_shuffle() {
+  // CHECK-LABEL: gen.sub_group_shuffle
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: llvm.call @_Z21sub_group_shuffle_xorij([[ZERO]], [[ZERO]]) {passthrough = ["convergent"]} : (i32, i32) -> i32
+  // CHECK: llvm.call @_Z20sub_group_shuffle_upij([[ZERO]], [[ZERO]]) {passthrough = ["convergent"]} : (i32, i32) -> i32
+  // CHECK: llvm.call @_Z22sub_group_shuffle_downij([[ZERO]], [[ZERO]]) {passthrough = ["convergent"]} : (i32, i32) -> i32
+  // CHECK: llvm.call @_Z17sub_group_shuffleij([[ZERO]], [[ZERO]]) {passthrough = ["convergent"]} : (i32, i32) -> i32
+  %1 = gen.sub_group_shuffle xor %0, %0 : i32
+  %2 = gen.sub_group_shuffle up %0, %0 : i32
+  %3 = gen.sub_group_shuffle down %0, %0 : i32
+  %4 = gen.sub_group_shuffle idx %0, %0 : i32
+
+  // CHECK: [[ZERO1:%.*]] = llvm.mlir.constant(0 : i8) : i8
+  // CHECK: llvm.call @_Z21sub_group_shuffle_xorcj([[ZERO1]], [[ZERO]]) {passthrough = ["convergent"]} : (i8, i32) -> i8
+  %5 = llvm.mlir.constant(0 : i8) : i8
+  %6 = gen.sub_group_shuffle xor %5, %0 : i8
+
+  // CHECK: [[ZERO2:%.*]] = llvm.mlir.constant(0 : i16) : i16
+  // CHECK: llvm.call @_Z21sub_group_shuffle_xorsj([[ZERO2]], [[ZERO]]) {passthrough = ["convergent"]} : (i16, i32) -> i16
+  %7 = llvm.mlir.constant(0 : i16) : i16
+  %8 = gen.sub_group_shuffle xor %7, %0 : i16
+
+  // CHECK: [[ZERO3:%.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK: llvm.call @_Z21sub_group_shuffle_xorlj([[ZERO3]], [[ZERO]]) {passthrough = ["convergent"]} : (i64, i32) -> i64
+  %9 = llvm.mlir.constant(0 : i64) : i64
+  %10 = gen.sub_group_shuffle xor %9, %0 : i64
+
+  // CHECK: [[ZERO4:%.*]] = llvm.mlir.constant(0.000000e+00 : f16) : f16
+  // CHECK: llvm.call @_Z21sub_group_shuffle_xorDhj([[ZERO4]], [[ZERO]]) {passthrough = ["convergent"]} : (f16, i32) -> f16
+  %11 = llvm.mlir.constant(0.0 : f16) : f16
+  %12 = gen.sub_group_shuffle xor %11, %0 : f16
+
+  // CHECK: [[ZERO5:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+  // CHECK: llvm.call @_Z21sub_group_shuffle_xorfj([[ZERO5]], [[ZERO]]) {passthrough = ["convergent"]} : (f32, i32) -> f32
+  %13 = llvm.mlir.constant(0.0 : f32) : f32
+  %14 = gen.sub_group_shuffle xor %13, %0 : f32
+
+  // CHECK: [[ZERO6:%.*]] = llvm.mlir.constant(0.000000e+00 : f64) : f64
+  // CHECK: llvm.call @_Z21sub_group_shuffle_xordj([[ZERO6]], [[ZERO]]) {passthrough = ["convergent"]} : (f64, i32) -> f64
+  %15 = llvm.mlir.constant(0.0 : f64) : f64
+  %16 = gen.sub_group_shuffle xor %15, %0 : f64
+  llvm.return
+}
diff --git a/mlir/test/Conversion/GENToSPIRV/gen-to-spirv.mlir b/mlir/test/Conversion/GENToSPIRV/gen-to-spirv.mlir
new file mode 100644
index 00000000000000..68f6b8efcdadf6
--- /dev/null
+++ b/mlir/test/Conversion/GENToSPIRV/gen-to-spirv.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt -convert-gen-to-spirv -split-input-file %s | FileCheck %s
+
+// CHECK-DAG:     spirv.GlobalVariable @__spirv_BuiltInNumWorkgroups built_in("NumWorkgroups") : !spirv.ptr<vector<3xi32>, Input>
+// CHECK-DAG:     spirv.GlobalVariable @__spirv_BuiltInWorkgroupSize built_in("WorkgroupSize") : !spirv.ptr<vector<3xi32>, Input>
+// CHECK-DAG:     spirv.GlobalVariable @__spirv_BuiltInWorkgroupId built_in("WorkgroupId") : !spirv.ptr<vector<3xi32>, Input>
+// CHECK-DAG:     spirv.GlobalVariable @__spirv_BuiltInLocalInvocationId built_in("LocalInvocationId") : !spirv.ptr<vector<3xi32>, Input>
+
+// CHECK-LABEL:   func.func @gen_nd_range(
+// CHECK-SAME:                            %[[VAL_0:.*]]: i32) {
+func.func @gen_nd_range(%dim: i32) {
+// CHECK:           %[[VAL_1:.*]] = spirv.mlir.addressof @__spirv_BuiltInLocalInvocationId : !spirv.ptr<vector<3xi32>, Input>
+// CHECK:           %[[VAL_2:.*]] = spirv.Load "Input" %[[VAL_1]] : vector<3xi32>
+// CHECK:           %[[VAL_3:.*]] = spirv.VectorExtractDynamic %[[VAL_2]]{{\[}}%[[VAL_0]]] : vector<3xi32>, i32
+  %0 = gen.local_id %dim
+// CHECK:           %[[VAL_4:.*]] = spirv.mlir.addressof @__spirv_BuiltInWorkgroupId : !spirv.ptr<vector<3xi32>, Input>
+// CHECK:           %[[VAL_5:.*]] = spirv.Load "Input" %[[VAL_4]] : vector<3xi32>
+// CHECK:           %[[VAL_6:.*]] = spirv.VectorExtractDynamic %[[VAL_5]]{{\[}}%[[VAL_0]]] : vector<3xi32>, i32
+  %1 = gen.work_group_id %dim
+// CHECK:           %[[VAL_7:.*]] = spirv.mlir.addressof @__spirv_BuiltInWorkgroupSize : !spirv.ptr<vector<3xi32>, Input>
+// CHECK:           %[[VAL_8:.*]] = spirv.Load "Input" %[[VAL_7]] : vector<3xi32>
+// CHECK:           %[[VAL_9:.*]] = spirv.VectorExtractDynamic %[[VAL_8]]{{\[}}%[[VAL_0]]] : vector<3xi32>, i32
+  %2 = gen.work_group_size %dim
+// CHECK:           %[[VAL_10:.*]] = spirv.mlir.addressof @__spirv_BuiltInNumWorkgroups : !spirv.ptr<vector<3xi32>, Input>
+// CHECK:           %[[VAL_11:.*]] = spirv.Load "Input" %[[VAL_10]] : vector<3xi32>
+// CHECK:           %[[VAL_12:.*]] = spirv.VectorExtractDynamic %[[VAL_11]]{{\[}}%[[VAL_0]]] : vector<3xi32>, i32
+  %3 = gen.num_work_groups %dim
+  func.return
+}
diff --git a/mlir/test/Dialect/GEN/gen.mlir b/mlir/test/Dialect/GEN/gen.mlir
new file mode 100644
index 00000000000000..e4d3dc76d4b560
--- /dev/null
+++ b/mlir/test/Dialect/GEN/gen.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
+
+llvm.func @gen_nd_range(%dim: i32) {
+  // CHECK-LABEL: gen_nd_range
+  // CHECK-SAME:              (%[[DIM:.*]]: i32)
+  // CHECK: gen.local_id %[[DIM]]
+  %0 = gen.local_id %dim
+  // CHECK: gen.work_group_id %[[DIM]]
+  %1 = gen.work_group_id %dim
+  // CHECK: gen.work_group_size %[[DIM]]
+  %2 = gen.work_group_size %dim
+  // CHECK: gen.num_work_groups %[[DIM]]
+  %3 = gen.num_work_groups %dim
+  llvm.return
+}
+
+llvm.func @gen.barrier() {
+  // CHECK-LABEL: gen.barrier
+  // CHECK: gen.barrier
+  gen.barrier
+  llvm.return
+}
+
+llvm.func @gen.sub_group_shuffle() {
+  // CHECK-LABEL: gen.sub_group_shuffle
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %1 = gen.sub_group_shuffle xor %0, %0 : i32
+  %1 = gen.sub_group_shuffle xor %0, %0 : i32
+  // CHECK: %2 = gen.sub_group_shuffle up %0, %0 : i32
+  %2 = gen.sub_group_shuffle up %0, %0 : i32
+  // CHECK: %3 = gen.sub_group_shuffle down %0, %0 : i32
+  %3 = gen.sub_group_shuffle down %0, %0 : i32
+  // CHECK: %4 = gen.sub_group_shuffle idx %0, %0 : i32
+  %4 = gen.sub_group_shuffle idx %0, %0 : i32
+  %5 = llvm.mlir.constant(0 : i8) : i8
+  // CHECK: %6 = gen.sub_group_shuffle xor %5, %0 : i8
+  %6 = gen.sub_group_shuffle xor %5, %0 : i8
+  %7 = llvm.mlir.constant(0 : i16) : i16
+  // CHECK: %8 = gen.sub_group_shuffle xor %7, %0 : i16
+  %8 = gen.sub_group_shuffle xor %7, %0 : i16
+  %9 = llvm.mlir.constant(0 : i64) : i64
+  // CHECK: %10 = gen.sub_group_shuffle xor %9, %0 : i64
+  %10 = gen.sub_group_shuffle xor %9, %0 : i64
+  %11 = llvm.mlir.constant(0.0 : f16) : f16
+  // CHECK: %12 = gen.sub_group_shuffle xor %11, %0 : f16
+  %12 = gen.sub_group_shuffle xor %11, %0 : f16
+  %13 = llvm.mlir.constant(0.0 : f32) : f32
+  // CHECK: %14 = gen.sub_group_shuffle xor %13, %0 : f32
+  %14 = gen.sub_group_shuffle xor %13, %0 : f32
+  %15 = llvm.mlir.constant(0.0 : f64) : f64
+  // CHECK: %16 = gen.sub_group_shuffle xor %15, %0 : f64
+  %16 = gen.sub_group_shuffle xor %15, %0 : f64
+  llvm.return
+}
diff --git a/mlir/test/Dialect/GEN/invalid.mlir b/mlir/test/Dialect/GEN/invalid.mlir
new file mode 100644
index 00000000000000..1285f8af216647
--- /dev/null
+++ b/mlir/test/Dialect/GEN/invalid.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt -split-input-file %s -verify-diagnostics
+
+func.func @test_3d_nd_range_bounds_low() {
+  %c-1 = arith.constant -1 : i32
+  // expected-error @below {{'gen.local_id' op input dimension must be in the range [0, 3). Got -1}}
+  %0 = gen.local_id %c-1
+  func.return
+}
+
+// -----
+
+func.func @test_3d_nd_range_bounds_high() {
+  %c3 = arith.constant 3 : i32
+  // expected-error @below {{'gen.work_group_id' op input dimension must be in the range [0, 3). Got 3}}
+  %0 = gen.work_group_id %c3
+  func.return
+}



More information about the Mlir-commits mailing list