[Mlir-commits] [mlir] [MLIR][GEN] Define GEN dialect (PR #87757)
Victor Perez
llvmlistbot at llvm.org
Fri Apr 5 01:50:51 PDT 2024
https://github.com/victor-eds created https://github.com/llvm/llvm-project/pull/87757
Define GEN dialect with the following operations:
- `gen.local_id` to query a work-item's local id
- `gen.work_group_id` to query the work-group's id
- `gen.work_group_size` to query the work-group's size
- `gen.num_work_groups` to query the number of work-groups
- `gen.barrier` to insert a barrier
- `gen.sub_group_shuffle` for different subgroup shuffles
Also define two conversion passes:
- `-convert-gen-to-spirv`
- `-convert-gen-to-llvm`
>From 0c5aefcf204cd6108bd1867081e05ee4db1dd5b0 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Fri, 5 Apr 2024 09:42:36 +0100
Subject: [PATCH] [MLIR][GEN] Define GEN dialect
Define GEN dialect with the following operations:
- `gen.local_id` to query a work-item's local id
- `gen.work_group_id` to query the work-group's id
- `gen.work_group_size` to query the work-group's size
- `gen.num_work_groups` to query the number of work-groups
- `gen.barrier` to insert a barrier
- `gen.sub_group_shuffle` for different subgroup shuffles
Also define two conversion passes:
- `-convert-gen-to-spirv`
- `-convert-gen-to-llvm`
Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
.../mlir/Conversion/GENToLLVM/GENToLLVMPass.h | 32 ++
.../mlir/Conversion/GENToSPIRV/GENToSPIRV.h | 31 ++
mlir/include/mlir/Conversion/Passes.h | 2 +
mlir/include/mlir/Conversion/Passes.td | 23 ++
mlir/include/mlir/Dialect/CMakeLists.txt | 1 +
mlir/include/mlir/Dialect/GEN/CMakeLists.txt | 1 +
.../mlir/Dialect/GEN/IR/CMakeLists.txt | 13 +
.../mlir/Dialect/GEN/IR/GENAttrDefs.td | 26 ++
mlir/include/mlir/Dialect/GEN/IR/GENDialect.h | 38 +++
.../include/mlir/Dialect/GEN/IR/GENDialect.td | 44 +++
mlir/include/mlir/Dialect/GEN/IR/GENOps.h | 26 ++
mlir/include/mlir/Dialect/GEN/IR/GENOps.td | 130 ++++++++
mlir/include/mlir/Dialect/GEN/IR/GENTraits.h | 30 ++
mlir/include/mlir/InitAllDialects.h | 2 +
mlir/lib/Conversion/CMakeLists.txt | 2 +
mlir/lib/Conversion/GENToLLVM/CMakeLists.txt | 14 +
.../Conversion/GENToLLVM/GENToLLVMPass.cpp | 280 ++++++++++++++++++
mlir/lib/Conversion/GENToSPIRV/CMakeLists.txt | 17 ++
mlir/lib/Conversion/GENToSPIRV/GENToSPIRV.cpp | 131 ++++++++
mlir/lib/Dialect/CMakeLists.txt | 1 +
mlir/lib/Dialect/GEN/CMakeLists.txt | 1 +
mlir/lib/Dialect/GEN/IR/CMakeLists.txt | 17 ++
mlir/lib/Dialect/GEN/IR/GENDialect.cpp | 42 +++
mlir/lib/Dialect/GEN/IR/GENOps.cpp | 17 ++
mlir/lib/Dialect/GEN/IR/GENTraits.cpp | 24 ++
.../Conversion/GENToLLVM/gen-to-llvm.mlir | 85 ++++++
.../Conversion/GENToSPIRV/gen-to-spirv.mlir | 28 ++
mlir/test/Dialect/GEN/gen.mlir | 54 ++++
mlir/test/Dialect/GEN/invalid.mlir | 17 ++
29 files changed, 1129 insertions(+)
create mode 100644 mlir/include/mlir/Conversion/GENToLLVM/GENToLLVMPass.h
create mode 100644 mlir/include/mlir/Conversion/GENToSPIRV/GENToSPIRV.h
create mode 100644 mlir/include/mlir/Dialect/GEN/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/GEN/IR/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/GEN/IR/GENAttrDefs.td
create mode 100644 mlir/include/mlir/Dialect/GEN/IR/GENDialect.h
create mode 100644 mlir/include/mlir/Dialect/GEN/IR/GENDialect.td
create mode 100644 mlir/include/mlir/Dialect/GEN/IR/GENOps.h
create mode 100644 mlir/include/mlir/Dialect/GEN/IR/GENOps.td
create mode 100644 mlir/include/mlir/Dialect/GEN/IR/GENTraits.h
create mode 100644 mlir/lib/Conversion/GENToLLVM/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/GENToLLVM/GENToLLVMPass.cpp
create mode 100644 mlir/lib/Conversion/GENToSPIRV/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/GENToSPIRV/GENToSPIRV.cpp
create mode 100644 mlir/lib/Dialect/GEN/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/GEN/IR/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/GEN/IR/GENDialect.cpp
create mode 100644 mlir/lib/Dialect/GEN/IR/GENOps.cpp
create mode 100644 mlir/lib/Dialect/GEN/IR/GENTraits.cpp
create mode 100644 mlir/test/Conversion/GENToLLVM/gen-to-llvm.mlir
create mode 100644 mlir/test/Conversion/GENToSPIRV/gen-to-spirv.mlir
create mode 100644 mlir/test/Dialect/GEN/gen.mlir
create mode 100644 mlir/test/Dialect/GEN/invalid.mlir
diff --git a/mlir/include/mlir/Conversion/GENToLLVM/GENToLLVMPass.h b/mlir/include/mlir/Conversion/GENToLLVM/GENToLLVMPass.h
new file mode 100644
index 00000000000000..1788a10a7843fb
--- /dev/null
+++ b/mlir/include/mlir/Conversion/GENToLLVM/GENToLLVMPass.h
@@ -0,0 +1,32 @@
+//===- GENToLLVMPass.h - GEN to LLVM dialect conversion ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_GENTOLLVM_GENTOLLVMPASS_H
+#define MLIR_CONVERSION_GENTOLLVM_GENTOLLVMPASS_H
+
+#include "mlir/Pass/Pass.h"
+#include <memory>
+
+namespace mlir {
+
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTGENTOLLVM
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace GEN {
+void populateGENToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
+
+std::unique_ptr<Pass> createConvertGENToLLVM();
+} // namespace GEN
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GENTOLLVM_GENTOLLVMPASS_H
diff --git a/mlir/include/mlir/Conversion/GENToSPIRV/GENToSPIRV.h b/mlir/include/mlir/Conversion/GENToSPIRV/GENToSPIRV.h
new file mode 100644
index 00000000000000..224d2dbc3ca3f8
--- /dev/null
+++ b/mlir/include/mlir/Conversion/GENToSPIRV/GENToSPIRV.h
@@ -0,0 +1,31 @@
+//===- GENToSPIRV.h - Convert GEN to SPIRV dialect -----*- C++ ----------*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_GENTOSPIRV_GENTOSPIRV_H
+#define MLIR_CONVERSION_GENTOSPIRV_GENTOSPIRV_H
+
+#include "mlir/Pass/Pass.h"
+#include <memory>
+
+namespace mlir {
+
+class SPIRVTypeConverter;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTGENTOSPIRV
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace GEN {
+void populateGENToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns);
+
+std::unique_ptr<OperationPass<>> createConvertGENToSPIRVPass();
+} // namespace GEN
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_GENTOSPIRV_GENTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 2179ae18ac074b..8bfab7bbfa9a1b 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -33,6 +33,8 @@
#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"
+#include "mlir/Conversion/GENToLLVM/GENToLLVMPass.h"
+#include "mlir/Conversion/GENToSPIRV/GENToSPIRV.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d094ee3b36ab95..1f7728b9105570 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -453,6 +453,29 @@ def ConvertFuncToSPIRV : Pass<"convert-func-to-spirv"> {
];
}
+//===----------------------------------------------------------------------===//
+// GENToLLVM
+//===----------------------------------------------------------------------===//
+
+def ConvertGENToLLVM : Pass<"convert-gen-to-llvm", "ModuleOp"> {
+ let summary = "Convert the GEN dialect to the LLVM dialect";
+ let description = [{
+ This pass converts GEN dialect operations to LLVM dialect operations.
+ }];
+ let constructor = "mlir::GEN::createConvertGENToLLVM()";
+ let dependentDialects = ["LLVM::LLVMDialect"];
+}
+
+//===----------------------------------------------------------------------===//
+// GENToSPIRV
+//===----------------------------------------------------------------------===//
+
+def ConvertGENToSPIRV : Pass<"convert-gen-to-spirv"> {
+ let summary = "Convert GEN dialect to SPIR-V dialect";
+ let constructor = "mlir::GEN::createConvertGENToSPIRVPass()";
+ let dependentDialects = ["spirv::SPIRVDialect"];
+}
+
//===----------------------------------------------------------------------===//
// GPUCommon
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 2da79011fa26a3..a775532996e1e5 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -12,6 +12,7 @@ add_subdirectory(ControlFlow)
add_subdirectory(DLTI)
add_subdirectory(EmitC)
add_subdirectory(Func)
+add_subdirectory(GEN)
add_subdirectory(GPU)
add_subdirectory(Index)
add_subdirectory(IRDL)
diff --git a/mlir/include/mlir/Dialect/GEN/CMakeLists.txt b/mlir/include/mlir/Dialect/GEN/CMakeLists.txt
new file mode 100644
index 00000000000000..f33061b2d87cff
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/include/mlir/Dialect/GEN/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/GEN/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..b441e52a143487
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect(GENOps gen)
+add_mlir_doc(GENDialect GENDialect Dialects/ -gen-dialect-doc)
+add_mlir_doc(GENOps GENOps Dialects/ -gen-op-doc)
+
+set(LLVM_TARGET_DEFINITIONS GENOps.td)
+mlir_tablegen(GENOpsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(GENOpsEnums.cpp.inc -gen-enum-defs)
+add_public_tablegen_target(MLIRGENOpsEnumsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS GENAttrDefs.td)
+mlir_tablegen(GENOpsAttrDefs.h.inc -gen-attrdef-decls)
+mlir_tablegen(GENOpsAttrDefs.cpp.inc -gen-attrdef-defs)
+add_public_tablegen_target(MLIRGENOpsAttrDefsIncGen)
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENAttrDefs.td b/mlir/include/mlir/Dialect/GEN/IR/GENAttrDefs.td
new file mode 100644
index 00000000000000..ef9de1cf581634
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENAttrDefs.td
@@ -0,0 +1,26 @@
+//===-- GENAttrDefs.td - GEN dialect attributes def. file --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef GEN_ATTRDEFS
+#define GEN_ATTRDEFS
+
+include "mlir/IR/EnumAttr.td"
+
+/// Enum attribute of the different shuffle kinds.
+/// Based on https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_non_uniform_instructions
+def GEN_ShflKindAttr : I32EnumAttr<"ShflKind", "GEN shuffle kind",
+ [
+ I32EnumAttrCase<"XOR", 0, "xor">,
+ I32EnumAttrCase<"UP", 1, "up">,
+ I32EnumAttrCase<"DOWN", 2, "down">,
+ I32EnumAttrCase<"IDX", 3, "idx">
+ ]> {
+ let cppNamespace = "::mlir::GEN";
+}
+
+#endif // GEN_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENDialect.h b/mlir/include/mlir/Dialect/GEN/IR/GENDialect.h
new file mode 100644
index 00000000000000..eee2e7e8b440ad
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENDialect.h
@@ -0,0 +1,38 @@
+//===- GENDialect.h - MLIR GEN dialect -----------------------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the GEN dialect in MLIR, containing Intel GEN operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_GEN_IR_GENDIALECT_H
+#define MLIR_DIALECT_GEN_IR_GENDIALECT_H
+
+#include "mlir/IR/Dialect.h"
+
+#include "mlir/Dialect/GEN/IR/GENOpsDialect.h.inc"
+
+namespace mlir {
+namespace GEN {
+
+/// GEN memory space identifiers following SPIRV storage class convention
+/// https://github.com/KhronosGroup/SPIRV-LLVM-Translator/blob/main/docs/SPIRVRepresentationInLLVM.rst#address-spaces
+///
+enum class GENStorageClass {
+ Function = 0, // OpenCL workitem address space
+ CrossWorkgroup = 1, // OpenCL Global memory
+ UniformConstant = 2, // OpenCL Constant memory
+ Workgroup = 3, // OpenCL Local memory
+ Generic = 4 // OpenCL Generic memory
+};
+
+} // namespace GEN
+} // namespace mlir
+
+#endif // MLIR_DIALECT_GEN_IR_GENDIALECT_H
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENDialect.td b/mlir/include/mlir/Dialect/GEN/IR/GENDialect.td
new file mode 100644
index 00000000000000..3ba6248decb7a4
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENDialect.td
@@ -0,0 +1,44 @@
+//===-- GENDialect.td - GEN dialect op definition file -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef GEN_DIALECT
+#define GEN_DIALECT
+
+include "mlir/IR/OpBase.td"
+
+def GEN_Dialect : Dialect {
+ let name = "gen";
+ let cppNamespace = "::mlir::GEN";
+ let summary = "The GEN dialect.";
+
+ let description = [{
+ GEN is a dialect for representing operations on Intel GPUs.
+ }];
+
+ let extraClassDeclaration = [{
+ /// Get the name of the attribute used to annotate max work group size
+ /// required for kernels.
+ static StringRef getMaxWorkGroupSizeAttrName() {
+ return "gen.max_work_group_size";
+ }
+
+ /// Get the name of the attribute used to annotate exact work group size
+ /// required for kernels.
+ static StringRef getReqdWorkGroupSizeAttrName() {
+ return "gen.reqd_work_group_size";
+ }
+
+ /// Get the name for the attribute used to annotate the exact sub group
+ /// size required for kernels.
+ static StringRef getReqdSubGroupSizeAttrName() {
+ return "gen.intel_reqd_sub_group_size";
+ }
+ }];
+}
+
+#endif // GEN_DIALECT
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENOps.h b/mlir/include/mlir/Dialect/GEN/IR/GENOps.h
new file mode 100644
index 00000000000000..be81cd111de7e3
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENOps.h
@@ -0,0 +1,26 @@
+//===--- GENOps.h - GEN Dialect Operations ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_GEN_IR_GENOPS_H
+#define MLIR_DIALECT_GEN_IR_GENOPS_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include "mlir/Dialect/GEN/IR/GENOpsEnums.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/GEN/IR/GENOpsAttrDefs.h.inc"
+
+#include "mlir/Dialect/GEN/IR/GENTraits.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/GEN/IR/GENOps.h.inc"
+
+#endif // MLIR_DIALECT_GEN_IR_GENOPS_H
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENOps.td b/mlir/include/mlir/Dialect/GEN/IR/GENOps.td
new file mode 100644
index 00000000000000..f3328f17899685
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENOps.td
@@ -0,0 +1,130 @@
+//===-- GENOps.td - GEN IR dialect op definition file ------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the GEN IR operation definition file.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef GEN_OPS
+#define GEN_OPS
+
+include "mlir/Dialect/GEN/IR/GENDialect.td"
+include "mlir/Dialect/GEN/IR/GENAttrDefs.td"
+include "mlir/IR/OpBase.td"
+include "mlir/IR/EnumAttr.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpAsmInterface.td"
+
+//===----------------------------------------------------------------------===//
+// GEN op definitions
+//===----------------------------------------------------------------------===//
+
+class GEN_Op<string mnemonic, list<Trait> traits = []> :
+ LLVM_OpBase<GEN_Dialect, mnemonic, traits> {
+}
+
+//===----------------------------------------------------------------------===//
+// ND-Range Operations
+//===----------------------------------------------------------------------===//
+
+def GEN3DNDRange : NativeOpTrait<"GEN3DNDRange">;
+
+class GEN_3DNDRangeOp<string mnemonic, list<Trait> traits = []>
+ : GEN_Op<mnemonic, [GEN3DNDRange, Pure] # traits>,
+ Arguments<(ins I32:$dim)>,
+ Results<(outs Index:$res)> {
+ let assemblyFormat = "$dim attr-dict";
+}
+
+def GEN_LocalIdOp : GEN_3DNDRangeOp<"local_id"> {
+ let summary = "Query a work-item's local id.";
+ let description = [{
+ Query the work-item's position in its work-group, i.e., its local id, in a
+ given dimension.
+ ```mlir
+ %local_id = gen.local_id %dim
+ ```
+ }];
+}
+
+def GEN_WorkGroupIdOp : GEN_3DNDRangeOp<"work_group_id"> {
+ let summary = "Query a work-item's work-group id.";
+ let description = [{
+ Query the work-item's work-group id in a given dimension.
+ ```mlir
+ %work_group_id = gen.work_group_id %dim
+ ```
+ }];
+}
+
+def GEN_WorkGroupSizeOp : GEN_3DNDRangeOp<"work_group_size"> {
+ let summary = "Query the work-group size.";
+ let description = [{
+ Query the work-item's work-group size in a given dimension.
+ ```mlir
+ %work_group_size = gen.work_group_size %dim
+ ```
+ }];
+}
+
+def GEN_NumWorkGroupsOp : GEN_3DNDRangeOp<"num_work_groups"> {
+ let summary = "Query the number of work-groups in the ND-range.";
+ let description = [{
+ Query the number of work-groups in the ND-range in a given dimension.
+ ```mlir
+ %wg_number = gen.num_work_groups %dim
+ ```
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Synchronization
+//===----------------------------------------------------------------------===//
+
+def GEN_BarrierOp : GEN_Op<"barrier"> {
+ let summary = "Workgroup barrier";
+
+ string baseDescription = [{
+ The `gen.barrier` operation performs a workgroup barrier and ensures all
+ outstanding memory transaction using local or global memory are complete.
+ }];
+
+ let arguments = (ins);
+ let results = (outs);
+ let assemblyFormat = "attr-dict";
+}
+
+def IntegerOrFloatType : AnyTypeOf<[AnySignlessInteger, AnyFloat]>;
+
+def GEN_SubGroupShuffleOp : GEN_Op<"sub_group_shuffle", [
+ TypesMatchWith<"result and value have the same type",
+ "res", "value", "$_self">]>,
+ Results<(outs IntegerOrFloatType:$res)>,
+ Arguments<(ins IntegerOrFloatType:$value,
+ I32:$mask,
+ GEN_ShflKindAttr:$kind)> {
+ let summary = "Subgroup shuffle";
+ string baseDescription = [{
+ The `gen.sub_group_shuffle` operation is invoked by different work items
+ with different values, given by $value. Different work items have different
+ subgroup local IDs. The shuffle kind, $kind, is given to determine how to
+ calculate the associated subgroup local ID. It returns the associated
+ $value for the work item with subgroup local ID equal to:
+ - $kind == xor, the current invocation’s subgroup local ID xor’ed with $mask.
+ - $kind == up, the current invocation’s subgroup local ID - $mask.
+ - $kind == down, the current invocation’s subgroup local ID + $mask.
+ - $kind == idx, the subgroup local ID $mask.
+ }];
+
+ let assemblyFormat = [{
+ $kind $value `,` $mask attr-dict `:` type($res)
+ }];
+}
+
+#endif // GEN_OPS
diff --git a/mlir/include/mlir/Dialect/GEN/IR/GENTraits.h b/mlir/include/mlir/Dialect/GEN/IR/GENTraits.h
new file mode 100644
index 00000000000000..3b7754887749bf
--- /dev/null
+++ b/mlir/include/mlir/Dialect/GEN/IR/GENTraits.h
@@ -0,0 +1,30 @@
+//===--- GENTraits.h - GEN Dialect Traits -----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_GEN_IR_GENTRAITS_H
+#define MLIR_DIALECT_GEN_IR_GENTRAITS_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace OpTrait {
+namespace detail {
+LogicalResult verifyGEN3DNDRange(Operation *op);
+} // namespace detail
+
+template <typename ConcreteType>
+class GEN3DNDRange : public TraitBase<ConcreteType, GEN3DNDRange> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ return detail::verifyGEN3DNDRange(op);
+ }
+};
+} // namespace OpTrait
+} // namespace mlir
+
+#endif // MLIR_DIALECT_GEN_IR_GENTRAITS_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index c558dc53cc7fac..a940d134024165 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -36,6 +36,7 @@
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GEN/IR/GENDialect.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/BufferDeallocationOpInterfaceImpl.h"
#include "mlir/Dialect/IRDL/IR/IRDL.h"
@@ -116,6 +117,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
DLTIDialect,
emitc::EmitCDialect,
func::FuncDialect,
+ GEN::GENDialect,
gpu::GPUDialect,
index::IndexDialect,
irdl::IRDLDialect,
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 41ab7046b91ce3..0d59ce960090c4 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -22,6 +22,8 @@ add_subdirectory(ConvertToLLVM)
add_subdirectory(FuncToEmitC)
add_subdirectory(FuncToLLVM)
add_subdirectory(FuncToSPIRV)
+add_subdirectory(GENToLLVM)
+add_subdirectory(GENToSPIRV)
add_subdirectory(GPUCommon)
add_subdirectory(GPUToNVVM)
add_subdirectory(GPUToROCDL)
diff --git a/mlir/lib/Conversion/GENToLLVM/CMakeLists.txt b/mlir/lib/Conversion/GENToLLVM/CMakeLists.txt
new file mode 100644
index 00000000000000..4ec773ccf46dd0
--- /dev/null
+++ b/mlir/lib/Conversion/GENToLLVM/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_conversion_library(MLIRGENToLLVM
+ GENToLLVMPass.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/GENToLLVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRGENDialect
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+)
diff --git a/mlir/lib/Conversion/GENToLLVM/GENToLLVMPass.cpp b/mlir/lib/Conversion/GENToLLVM/GENToLLVMPass.cpp
new file mode 100644
index 00000000000000..244817d8cf00c6
--- /dev/null
+++ b/mlir/lib/Conversion/GENToLLVM/GENToLLVMPass.cpp
@@ -0,0 +1,280 @@
+//===- GENToLLVMPass.cpp - MLIR GEN to LLVM dialect conversion ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GENToLLVM/GENToLLVMPass.h"
+
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/GEN/IR/GENDialect.h"
+#include "mlir/Dialect/GEN/IR/GENOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/ErrorHandling.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTGENTOLLVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Helper Functions
+//===----------------------------------------------------------------------===//
+
+static LLVM::CallOp createDeviceFunctionCall(
+ ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
+ ArrayRef<Type> argTypes, ArrayRef<Value> args, bool convergent = false) {
+ auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
+ MLIRContext *context = rewriter.getContext();
+ Location loc = UnknownLoc::get(context);
+ auto convergentAttr =
+ rewriter.getArrayAttr(StringAttr::get(context, "convergent"));
+
+ auto getOrCreateFunction = [&](StringRef funcName) {
+ Operation *funcOp = moduleOp.lookupSymbol(funcName);
+ if (funcOp)
+ return cast<LLVM::LLVMFuncOp>(funcOp);
+
+ auto funcType = LLVM::LLVMFunctionType::get(retType, argTypes);
+ ConversionPatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(moduleOp.getBody());
+ auto func = rewriter.create<LLVM::LLVMFuncOp>(loc, funcName, funcType);
+ func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
+ if (convergent)
+ func.setPassthroughAttr(convergentAttr);
+
+ return func;
+ };
+
+ LLVM::LLVMFuncOp funcOp = getOrCreateFunction(funcName);
+ auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args);
+ if (convergent)
+ callOp->setAttr("passthrough", convergentAttr);
+
+ return callOp;
+}
+
+static LLVM::CallOp createSubGroupShuffle(ConversionPatternRewriter &rewriter,
+ Value value, Value mask,
+ GEN::ShflKind kind) {
+ assert(isa<IntegerType>(mask.getType()) &&
+ cast<IntegerType>(mask.getType()).isInteger(32) &&
+ "Expecting mask type to be i32");
+
+ std::string fnName = "";
+ switch (kind) {
+ case GEN::ShflKind::XOR:
+ fnName = "_Z21sub_group_shuffle_xor";
+ break;
+ case GEN::ShflKind::UP:
+ fnName = "_Z20sub_group_shuffle_up";
+ break;
+ case GEN::ShflKind::DOWN:
+ fnName = "_Z22sub_group_shuffle_down";
+ break;
+ case GEN::ShflKind::IDX:
+ fnName = "_Z17sub_group_shuffle";
+ break;
+ }
+
+ TypeSwitch<Type>(value.getType())
+ .Case<Float16Type>([&](auto) { fnName += "Dh"; })
+ .Case<Float32Type>([&](auto) { fnName += "f"; })
+ .Case<Float64Type>([&](auto) { fnName += "d"; })
+ .Case<IntegerType>([&](auto ty) {
+ switch (ty.getWidth()) {
+ case 8:
+ fnName += "c";
+ break;
+ case 16:
+ fnName += "s";
+ break;
+ case 32:
+ fnName += "i";
+ break;
+ case 64:
+ fnName += "l";
+ break;
+ default:
+ llvm_unreachable("unhandled integer type");
+ }
+ });
+
+ fnName += "j";
+
+ return createDeviceFunctionCall(rewriter, fnName, value.getType(),
+ {value.getType(), mask.getType()},
+ {value, mask}, true /*convergent*/);
+}
+
+static Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v) {
+ auto i32ty = rewriter.getIntegerType(32);
+ return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
+ IntegerAttr::get(i32ty, v));
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ND-range Ops Lowerings
+//===----------------------------------------------------------------------===//
+
+class GEN3DNDRangeLoweringBase : public ConvertToLLVMPattern {
+public:
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ assert(op->getNumOperands() == 1 && "Expecting a single operand");
+ Type resType = typeConverter->convertType(op->getResult(0).getType());
+ LLVM::CallOp callOp = createDeviceFunctionCall(
+ rewriter, builtinName, resType, rewriter.getI32Type(), operands[0]);
+ rewriter.replaceOp(op, callOp);
+ return success();
+ }
+
+protected:
+ GEN3DNDRangeLoweringBase(StringRef builtinName, StringRef rootOpName,
+ const LLVMTypeConverter &typeConverter,
+ PatternBenefit benefit)
+ : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(),
+ typeConverter, benefit),
+ builtinName(builtinName) {}
+
+private:
+ StringRef builtinName;
+};
+
+template <typename SourceOp>
+constexpr StringRef getBuiltinName();
+
+template <>
+StringRef getBuiltinName<GEN::LocalIdOp>() {
+ return "_Z12get_local_idj";
+}
+
+template <>
+StringRef getBuiltinName<GEN::WorkGroupIdOp>() {
+ return "_Z12get_group_idj";
+}
+
+template <>
+StringRef getBuiltinName<GEN::WorkGroupSizeOp>() {
+ return "_Z14get_local_sizej";
+}
+
+template <>
+StringRef getBuiltinName<GEN::NumWorkGroupsOp>() {
+ return "_Z14get_num_groupsj";
+}
+
+template <typename SourceOp>
+struct GEN3DNDRangeLowering : public GEN3DNDRangeLoweringBase {
+ GEN3DNDRangeLowering(const LLVMTypeConverter &typeConverter,
+ PatternBenefit benefit = 1)
+ : GEN3DNDRangeLoweringBase(getBuiltinName<SourceOp>(),
+ SourceOp::getOperationName(), typeConverter,
+ benefit) {}
+};
+
+//===----------------------------------------------------------------------===//
+// Synchronization Ops Lowerings
+//===----------------------------------------------------------------------===//
+
+struct GENBarrierLowering : public ConvertOpToLLVMPattern<GEN::BarrierOp> {
+ using ConvertOpToLLVMPattern<GEN::BarrierOp>::ConvertOpToLLVMPattern;
+
+ enum MemFence {
+ Local = 0x01,
+ Global = 0x02,
+ };
+
+ LogicalResult
+ matchAndRewrite(GEN::BarrierOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto retType = LLVM::LLVMVoidType::get(rewriter.getContext());
+ auto argType = rewriter.getIntegerType(32);
+ auto arg = createConstantI32(op->getLoc(), rewriter, MemFence::Local);
+ LLVM::CallOp callOp =
+ createDeviceFunctionCall(rewriter, "_Z7barrierj", {retType}, {argType},
+ {arg}, true /*convergent*/);
+ rewriter.replaceOp(op, callOp);
+ return success();
+ }
+};
+
+struct SubGroupShuffleLowering
+ : public ConvertOpToLLVMPattern<GEN::SubGroupShuffleOp> {
+ using ConvertOpToLLVMPattern<GEN::SubGroupShuffleOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(GEN::SubGroupShuffleOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value val = op.getValue();
+ Value mask = op.getMask();
+ GEN::ShflKind kind = op.getKind();
+ LLVM::CallOp callOp = createSubGroupShuffle(rewriter, val, mask, kind);
+ rewriter.replaceOp(op, callOp);
+ return success();
+ }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ConvertGENToLLVM final
+ : public impl::ConvertGENToLLVMBase<ConvertGENToLLVM> {
+ using Base::Base;
+
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ RewritePatternSet pattern(context);
+ LowerToLLVMOptions options(context);
+ LLVMTypeConverter converter(context, options);
+ LLVMConversionTarget target(*context);
+
+ GEN::populateGENToLLVMConversionPatterns(converter, pattern);
+
+ if (failed(
+ applyPartialConversion(getOperation(), target, std::move(pattern))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern Population and Registration
+//===----------------------------------------------------------------------===//
+
+void mlir::GEN::populateGENToLLVMConversionPatterns(
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ patterns.add<GEN3DNDRangeLowering<GEN::LocalIdOp>,
+ GEN3DNDRangeLowering<GEN::WorkGroupIdOp>,
+ GEN3DNDRangeLowering<GEN::WorkGroupSizeOp>,
+ GEN3DNDRangeLowering<GEN::NumWorkGroupsOp>, GENBarrierLowering,
+ SubGroupShuffleLowering>(converter);
+}
+
+std::unique_ptr<Pass> mlir::GEN::createConvertGENToLLVM() {
+ return std::make_unique<ConvertGENToLLVM>();
+}
diff --git a/mlir/lib/Conversion/GENToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/GENToSPIRV/CMakeLists.txt
new file mode 100644
index 00000000000000..3afb89ad958a64
--- /dev/null
+++ b/mlir/lib/Conversion/GENToSPIRV/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(MLIRGENToSPIRV
+ GENToSPIRV.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/GENToSPIRV
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRGENDialect
+ MLIRSPIRVConversion
+ MLIRSPIRVDialect
+ )
diff --git a/mlir/lib/Conversion/GENToSPIRV/GENToSPIRV.cpp b/mlir/lib/Conversion/GENToSPIRV/GENToSPIRV.cpp
new file mode 100644
index 00000000000000..37077c6f1cab8f
--- /dev/null
+++ b/mlir/lib/Conversion/GENToSPIRV/GENToSPIRV.cpp
@@ -0,0 +1,131 @@
+//===- GENToSPIRV.cpp - GEN to SPIRV dialect conversion -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GENToSPIRV/GENToSPIRV.h"
+
+#include "mlir/Dialect/GEN/IR/GENDialect.h"
+#include "mlir/Dialect/GEN/IR/GENOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTGENTOSPIRV
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "gen-to-spirv-pattern"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// ND-range Ops Lowerings
+//===----------------------------------------------------------------------===//
+
+/// Pattern to convert GEN3DNDRange operations to SPIR-V.
+///
+/// Convert:
+/// ```mlir
+/// %0 = gen.operation_name %dim
+/// ```
+/// To:
+/// ```mlir
+/// %__spirv_BuiltinName___addr = spirv.mlir.addressof
+/// @__spirv_BuiltInBuiltinName : !spirv.ptr<vector<3xIndexType>, Input>
+/// %__builtin_value = spirv.Load "Input" %__builtin__BuiltinName___addr :
+/// vector<3xIndexType>
+/// %0 = spirv.VectorExtractDynamic %__builtin_value[%dim] :
+/// vector<3xIndexType>, i32
+/// ```
+/// With `BuiltinName` the name of a SPIR-V builtin, and `IndexType`, `i32` for
+/// 32-bit targets and `i64` for 64-bit targets.
+class GEN3DNDRangeLoweringBase : public ConversionPattern {
+public:
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ assert(operands.size() == 1 && "Expecting a single operand");
+ // The builtin variable must be of type <3xi32> for 32-bit targets and
+ // <3xi64> for 64-bit targets.
+ Type builtinType =
+ this->template getTypeConverter<SPIRVTypeConverter>()->getIndexType();
+ constexpr StringLiteral spvBuiltinPrefix = "__spirv_BuiltIn";
+ constexpr StringLiteral spvBuiltinSuffix = "";
+ Value vector = spirv::getBuiltinVariableValue(
+ op, builtin, builtinType, rewriter, spvBuiltinPrefix, spvBuiltinSuffix);
+ rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(op, vector,
+ operands[0]);
+ return success();
+ }
+
+protected:
+ GEN3DNDRangeLoweringBase(spirv::BuiltIn builtin,
+ const TypeConverter &typeConverter, StringRef opName,
+ PatternBenefit benefit, MLIRContext *context)
+ : ConversionPattern(typeConverter, opName, benefit, context),
+ builtin(builtin) {}
+
+private:
+ spirv::BuiltIn builtin;
+};
+
+template <typename SourceOp, spirv::BuiltIn Builtin>
+struct GEN3DNDRangeLowering : public GEN3DNDRangeLoweringBase {
+ GEN3DNDRangeLowering(const TypeConverter &typeConverter, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : GEN3DNDRangeLoweringBase(Builtin, typeConverter,
+ SourceOp::getOperationName(), benefit,
+ context) {}
+};
+
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+
+void mlir::GEN::populateGENToSPIRVPatterns(SPIRVTypeConverter &converter,
+ RewritePatternSet &patterns) {
+ patterns.add<
+ GEN3DNDRangeLowering<GEN::LocalIdOp, spirv::BuiltIn::LocalInvocationId>,
+ GEN3DNDRangeLowering<GEN::WorkGroupIdOp, spirv::BuiltIn::WorkgroupId>,
+ GEN3DNDRangeLowering<GEN::WorkGroupSizeOp, spirv::BuiltIn::WorkgroupSize>,
+ GEN3DNDRangeLowering<GEN::NumWorkGroupsOp,
+ spirv::BuiltIn::NumWorkgroups>>(
+ converter, patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct ConvertGENToSPIRVPass
+ : public impl::ConvertGENToSPIRVBase<ConvertGENToSPIRVPass> {
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
+ std::unique_ptr<SPIRVConversionTarget> target =
+ SPIRVConversionTarget::get(targetAttr);
+
+ SPIRVTypeConverter typeConverter(targetAttr);
+
+ // Fail hard when there are any remaining GEN ops.
+ target->addIllegalDialect<GEN::GENDialect>();
+
+ RewritePatternSet patterns(&getContext());
+ GEN::populateGENToSPIRVPatterns(typeConverter, patterns);
+
+ if (failed(applyPartialConversion(op, *target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<>> mlir::GEN::createConvertGENToSPIRVPass() {
+ return std::make_unique<ConvertGENToSPIRVPass>();
+}
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index b1ba5a3bc8817d..72e11b4eb33dfa 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -12,6 +12,7 @@ add_subdirectory(ControlFlow)
add_subdirectory(DLTI)
add_subdirectory(EmitC)
add_subdirectory(Func)
+add_subdirectory(GEN)
add_subdirectory(GPU)
add_subdirectory(Index)
add_subdirectory(IRDL)
diff --git a/mlir/lib/Dialect/GEN/CMakeLists.txt b/mlir/lib/Dialect/GEN/CMakeLists.txt
new file mode 100644
index 00000000000000..f33061b2d87cff
--- /dev/null
+++ b/mlir/lib/Dialect/GEN/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/GEN/IR/CMakeLists.txt b/mlir/lib/Dialect/GEN/IR/CMakeLists.txt
new file mode 100644
index 00000000000000..160c6eb2c53bd5
--- /dev/null
+++ b/mlir/lib/Dialect/GEN/IR/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRGENDialect
+ GENDialect.cpp
+ GENOps.cpp
+ GENTraits.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GEN
+
+ DEPENDS
+ MLIRGENOpsIncGen
+ MLIRGENOpsEnumsIncGen
+ MLIRGENOpsAttrDefsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRLLVMDialect
+)
diff --git a/mlir/lib/Dialect/GEN/IR/GENDialect.cpp b/mlir/lib/Dialect/GEN/IR/GENDialect.cpp
new file mode 100644
index 00000000000000..351e5a32a67505
--- /dev/null
+++ b/mlir/lib/Dialect/GEN/IR/GENDialect.cpp
@@ -0,0 +1,42 @@
+//===- GENDialect.cpp - MLIR GEN Dialect implementation -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GEN/IR/GENDialect.h"
+#include "mlir/Dialect/GEN/IR/GENOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Function.h"
+#include "llvm/Support/SourceMgr.h"
+
+using namespace mlir;
+using namespace mlir::GEN;
+
+#include "mlir/Dialect/GEN/IR/GENOpsDialect.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// GEN dialect.
+//===----------------------------------------------------------------------===//
+
+void GENDialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/GEN/IR/GENOps.cpp.inc"
+ >();
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/GEN/IR/GENOpsAttrDefs.cpp.inc"
+ >();
+}
diff --git a/mlir/lib/Dialect/GEN/IR/GENOps.cpp b/mlir/lib/Dialect/GEN/IR/GENOps.cpp
new file mode 100644
index 00000000000000..f7827e0520fb27
--- /dev/null
+++ b/mlir/lib/Dialect/GEN/IR/GENOps.cpp
@@ -0,0 +1,17 @@
+//===- GENOps.cpp - GEN dialect operations --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GEN/IR/GENOps.h"
+#include "mlir/Dialect/GEN/IR/GENDialect.h"
+#include "mlir/IR/Builders.h"
+
+#include "mlir/Dialect/GEN/IR/GENOpsEnums.cpp.inc"
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/GEN/IR/GENOpsAttrDefs.cpp.inc"
+#define GET_OP_CLASSES
+#include "mlir/Dialect/GEN/IR/GENOps.cpp.inc"
diff --git a/mlir/lib/Dialect/GEN/IR/GENTraits.cpp b/mlir/lib/Dialect/GEN/IR/GENTraits.cpp
new file mode 100644
index 00000000000000..d1a01675c3d08a
--- /dev/null
+++ b/mlir/lib/Dialect/GEN/IR/GENTraits.cpp
@@ -0,0 +1,24 @@
+//===- GENTraits.cpp - GEN dialect traits ---------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GEN/IR/GENTraits.h"
+
+#include "mlir/IR/Matchers.h"
+
+using namespace mlir;
+
+LogicalResult mlir::OpTrait::detail::verifyGEN3DNDRange(Operation *op) {
+ llvm::APInt value;
+ if (matchPattern(op->getOperand(0), m_ConstantInt(&value)) &&
+ !(/*value in [0, 3)*/ value.sge(0) && value.slt(3))) {
+ return op->emitOpError()
+ << "input dimension must be in the range [0, 3). Got "
+ << value.getSExtValue();
+ }
+ return success();
+}
diff --git a/mlir/test/Conversion/GENToLLVM/gen-to-llvm.mlir b/mlir/test/Conversion/GENToLLVM/gen-to-llvm.mlir
new file mode 100644
index 00000000000000..16679487032212
--- /dev/null
+++ b/mlir/test/Conversion/GENToLLVM/gen-to-llvm.mlir
@@ -0,0 +1,85 @@
+// RUN: mlir-opt -convert-gen-to-llvm -split-input-file %s | FileCheck %s
+
+llvm.func @gen_nd_range(%dim: i32) {
+ // CHECK-LABEL: gen_nd_range
+ // CHECK-SAME: (%[[DIM:.*]]: i32)
+ // CHECK: llvm.call @_Z12get_local_idj(%[[DIM]]) : (i32) -> i64
+ %0 = gen.local_id %dim
+ // CHECK: llvm.call @_Z12get_group_idj(%[[DIM]]) : (i32) -> i64
+ %1 = gen.work_group_id %dim
+ // CHECK: llvm.call @_Z14get_local_sizej(%[[DIM]]) : (i32) -> i64
+ %2 = gen.work_group_size %dim
+ // CHECK: llvm.call @_Z14get_num_groupsj(%[[DIM]]) : (i32) -> i64
+ %3 = gen.num_work_groups %dim
+ llvm.return
+}
+
+// -----
+
+// CHECK: llvm.func spir_funccc @_Z7barrierj(i32) attributes {passthrough = ["convergent"]}
+
+llvm.func @gen.barrier() {
+ // CHECK-LABEL: gen.barrier
+ // CHECK: [[CST:%.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: llvm.call @_Z7barrierj([[CST]]) {passthrough = ["convergent"]} : (i32) -> ()
+ gen.barrier
+ llvm.return
+}
+
+// -----
+
+// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xordj(f64, i32) -> f64 attributes {passthrough = ["convergent"]}
+// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorfj(f32, i32) -> f32 attributes {passthrough = ["convergent"]}
+// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorDhj(f16, i32) -> f16 attributes {passthrough = ["convergent"]}
+// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorlj(i64, i32) -> i64 attributes {passthrough = ["convergent"]}
+// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorsj(i16, i32) -> i16 attributes {passthrough = ["convergent"]}
+// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorcj(i8, i32) -> i8 attributes {passthrough = ["convergent"]}
+// CHECK-DAG: llvm.func spir_funccc @_Z17sub_group_shuffleij(i32, i32) -> i32 attributes {passthrough = ["convergent"]}
+// CHECK-DAG: llvm.func spir_funccc @_Z22sub_group_shuffle_downij(i32, i32) -> i32 attributes {passthrough = ["convergent"]}
+// CHECK-DAG: llvm.func spir_funccc @_Z20sub_group_shuffle_upij(i32, i32) -> i32 attributes {passthrough = ["convergent"]}
+// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorij(i32, i32) -> i32 attributes {passthrough = ["convergent"]}
+
+llvm.func @gen.sub_group_shuffle() {
+ // CHECK-LABEL: gen.sub_group_shuffle
+ %0 = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: llvm.call @_Z21sub_group_shuffle_xorij([[ZERO]], [[ZERO]]) {passthrough = ["convergent"]} : (i32, i32) -> i32
+ // CHECK: llvm.call @_Z20sub_group_shuffle_upij([[ZERO]], [[ZERO]]) {passthrough = ["convergent"]} : (i32, i32) -> i32
+ // CHECK: llvm.call @_Z22sub_group_shuffle_downij([[ZERO]], [[ZERO]]) {passthrough = ["convergent"]} : (i32, i32) -> i32
+ // CHECK: llvm.call @_Z17sub_group_shuffleij([[ZERO]], [[ZERO]]) {passthrough = ["convergent"]} : (i32, i32) -> i32
+ %1 = gen.sub_group_shuffle xor %0, %0 : i32
+ %2 = gen.sub_group_shuffle up %0, %0 : i32
+ %3 = gen.sub_group_shuffle down %0, %0 : i32
+ %4 = gen.sub_group_shuffle idx %0, %0 : i32
+
+ // CHECK: [[ZERO1:%.*]] = llvm.mlir.constant(0 : i8) : i8
+ // CHECK: llvm.call @_Z21sub_group_shuffle_xorcj([[ZERO1]], [[ZERO]]) {passthrough = ["convergent"]} : (i8, i32) -> i8
+ %5 = llvm.mlir.constant(0 : i8) : i8
+ %6 = gen.sub_group_shuffle xor %5, %0 : i8
+
+ // CHECK: [[ZERO2:%.*]] = llvm.mlir.constant(0 : i16) : i16
+ // CHECK: llvm.call @_Z21sub_group_shuffle_xorsj([[ZERO2]], [[ZERO]]) {passthrough = ["convergent"]} : (i16, i32) -> i16
+ %7 = llvm.mlir.constant(0 : i16) : i16
+ %8 = gen.sub_group_shuffle xor %7, %0 : i16
+
+ // CHECK: [[ZERO3:%.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK: llvm.call @_Z21sub_group_shuffle_xorlj([[ZERO3]], [[ZERO]]) {passthrough = ["convergent"]} : (i64, i32) -> i64
+ %9 = llvm.mlir.constant(0 : i64) : i64
+ %10 = gen.sub_group_shuffle xor %9, %0 : i64
+
+ // CHECK: [[ZERO4:%.*]] = llvm.mlir.constant(0.000000e+00 : f16) : f16
+ // CHECK: llvm.call @_Z21sub_group_shuffle_xorDhj([[ZERO4]], [[ZERO]]) {passthrough = ["convergent"]} : (f16, i32) -> f16
+ %11 = llvm.mlir.constant(0.0 : f16) : f16
+ %12 = gen.sub_group_shuffle xor %11, %0 : f16
+
+ // CHECK: [[ZERO5:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
+ // CHECK: llvm.call @_Z21sub_group_shuffle_xorfj([[ZERO5]], [[ZERO]]) {passthrough = ["convergent"]} : (f32, i32) -> f32
+ %13 = llvm.mlir.constant(0.0 : f32) : f32
+ %14 = gen.sub_group_shuffle xor %13, %0 : f32
+
+ // CHECK: [[ZERO6:%.*]] = llvm.mlir.constant(0.000000e+00 : f64) : f64
+ // CHECK: llvm.call @_Z21sub_group_shuffle_xordj([[ZERO6]], [[ZERO]]) {passthrough = ["convergent"]} : (f64, i32) -> f64
+ %15 = llvm.mlir.constant(0.0 : f64) : f64
+ %16 = gen.sub_group_shuffle xor %15, %0 : f64
+ llvm.return
+}
diff --git a/mlir/test/Conversion/GENToSPIRV/gen-to-spirv.mlir b/mlir/test/Conversion/GENToSPIRV/gen-to-spirv.mlir
new file mode 100644
index 00000000000000..68f6b8efcdadf6
--- /dev/null
+++ b/mlir/test/Conversion/GENToSPIRV/gen-to-spirv.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt -convert-gen-to-spirv -split-input-file %s | FileCheck %s
+
+// CHECK-DAG: spirv.GlobalVariable @__spirv_BuiltInNumWorkgroups built_in("NumWorkgroups") : !spirv.ptr<vector<3xi32>, Input>
+// CHECK-DAG: spirv.GlobalVariable @__spirv_BuiltInWorkgroupSize built_in("WorkgroupSize") : !spirv.ptr<vector<3xi32>, Input>
+// CHECK-DAG: spirv.GlobalVariable @__spirv_BuiltInWorkgroupId built_in("WorkgroupId") : !spirv.ptr<vector<3xi32>, Input>
+// CHECK-DAG: spirv.GlobalVariable @__spirv_BuiltInLocalInvocationId built_in("LocalInvocationId") : !spirv.ptr<vector<3xi32>, Input>
+
+// CHECK-LABEL: func.func @gen_nd_range(
+// CHECK-SAME: %[[VAL_0:.*]]: i32) {
+func.func @gen_nd_range(%dim: i32) {
+// CHECK: %[[VAL_1:.*]] = spirv.mlir.addressof @__spirv_BuiltInLocalInvocationId : !spirv.ptr<vector<3xi32>, Input>
+// CHECK: %[[VAL_2:.*]] = spirv.Load "Input" %[[VAL_1]] : vector<3xi32>
+// CHECK: %[[VAL_3:.*]] = spirv.VectorExtractDynamic %[[VAL_2]]{{\[}}%[[VAL_0]]] : vector<3xi32>, i32
+ %0 = gen.local_id %dim
+// CHECK: %[[VAL_4:.*]] = spirv.mlir.addressof @__spirv_BuiltInWorkgroupId : !spirv.ptr<vector<3xi32>, Input>
+// CHECK: %[[VAL_5:.*]] = spirv.Load "Input" %[[VAL_4]] : vector<3xi32>
+// CHECK: %[[VAL_6:.*]] = spirv.VectorExtractDynamic %[[VAL_5]]{{\[}}%[[VAL_0]]] : vector<3xi32>, i32
+ %1 = gen.work_group_id %dim
+// CHECK: %[[VAL_7:.*]] = spirv.mlir.addressof @__spirv_BuiltInWorkgroupSize : !spirv.ptr<vector<3xi32>, Input>
+// CHECK: %[[VAL_8:.*]] = spirv.Load "Input" %[[VAL_7]] : vector<3xi32>
+// CHECK: %[[VAL_9:.*]] = spirv.VectorExtractDynamic %[[VAL_8]]{{\[}}%[[VAL_0]]] : vector<3xi32>, i32
+ %2 = gen.work_group_size %dim
+// CHECK: %[[VAL_10:.*]] = spirv.mlir.addressof @__spirv_BuiltInNumWorkgroups : !spirv.ptr<vector<3xi32>, Input>
+// CHECK: %[[VAL_11:.*]] = spirv.Load "Input" %[[VAL_10]] : vector<3xi32>
+// CHECK: %[[VAL_12:.*]] = spirv.VectorExtractDynamic %[[VAL_11]]{{\[}}%[[VAL_0]]] : vector<3xi32>, i32
+ %3 = gen.num_work_groups %dim
+ func.return
+}
diff --git a/mlir/test/Dialect/GEN/gen.mlir b/mlir/test/Dialect/GEN/gen.mlir
new file mode 100644
index 00000000000000..e4d3dc76d4b560
--- /dev/null
+++ b/mlir/test/Dialect/GEN/gen.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
+
+llvm.func @gen_nd_range(%dim: i32) {
+ // CHECK-LABEL: gen_nd_range
+ // CHECK-SAME: (%[[DIM:.*]]: i32)
+ // CHECK: gen.local_id %[[DIM]]
+ %0 = gen.local_id %dim
+ // CHECK: gen.work_group_id %[[DIM]]
+ %1 = gen.work_group_id %dim
+ // CHECK: gen.work_group_size %[[DIM]]
+ %2 = gen.work_group_size %dim
+ // CHECK: gen.num_work_groups %[[DIM]]
+ %3 = gen.num_work_groups %dim
+ llvm.return
+}
+
+llvm.func @gen.barrier() {
+ // CHECK-LABEL: gen.barrier
+ // CHECK: gen.barrier
+ gen.barrier
+ llvm.return
+}
+
+llvm.func @gen.sub_group_shuffle() {
+ // CHECK-LABEL: gen.sub_group_shuffle
+ %0 = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %1 = gen.sub_group_shuffle xor %0, %0 : i32
+ %1 = gen.sub_group_shuffle xor %0, %0 : i32
+ // CHECK: %2 = gen.sub_group_shuffle up %0, %0 : i32
+ %2 = gen.sub_group_shuffle up %0, %0 : i32
+ // CHECK: %3 = gen.sub_group_shuffle down %0, %0 : i32
+ %3 = gen.sub_group_shuffle down %0, %0 : i32
+ // CHECK: %4 = gen.sub_group_shuffle idx %0, %0 : i32
+ %4 = gen.sub_group_shuffle idx %0, %0 : i32
+ %5 = llvm.mlir.constant(0 : i8) : i8
+ // CHECK: %6 = gen.sub_group_shuffle xor %5, %0 : i8
+ %6 = gen.sub_group_shuffle xor %5, %0 : i8
+ %7 = llvm.mlir.constant(0 : i16) : i16
+ // CHECK: %8 = gen.sub_group_shuffle xor %7, %0 : i16
+ %8 = gen.sub_group_shuffle xor %7, %0 : i16
+ %9 = llvm.mlir.constant(0 : i64) : i64
+ // CHECK: %10 = gen.sub_group_shuffle xor %9, %0 : i64
+ %10 = gen.sub_group_shuffle xor %9, %0 : i64
+ %11 = llvm.mlir.constant(0.0 : f16) : f16
+ // CHECK: %12 = gen.sub_group_shuffle xor %11, %0 : f16
+ %12 = gen.sub_group_shuffle xor %11, %0 : f16
+ %13 = llvm.mlir.constant(0.0 : f32) : f32
+ // CHECK: %14 = gen.sub_group_shuffle xor %13, %0 : f32
+ %14 = gen.sub_group_shuffle xor %13, %0 : f32
+ %15 = llvm.mlir.constant(0.0 : f64) : f64
+ // CHECK: %16 = gen.sub_group_shuffle xor %15, %0 : f64
+ %16 = gen.sub_group_shuffle xor %15, %0 : f64
+ llvm.return
+}
diff --git a/mlir/test/Dialect/GEN/invalid.mlir b/mlir/test/Dialect/GEN/invalid.mlir
new file mode 100644
index 00000000000000..1285f8af216647
--- /dev/null
+++ b/mlir/test/Dialect/GEN/invalid.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt -split-input-file %s -verify-diagnostics
+
+func.func @test_3d_nd_range_bounds_low() {
+ %c-1 = arith.constant -1 : i32
+ // expected-error @below {{'gen.local_id' op input dimension must be in the range [0, 3). Got -1}}
+ %0 = gen.local_id %c-1
+ func.return
+}
+
+// -----
+
+func.func @test_3d_nd_range_bounds_high() {
+ %c3 = arith.constant 3 : i32
+ // expected-error @below {{'gen.work_group_id' op input dimension must be in the range [0, 3). Got 3}}
+ %0 = gen.work_group_id %c3
+ func.return
+}
More information about the Mlir-commits
mailing list