[Mlir-commits] [mlir] 429d792 - [mlir] Add support for generating dialect declarations via tablegen.
River Riddle
llvmlistbot at llvm.org
Sat Mar 14 20:59:55 PDT 2020
Author: River Riddle
Date: 2020-03-14T20:36:44-07:00
New Revision: 429d792f23f2e72628cae763667bca60d69853e7
URL: https://github.com/llvm/llvm-project/commit/429d792f23f2e72628cae763667bca60d69853e7
DIFF: https://github.com/llvm/llvm-project/commit/429d792f23f2e72628cae763667bca60d69853e7.diff
LOG: [mlir] Add support for generating dialect declarations via tablegen.
Summary: This generates the class declarations for dialects using the existing 'Dialect' tablegen classes.
Differential Revision: https://reviews.llvm.org/D76185
Added:
mlir/include/mlir/Dialect/QuantOps/QuantOpsBase.td
mlir/tools/mlir-tblgen/DialectGen.cpp
Modified:
mlir/cmake/modules/AddMLIR.cmake
mlir/docs/CreatingADialect.md
mlir/include/mlir/Dialect/AffineOps/AffineOps.h
mlir/include/mlir/Dialect/AffineOps/AffineOps.td
mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt
mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt
mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h
mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td
mlir/include/mlir/Dialect/GPU/CMakeLists.txt
mlir/include/mlir/Dialect/GPU/GPUDialect.h
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt
mlir/include/mlir/Dialect/LoopOps/LoopOps.h
mlir/include/mlir/Dialect/LoopOps/LoopOps.td
mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt
mlir/include/mlir/Dialect/QuantOps/QuantOps.h
mlir/include/mlir/Dialect/QuantOps/QuantOps.td
mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h
mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td
mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt
mlir/include/mlir/Dialect/Shape/IR/Shape.h
mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt
mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt
mlir/include/mlir/Dialect/VectorOps/VectorOps.h
mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/Attribute.h
mlir/include/mlir/TableGen/Dialect.h
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/TableGen/Attribute.cpp
mlir/lib/TableGen/Dialect.cpp
mlir/tools/mlir-tblgen/CMakeLists.txt
Removed:
mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td
################################################################################
diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake
index 6354d9030367..4e06a0e743fe 100644
--- a/mlir/cmake/modules/AddMLIR.cmake
+++ b/mlir/cmake/modules/AddMLIR.cmake
@@ -28,10 +28,11 @@ function(whole_archive_link target)
endfunction(whole_archive_link)
# Declare a dialect in the include directory
-function(add_mlir_dialect dialect dialect_doc_filename)
+function(add_mlir_dialect dialect dialect_namespace dialect_doc_filename)
set(LLVM_TARGET_DEFINITIONS ${dialect}.td)
mlir_tablegen(${dialect}.h.inc -gen-op-decls)
mlir_tablegen(${dialect}.cpp.inc -gen-op-defs)
+ mlir_tablegen(${dialect}Dialect.h.inc -gen-dialect-decls -dialect=${dialect_namespace})
add_public_tablegen_target(MLIR${dialect}IncGen)
add_dependencies(mlir-headers MLIR${dialect}IncGen)
diff --git a/mlir/docs/CreatingADialect.md b/mlir/docs/CreatingADialect.md
index 3a987fd0c5d3..8c8d3a9bc0eb 100644
--- a/mlir/docs/CreatingADialect.md
+++ b/mlir/docs/CreatingADialect.md
@@ -39,7 +39,7 @@ is declared using add_mlir_dialect().
```cmake
-add_mlir_dialect(FooOps FooOps)
+add_mlir_dialect(FooOps foo FooOps)
```
diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h
index 53a06fbf89f7..edae534f12ba 100644
--- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h
+++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h
@@ -36,17 +36,6 @@ class OpBuilder;
/// symbol.
bool isTopLevelValue(Value value);
-class AffineOpsDialect : public Dialect {
-public:
- AffineOpsDialect(MLIRContext *context);
- static StringRef getDialectNamespace() { return "affine"; }
-
- /// Materialize a single constant operation from a given attribute value with
- /// the desired resultant type.
- Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
- Location loc) override;
-};
-
/// AffineDmaStartOp starts a non-blocking DMA operation that transfers data
/// from a source memref to a destination memref. The source and destination
/// memref need not be of the same dimensionality, but need to have the same
@@ -504,6 +493,8 @@ AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
void fullyComposeAffineMapAndOperands(AffineMap *map,
SmallVectorImpl<Value> *operands);
+#include "mlir/Dialect/AffineOps/AffineOpsDialect.h.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/AffineOps/AffineOps.h.inc"
diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td
index 4b94cf2530a4..307860f1622b 100644
--- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td
+++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td
@@ -17,14 +17,15 @@ include "mlir/Dialect/AffineOps/AffineOpsBase.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffects.td"
-def Affine_Dialect : Dialect {
+def AffineOps_Dialect : Dialect {
let name = "affine";
let cppNamespace = "";
+ let hasConstantMaterializer = 1;
}
// Base class for Affine dialect ops.
class Affine_Op<string mnemonic, list<OpTrait> traits = []> :
- Op<Affine_Dialect, mnemonic, traits> {
+ Op<AffineOps_Dialect, mnemonic, traits> {
// For every affine op, there needs to be a:
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
// * LogicalResult verify(${C++ class of Op} op)
@@ -290,7 +291,7 @@ def AffineIfOp : Affine_Op<"if",
}
class AffineMinMaxOpBase<string mnemonic, list<OpTrait> traits = []> :
- Op<Affine_Dialect, mnemonic, traits> {
+ Op<AffineOps_Dialect, mnemonic, traits> {
let arguments = (ins AffineMapAttr:$map, Variadic<Index>:$operands);
let results = (outs Index);
diff --git a/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt b/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt
index 7339bcc9dcfd..155e066a4725 100644
--- a/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/AffineOps/CMakeLists.txt
@@ -1 +1 @@
-add_mlir_dialect(AffineOps AffineOps)
+add_mlir_dialect(AffineOps affine AffineOps)
diff --git a/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt b/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt
index 484230778b3d..cff7f3ea1548 100644
--- a/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/FxpMathOps/CMakeLists.txt
@@ -1 +1 @@
-add_mlir_dialect(FxpMathOps FxpMathOps)
+add_mlir_dialect(FxpMathOps fxpmath FxpMathOps)
diff --git a/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h b/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h
index 17d320ab04b4..bb5ab0ecd001 100644
--- a/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h
+++ b/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.h
@@ -17,11 +17,7 @@
namespace mlir {
namespace fxpmath {
-/// Defines the 'FxpMathOps' dialect.
-class FxpMathOpsDialect : public Dialect {
-public:
- FxpMathOpsDialect(MLIRContext *context);
-};
+#include "mlir/Dialect/FxpMathOps/FxpMathOpsDialect.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h.inc"
diff --git a/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td b/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td
index 674318431ff6..7e6625d04ef4 100644
--- a/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td
+++ b/mlir/include/mlir/Dialect/FxpMathOps/FxpMathOps.td
@@ -15,10 +15,10 @@
#define DIALECT_FXPMATHOPS_FXPMATH_OPS_
include "mlir/IR/OpBase.td"
-include "mlir/Dialect/QuantOps/QuantPredicates.td"
+include "mlir/Dialect/QuantOps/QuantOpsBase.td"
include "mlir/Interfaces/SideEffects.td"
-def fxpmath_Dialect : Dialect {
+def FxpMathOps_Dialect : Dialect {
let name = "fxpmath";
}
@@ -78,7 +78,7 @@ def fxpmath_CompareFnAttr : StrEnumAttr<"ComparisonFn",
//===----------------------------------------------------------------------===//
class fxpmath_Op<string mnemonic, list<OpTrait> traits> :
- Op<fxpmath_Dialect, mnemonic, traits>;
+ Op<FxpMathOps_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// Fixed-point (fxp) arithmetic ops used by kernels.
diff --git a/mlir/include/mlir/Dialect/GPU/CMakeLists.txt b/mlir/include/mlir/Dialect/GPU/CMakeLists.txt
index fd85b5bcfbfa..bb4c4f5d34c4 100644
--- a/mlir/include/mlir/Dialect/GPU/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/GPU/CMakeLists.txt
@@ -1 +1 @@
-add_mlir_dialect(GPUOps GPUOps)
+add_mlir_dialect(GPUOps gpu GPUOps)
diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
index c4a99d87f5fe..9570d8feaaf2 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
@@ -26,51 +26,6 @@ class FuncOp;
namespace gpu {
-/// The dialect containing GPU kernel launching operations and related
-/// facilities.
-class GPUDialect : public Dialect {
-public:
- /// Create the dialect in the given `context`.
- explicit GPUDialect(MLIRContext *context);
- /// Get dialect namespace.
- static StringRef getDialectNamespace() { return "gpu"; }
-
- /// Get the name of the attribute used to annotate the modules that contain
- /// kernel modules.
- static StringRef getContainerModuleAttrName() {
- return "gpu.container_module";
- }
-
- /// Get the canonical string name of the dialect.
- static StringRef getDialectName();
-
- /// Get the name of the attribute used to annotate external kernel functions.
- static StringRef getKernelFuncAttrName() { return "gpu.kernel"; }
-
- /// Get the name of the attribute used to annotate kernel modules.
- static StringRef getKernelModuleAttrName() { return "gpu.kernel_module"; }
-
- /// Returns whether the given function is a kernel function, i.e., has the
- /// 'gpu.kernel' attribute.
- static bool isKernel(Operation *op);
-
- /// Returns the number of workgroup (thread, block) dimensions supported in
- /// the GPU dialect.
- // TODO(zinenko,herhut): consider generalizing this.
- static unsigned getNumWorkgroupDimensions() { return 3; }
-
- /// Returns the numeric value used to identify the workgroup memory address
- /// space.
- static unsigned getWorkgroupAddressSpace() { return 3; }
-
- /// Returns the numeric value used to identify the private memory address
- /// space.
- static unsigned getPrivateAddressSpace() { return 5; }
-
- LogicalResult verifyOperationAttribute(Operation *op,
- NamedAttribute attr) override;
-};
-
/// Utility class for the GPU dialect to represent triples of `Value`s
/// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation.
struct KernelDim3 {
@@ -79,6 +34,8 @@ struct KernelDim3 {
Value z;
};
+#include "mlir/Dialect/GPU/GPUOpsDialect.h.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/GPU/GPUOps.h.inc"
diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index eda10d8c13b4..1ff0a8b57f5d 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -28,6 +28,39 @@ def IntLikeOrLLVMInt : TypeConstraint<
def GPU_Dialect : Dialect {
let name = "gpu";
+ let extraClassDeclaration = [{
+ /// Get the name of the attribute used to annotate the modules that contain
+ /// kernel modules.
+ static StringRef getContainerModuleAttrName() {
+ return "gpu.container_module";
+ }
+ /// Get the name of the attribute used to annotate external kernel
+ /// functions.
+ static StringRef getKernelFuncAttrName() { return "gpu.kernel"; }
+
+ /// Get the name of the attribute used to annotate kernel modules.
+ static StringRef getKernelModuleAttrName() { return "gpu.kernel_module"; }
+
+ /// Returns whether the given function is a kernel function, i.e., has the
+ /// 'gpu.kernel' attribute.
+ static bool isKernel(Operation *op);
+
+ /// Returns the number of workgroup (thread, block) dimensions supported in
+ /// the GPU dialect.
+ // TODO(zinenko,herhut): consider generalizing this.
+ static unsigned getNumWorkgroupDimensions() { return 3; }
+
+ /// Returns the numeric value used to identify the workgroup memory address
+ /// space.
+ static unsigned getWorkgroupAddressSpace() { return 3; }
+
+ /// Returns the numeric value used to identify the private memory address
+ /// space.
+ static unsigned getPrivateAddressSpace() { return 5; }
+
+ LogicalResult verifyOperationAttribute(Operation *op,
+ NamedAttribute attr) override;
+ }];
}
class GPU_Op<string mnemonic, list<OpTrait> traits = []> :
diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 8db3a86142d9..796b4a68a2b1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -1,12 +1,13 @@
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
mlir_tablegen(LLVMOps.h.inc -gen-op-decls)
mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs)
+mlir_tablegen(LLVMOpsDialect.h.inc -gen-dialect-decls)
mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRLLVMOpsIncGen)
-add_mlir_dialect(NVVMOps NVVMOps)
-add_mlir_dialect(ROCDLOps ROCDLOps)
+add_mlir_dialect(NVVMOps nvvm NVVMOps)
+add_mlir_dialect(ROCDLOps rocdl ROCDLOps)
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 405d93aa02fe..1ff2d6ae28bb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -201,32 +201,7 @@ class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/LLVMOps.h.inc"
-class LLVMDialect : public Dialect {
-public:
- explicit LLVMDialect(MLIRContext *context);
- ~LLVMDialect();
- static StringRef getDialectNamespace() { return "llvm"; }
-
- llvm::LLVMContext &getLLVMContext();
- llvm::Module &getLLVMModule();
-
- /// Parse a type registered to this dialect.
- Type parseType(DialectAsmParser &parser) const override;
-
- /// Print a type registered to this dialect.
- void printType(Type type, DialectAsmPrinter &os) const override;
-
- /// Verify a region argument attribute registered to this dialect.
- /// Returns failure if the verification failed, success otherwise.
- LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIdx,
- unsigned argIdx,
- NamedAttribute argAttr) override;
-
-private:
- friend LLVMType;
-
- std::unique_ptr<detail::LLVMDialectImpl> impl;
-};
+#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.h.inc"
/// Create an LLVM global containing the string "value" at the module containing
/// surrounding the insertion point of builder. Obtain the address of that
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index f32161325853..b2d1e57c0f11 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -19,11 +19,28 @@ include "mlir/IR/OpBase.td"
def LLVM_Dialect : Dialect {
let name = "llvm";
let cppNamespace = "LLVM";
+ let extraClassDeclaration = [{
+ ~LLVMDialect();
+ llvm::LLVMContext &getLLVMContext();
+ llvm::Module &getLLVMModule();
+
+ /// Verify a region argument attribute registered to this dialect.
+ /// Returns failure if the verification failed, success otherwise.
+ LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIdx,
+ unsigned argIdx,
+ NamedAttribute argAttr) override;
+
+ private:
+ friend LLVMType;
+
+ std::unique_ptr<detail::LLVMDialectImpl> impl;
+ }];
}
// LLVM IR type wrapped in MLIR.
-def LLVM_Type : Type<CPred<"$_self.isa<::mlir::LLVM::LLVMType>()">,
- "LLVM dialect type">;
+def LLVM_Type : DialectType<LLVM_Dialect,
+ CPred<"$_self.isa<::mlir::LLVM::LLVMType>()">,
+ "LLVM dialect type">;
// Type constraint accepting only wrapped LLVM integer types.
def LLVMInt : TypeConstraint<
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 6a0b48a2c93a..7b5e1060b301 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -25,12 +25,7 @@ namespace NVVM {
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/NVVMOps.h.inc"
-class NVVMDialect : public Dialect {
-public:
- explicit NVVMDialect(MLIRContext *context);
-
- static StringRef getDialectNamespace() { return "nvvm"; }
-};
+#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.h.inc"
} // namespace NVVM
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
index 1e3740c22a2e..6177101706b4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
@@ -33,12 +33,7 @@ namespace ROCDL {
#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/ROCDLOps.h.inc"
-class ROCDLDialect : public Dialect {
-public:
- explicit ROCDLDialect(MLIRContext *context);
-
- static StringRef getDialectNamespace() { return "rocdl"; }
-};
+#include "mlir/Dialect/LLVMIR/ROCDLOpsDialect.h.inc"
} // namespace ROCDL
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
index 6ae8b0a3e483..212e3022262d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
@@ -1,4 +1,4 @@
-add_mlir_dialect(LinalgOps LinalgDoc)
+add_mlir_dialect(LinalgOps linalg LinalgDoc)
set(LLVM_TARGET_DEFINITIONS LinalgStructuredOps.td)
mlir_tablegen(LinalgStructuredOps.h.inc -gen-op-decls)
mlir_tablegen(LinalgStructuredOps.cpp.inc -gen-op-defs)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 154b875e0fbc..3b52b0e63918 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -34,6 +34,6 @@ def Linalg_Dialect : Dialect {
// Whether a type is a RangeType.
def LinalgIsRangeTypePred : CPred<"$_self.isa<RangeType>()">;
-def Range : Type<LinalgIsRangeTypePred, "range">;
+def Range : DialectType<Linalg_Dialect, LinalgIsRangeTypePred, "range">;
#endif // LINALG_BASE
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
index 3a7c1ac8a831..6f93d2184709 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h
@@ -21,17 +21,7 @@ enum LinalgTypes {
LAST_USED_LINALG_TYPE = Range,
};
-class LinalgDialect : public Dialect {
-public:
- explicit LinalgDialect(MLIRContext *context);
- static StringRef getDialectNamespace() { return "linalg"; }
-
- /// Parse a type registered to this dialect.
- Type parseType(DialectAsmParser &parser) const override;
-
- /// Print a type registered to this dialect.
- void printType(Type type, DialectAsmPrinter &os) const override;
-};
+#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc"
/// A RangeType represents a minimal range abstraction (min, max, step).
/// It is constructed by calling the linalg.range op with three values index of
diff --git a/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt b/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt
index 0fda882d3f54..511c32d32a47 100644
--- a/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LoopOps/CMakeLists.txt
@@ -1 +1 @@
-add_mlir_dialect(LoopOps LoopOps)
+add_mlir_dialect(LoopOps loop LoopOps)
diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.h b/mlir/include/mlir/Dialect/LoopOps/LoopOps.h
index f1fe8d5c12b0..09b982d93373 100644
--- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.h
+++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.h
@@ -25,11 +25,7 @@ namespace loop {
class TerminatorOp;
-class LoopOpsDialect : public Dialect {
-public:
- LoopOpsDialect(MLIRContext *context);
- static StringRef getDialectNamespace() { return "loop"; }
-};
+#include "mlir/Dialect/LoopOps/LoopOpsDialect.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/LoopOps/LoopOps.h.inc"
diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
index 462ec5ddb7f1..84765222ce4c 100644
--- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
+++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td
@@ -16,14 +16,14 @@
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffects.td"
-def Loop_Dialect : Dialect {
+def LoopOps_Dialect : Dialect {
let name = "loop";
let cppNamespace = "";
}
// Base class for Loop dialect ops.
class Loop_Op<string mnemonic, list<OpTrait> traits = []> :
- Op<Loop_Dialect, mnemonic, traits> {
+ Op<LoopOps_Dialect, mnemonic, traits> {
// For every standard op, there needs to be a:
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
// * LogicalResult verify(${C++ class of Op} op)
diff --git a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
index d04e695708fe..0362a631dc2e 100644
--- a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
@@ -1 +1 @@
-add_mlir_dialect(OpenMPOps OpenMPOps)
+add_mlir_dialect(OpenMPOps omp OpenMPOps)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
index 764903be6161..6761b51b55b5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
@@ -22,13 +22,7 @@ namespace omp {
#define GET_OP_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOps.h.inc"
-class OpenMPDialect : public Dialect {
-public:
- explicit OpenMPDialect(MLIRContext *context);
-
- static StringRef getDialectNamespace() { return "omp"; }
-};
-
+#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.h.inc"
} // namespace omp
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt b/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt
index 90a61c4c194f..87a9fd6a3066 100644
--- a/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/QuantOps/CMakeLists.txt
@@ -1 +1 @@
-add_mlir_dialect(QuantOps QuantOps)
+add_mlir_dialect(QuantOps quant QuantOps)
diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantOps.h b/mlir/include/mlir/Dialect/QuantOps/QuantOps.h
index 56d1eb656a9a..802be083e30a 100644
--- a/mlir/include/mlir/Dialect/QuantOps/QuantOps.h
+++ b/mlir/include/mlir/Dialect/QuantOps/QuantOps.h
@@ -21,17 +21,7 @@
namespace mlir {
namespace quant {
-/// Defines the 'Quantization' dialect
-class QuantizationDialect : public Dialect {
-public:
- QuantizationDialect(MLIRContext *context);
-
- /// Parse a type registered to this dialect.
- Type parseType(DialectAsmParser &parser) const override;
-
- /// Print a type registered to this dialect.
- void printType(Type type, DialectAsmPrinter &os) const override;
-};
+#include "mlir/Dialect/QuantOps/QuantOpsDialect.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/QuantOps/QuantOps.h.inc"
diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td
index 227ce33a26b4..92e1e1d813ed 100644
--- a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td
+++ b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td
@@ -13,20 +13,15 @@
#ifndef DIALECT_QUANTOPS_QUANT_OPS_
#define DIALECT_QUANTOPS_QUANT_OPS_
-include "mlir/IR/OpBase.td"
-include "mlir/Dialect/QuantOps/QuantPredicates.td"
+include "mlir/Dialect/QuantOps/QuantOpsBase.td"
include "mlir/Interfaces/SideEffects.td"
-def quant_Dialect : Dialect {
- let name = "quant";
-}
-
//===----------------------------------------------------------------------===//
// Base classes
//===----------------------------------------------------------------------===//
class quant_Op<string mnemonic, list<OpTrait> traits> :
- Op<quant_Dialect, mnemonic, traits>;
+ Op<Quantization_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// Quantization casts
diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td b/mlir/include/mlir/Dialect/QuantOps/QuantOpsBase.td
similarity index 84%
rename from mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td
rename to mlir/include/mlir/Dialect/QuantOps/QuantOpsBase.td
index cd2e85fd985d..84efcc1c8fc0 100644
--- a/mlir/include/mlir/Dialect/QuantOps/QuantPredicates.td
+++ b/mlir/include/mlir/Dialect/QuantOps/QuantOpsBase.td
@@ -1,4 +1,4 @@
-//===- QuantPredicates.td - Predicates for dialect types ---*- tablegen -*-===//
+//===- QuantOpsBase.td - Quantization dialect base ---------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -10,8 +10,14 @@
//
//===----------------------------------------------------------------------===//
-#ifndef DIALECT_QUANTOPS_QUANT_PREDICATES_
-#define DIALECT_QUANTOPS_QUANT_PREDICATES_
+#ifndef DIALECT_QUANTOPS_QUANT_OPS_BASE_
+#define DIALECT_QUANTOPS_QUANT_OPS_BASE_
+
+include "mlir/IR/OpBase.td"
+
+def Quantization_Dialect : Dialect {
+ let name = "quant";
+}
//===----------------------------------------------------------------------===//
// Quantization type definitions
@@ -54,10 +60,12 @@ def quant_RealOrStorageValueType :
// An implementation of UniformQuantizedType.
def quant_UniformQuantizedType :
- Type<CPred<"$_self.isa<UniformQuantizedType>()">, "UniformQuantizedType">;
+ DialectType<Quantization_Dialect,
+ CPred<"$_self.isa<UniformQuantizedType>()">,
+ "UniformQuantizedType">;
// Predicate for detecting a container or primitive of UniformQuantizedType.
def quant_UniformQuantizedValueType :
quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;
-#endif // DIALECT_QUANTOPS_QUANT_PREDICATES_
+#endif // DIALECT_QUANTOPS_QUANT_OPS_BASE_
diff --git a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
index fb1cf9f72089..8d297cc92117 100644
--- a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
@@ -1,4 +1,4 @@
-add_mlir_dialect(SPIRVOps SPIRVOps)
+add_mlir_dialect(SPIRVOps spv SPIRVOps)
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls)
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index aad8b9f2ec7e..4dc1886fae2d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -22,7 +22,7 @@ include "mlir/Dialect/SPIRV/SPIRVAvailability.td"
// SPIR-V dialect definitions
//===----------------------------------------------------------------------===//
-def SPV_Dialect : Dialect {
+def SPIRV_Dialect : Dialect {
let name = "spv";
let summary = "The SPIR-V dialect in MLIR.";
@@ -46,6 +46,43 @@ def SPV_Dialect : Dialect {
}];
let cppNamespace = "spirv";
+ let hasConstantMaterializer = 1;
+ let extraClassDeclaration = [{
+ //===------------------------------------------------------------------===//
+ // Type
+ //===------------------------------------------------------------------===//
+
+ /// Checks if the given `type` is valid in SPIR-V dialect.
+ static bool isValidType(Type type);
+
+ /// Checks if the given `scalar type` is valid in SPIR-V dialect.
+ static bool isValidScalarType(Type type);
+
+ //===------------------------------------------------------------------===//
+ // Attribute
+ //===------------------------------------------------------------------===//
+
+ /// Returns the attribute name to use when specifying decorations on results
+ /// of operations.
+ static std::string getAttributeName(Decoration decoration);
+
+ /// Provides a hook for verifying SPIR-V dialect attributes attached to the
+ /// given op.
+ LogicalResult verifyOperationAttribute(Operation *op,
+ NamedAttribute attribute) override;
+
+ /// Provides a hook for verifying SPIR-V dialect attributes attached to the
+ /// given op's region argument.
+ LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIndex,
+ unsigned argIndex,
+ NamedAttribute attribute) override;
+
+ /// Provides a hook for verifying SPIR-V dialect attributes attached to the
+ /// given op's region result.
+ LogicalResult verifyRegionResultAttribute(
+ Operation *op, unsigned regionIndex, unsigned resultIndex,
+ NamedAttribute attribute) override;
+ }];
}
//===----------------------------------------------------------------------===//
@@ -2953,7 +2990,8 @@ def SPV_SamplerUseAttr:
// SPIR-V attribute definitions
//===----------------------------------------------------------------------===//
-def SPV_VerCapExtAttr : Attr<
+def SPV_VerCapExtAttr : DialectAttr<
+ SPIRV_Dialect,
CPred<"$_self.isa<::mlir::spirv::VerCapExtAttr>()">,
"version-capability-extension attribute"> {
let storageType = "::mlir::spirv::VerCapExtAttr";
@@ -2993,10 +3031,14 @@ def SPV_Vector : VectorOfLengthAndType<[2, 3, 4],
[SPV_Bool, SPV_Integer, SPV_Float]>;
// Component type check is done in the type parser for the following SPIR-V
// dialect-specific types so we use "Any" here.
-def SPV_AnyPtr : Type<SPV_IsPtrType, "any SPIR-V pointer type">;
-def SPV_AnyArray : Type<SPV_IsArrayType, "any SPIR-V array type">;
-def SPV_AnyRTArray : Type<SPV_IsRTArrayType, "any SPIR-V runtime array type">;
-def SPV_AnyStruct : Type<SPV_IsStructType, "any SPIR-V struct type">;
+def SPV_AnyPtr : DialectType<SPIRV_Dialect, SPV_IsPtrType,
+ "any SPIR-V pointer type">;
+def SPV_AnyArray : DialectType<SPIRV_Dialect, SPV_IsArrayType,
+ "any SPIR-V array type">;
+def SPV_AnyRTArray : DialectType<SPIRV_Dialect, SPV_IsRTArrayType,
+ "any SPIR-V runtime array type">;
+def SPV_AnyStruct : DialectType<SPIRV_Dialect, SPV_IsStructType,
+ "any SPIR-V struct type">;
def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>;
def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
@@ -3264,7 +3306,7 @@ def SPV_OpcodeAttr :
// Base class for all SPIR-V ops.
class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
- Op<SPV_Dialect, mnemonic, !listconcat(traits, [
+ Op<SPIRV_Dialect, mnemonic, !listconcat(traits, [
// TODO(antiagainst): We don't need all of the following traits for
// every op; only the suitabble ones should be added automatically
// after ODS supports dialect-specific contents.
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h
index a6d93d8d2862..2cffebec60ea 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h
@@ -20,67 +20,7 @@ namespace spirv {
enum class Decoration : uint32_t;
-class SPIRVDialect : public Dialect {
-public:
- explicit SPIRVDialect(MLIRContext *context);
-
- static StringRef getDialectNamespace() { return "spv"; }
-
- //===--------------------------------------------------------------------===//
- // Type
- //===--------------------------------------------------------------------===//
-
- /// Checks if the given `type` is valid in SPIR-V dialect.
- static bool isValidType(Type type);
-
- /// Checks if the given `scalar type` is valid in SPIR-V dialect.
- static bool isValidScalarType(Type type);
-
- /// Parses a type registered to this dialect.
- Type parseType(DialectAsmParser &parser) const override;
-
- /// Prints a type registered to this dialect.
- void printType(Type type, DialectAsmPrinter &os) const override;
-
- //===--------------------------------------------------------------------===//
- // Attribute
- //===--------------------------------------------------------------------===//
-
- /// Returns the attribute name to use when specifying decorations on results
- /// of operations.
- static std::string getAttributeName(Decoration decoration);
-
- /// Parses an attribute registered to this dialect.
- Attribute parseAttribute(DialectAsmParser &parser, Type type) const override;
-
- /// Prints an attribute registered to this dialect.
- void printAttribute(Attribute, DialectAsmPrinter &printer) const override;
-
- /// Provides a hook for verifying SPIR-V dialect attributes attached to the
- /// given op.
- LogicalResult verifyOperationAttribute(Operation *op,
- NamedAttribute attribute) override;
-
- /// Provides a hook for verifying SPIR-V dialect attributes attached to the
- /// given op's region argument.
- LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIndex,
- unsigned argIndex,
- NamedAttribute attribute) override;
-
- /// Provides a hook for verifying SPIR-V dialect attributes attached to the
- /// given op's region result.
- LogicalResult verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
- unsigned resultIndex,
- NamedAttribute attribute) override;
-
- //===--------------------------------------------------------------------===//
- // Constant
- //===--------------------------------------------------------------------===//
-
- /// Provides a hook for materializing a constant to this dialect.
- Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
- Location loc) override;
-};
+#include "mlir/Dialect/SPIRV/SPIRVOpsDialect.h.inc"
} // end namespace spirv
} // end namespace mlir
diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td
index 3c10caa601e3..a463f0e8da95 100644
--- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td
+++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td
@@ -29,7 +29,7 @@ include "mlir/Dialect/SPIRV/SPIRVBase.td"
// 1) Descriptor Set.
// 2) Binding number.
// 3) Storage class.
-def SPV_InterfaceVarABIAttr : StructAttr<"InterfaceVarABIAttr", SPV_Dialect, [
+def SPV_InterfaceVarABIAttr : StructAttr<"InterfaceVarABIAttr", SPIRV_Dialect, [
StructFieldAttr<"descriptor_set", I32Attr>,
StructFieldAttr<"binding", I32Attr>,
StructFieldAttr<"storage_class", SPV_StorageClassAttr>
@@ -38,7 +38,7 @@ def SPV_InterfaceVarABIAttr : StructAttr<"InterfaceVarABIAttr", SPV_Dialect, [
// For entry functions, this attribute specifies information related to entry
// points in the generated SPIR-V module:
// 1) WorkGroup Size.
-def SPV_EntryPointABIAttr : StructAttr<"EntryPointABIAttr", SPV_Dialect, [
+def SPV_EntryPointABIAttr : StructAttr<"EntryPointABIAttr", SPIRV_Dialect, [
StructFieldAttr<"local_size", I32ElementsAttr>
]>;
@@ -54,7 +54,7 @@ def SPV_CapabilityArrayAttr : TypedArrayAttrBase<
// See https://renderdoc.org/vkspec_chunked/chap36.html#limits for the complete
// list of limits and their explanation for the Vulkan API. The following ones
// are those affecting SPIR-V CodeGen.
-def SPV_ResourceLimitsAttr : StructAttr<"ResourceLimitsAttr", SPV_Dialect, [
+def SPV_ResourceLimitsAttr : StructAttr<"ResourceLimitsAttr", SPIRV_Dialect, [
StructFieldAttr<"max_compute_workgroup_invocations", I32Attr>,
StructFieldAttr<"max_compute_workgroup_size", I32ElementsAttr>
]>;
diff --git a/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt
index 3d1adca9c2be..702ec621f486 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt
@@ -1 +1 @@
-add_mlir_dialect(ShapeOps ShapeOps)
+add_mlir_dialect(ShapeOps shape ShapeOps)
diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
index 47234ae9b826..fe302e62e9c8 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h
+++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
@@ -21,13 +21,6 @@
namespace mlir {
namespace shape {
-/// This dialect contains shape inference related operations and facilities.
-class ShapeDialect : public Dialect {
-public:
- /// Create the dialect in the given `context`.
- explicit ShapeDialect(MLIRContext *context);
-};
-
namespace ShapeTypes {
enum Kind {
Component = Type::FIRST_SHAPE_TYPE,
@@ -112,6 +105,8 @@ class ValueShapeType : public Type::TypeBase<ValueShapeType, Type> {
#define GET_OP_CLASSES
#include "mlir/Dialect/Shape/IR/ShapeOps.h.inc"
+#include "mlir/Dialect/Shape/IR/ShapeOpsDialect.h.inc"
+
} // namespace shape
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt
index b6534797a065..9abc8430c16b 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt
@@ -1,6 +1,7 @@
set(LLVM_TARGET_DEFINITIONS Ops.td)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
+mlir_tablegen(OpsDialect.h.inc -gen-dialect-decls)
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRStandardOpsIncGen)
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 76406c607233..06ceaf4b3f29 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -31,20 +31,11 @@ class Builder;
class FuncOp;
class OpBuilder;
-class StandardOpsDialect : public Dialect {
-public:
- StandardOpsDialect(MLIRContext *context);
- static StringRef getDialectNamespace() { return "std"; }
-
- /// Materialize a single constant operation from a given attribute value with
- /// the desired resultant type.
- Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
- Location loc) override;
-};
-
#define GET_OP_CLASSES
#include "mlir/Dialect/StandardOps/IR/Ops.h.inc"
+#include "mlir/Dialect/StandardOps/IR/OpsDialect.h.inc"
+
/// This is a refinement of the "constant" op for the case where it is
/// returning a float value of FloatType.
///
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index daf5da739c50..9ca5465f3787 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -18,14 +18,15 @@ include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffects.td"
-def Std_Dialect : Dialect {
+def StandardOps_Dialect : Dialect {
let name = "std";
let cppNamespace = "";
+ let hasConstantMaterializer = 1;
}
// Base class for Standard dialect ops.
class Std_Op<string mnemonic, list<OpTrait> traits = []> :
- Op<Std_Dialect, mnemonic, traits> {
+ Op<StandardOps_Dialect, mnemonic, traits> {
// For every standard op, there needs to be a:
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
// * LogicalResult verify(${C++ class of Op} op)
@@ -63,7 +64,7 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
// Base class for unary ops. Requires single operand and result. Individual
// classes will have `operand` accessor.
class UnaryOp<string mnemonic, list<OpTrait> traits = []> :
- Op<Std_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])> {
+ Op<StandardOps_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])> {
let results = (outs AnyType);
let printer = [{
return printStandardUnaryOp(this->getOperation(), p);
@@ -86,7 +87,7 @@ class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
// results to be of the same type, but does not constrain them to specific
// types. Individual classes will have `lhs` and `rhs` accessor to operands.
class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
- Op<Std_Dialect, mnemonic,
+ Op<StandardOps_Dialect, mnemonic,
!listconcat(traits, [NoSideEffect, SameOperandsAndResultType])> {
let results = (outs AnyType);
diff --git a/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt b/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt
index 5ce3168c5580..4977e117e7e0 100644
--- a/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt
@@ -1,4 +1,4 @@
-add_mlir_dialect(VectorOps VectorOps)
+add_mlir_dialect(VectorOps vector VectorOps)
set(LLVM_TARGET_DEFINITIONS VectorTransformPatterns.td)
mlir_tablegen(VectorTransformPatterns.h.inc -gen-rewriters)
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h
index e32752fe6030..e13d480b7a9f 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h
@@ -24,18 +24,6 @@ class MLIRContext;
class OwningRewritePatternList;
namespace vector {
-/// Dialect for Ops on higher-dimensional vector types.
-class VectorOpsDialect : public Dialect {
-public:
- VectorOpsDialect(MLIRContext *context);
- static StringRef getDialectNamespace() { return "vector"; }
-
- /// Materialize a single constant operation from a given attribute value with
- /// the desired resultant type.
- Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
- Location loc) override;
-};
-
/// Collect a set of vector-to-vector canonicalization patterns.
void populateVectorToVectorCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context);
@@ -75,6 +63,8 @@ ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values);
#define GET_OP_CLASSES
#include "mlir/Dialect/VectorOps/VectorOps.h.inc"
+#include "mlir/Dialect/VectorOps/VectorOpsDialect.h.inc"
+
} // end namespace vector
} // end namespace mlir
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index ee7c1431429c..88b6e1e993e6 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -16,14 +16,15 @@
include "mlir/Dialect/AffineOps/AffineOpsBase.td"
include "mlir/Interfaces/SideEffects.td"
-def Vector_Dialect : Dialect {
+def VectorOps_Dialect : Dialect {
let name = "vector";
let cppNamespace = "vector";
+ let hasConstantMaterializer = 1;
}
// Base class for Vector dialect ops.
class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
- Op<Vector_Dialect, mnemonic, traits> {
+ Op<VectorOps_Dialect, mnemonic, traits> {
// For every vector op, there needs to be a:
// * void print(OpAsmPrinter &p, ${C++ class of Op} op)
// * LogicalResult verify(${C++ class of Op} op)
@@ -432,7 +433,7 @@ def Vector_ExtractSlicesOp :
}
def Vector_FMAOp :
- Op<Vector_Dialect, "fma", [NoSideEffect,
+ Op<VectorOps_Dialect, "fma", [NoSideEffect,
AllTypesMatch<["lhs", "rhs", "acc", "result"]>]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc)>,
Results<(outs AnyVector:$result)> {
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index fa890e22e823..159a3c5eae54 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -253,6 +253,13 @@ class Dialect {
// the generated files are included into the dialect, you may want to specify
// a full namespace path or a partial one.
string cppNamespace = name;
+
+ // An optional code block containing extra declarations to place in the
+ // dialect declaration.
+ code extraClassDeclaration = "";
+
+ // If this dialect overrides the hook for materializing constants.
+ bit hasConstantMaterializer = 0;
}
//===----------------------------------------------------------------------===//
@@ -753,6 +760,12 @@ class Attr<Pred condition, string descr = ""> :
Attr baseAttr = ?;
}
+// An attribute of a specific dialect.
+class DialectAttr<Dialect d, Pred condition, string descr = ""> :
+ Attr<condition, descr> {
+ Dialect dialect = d;
+}
+
//===----------------------------------------------------------------------===//
// Attribute modifier definition
diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h
index dbc018a09323..f99939392e93 100644
--- a/mlir/include/mlir/TableGen/Attribute.h
+++ b/mlir/include/mlir/TableGen/Attribute.h
@@ -25,6 +25,7 @@ class Record;
namespace mlir {
namespace tblgen {
+class Dialect;
class Type;
// Wrapper class with helper methods for accessing attribute constraints defined
@@ -105,6 +106,9 @@ class Attribute : public AttrConstraint {
// Returns the code body for derived attribute. Aborts if this is not a
// derived attribute.
StringRef getDerivedCodeBody() const;
+
+ // Returns the dialect for the attribute if defined.
+ Dialect getDialect() const;
};
// Wrapper class providing helper methods for accessing MLIR constant attribute
diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index eb03a6a96626..7cf5760b6817 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -32,6 +32,9 @@ class Dialect {
// Returns the C++ namespaces that ops of this dialect should be placed into.
StringRef getCppNamespace() const;
+ // Returns this dialect's C++ class name.
+ std::string getCppClassName() const;
+
// Returns the summary description of the dialect. Returns empty string if
// none.
StringRef getSummary() const;
@@ -39,6 +42,12 @@ class Dialect {
// Returns the description of the dialect. Returns empty string if none.
StringRef getDescription() const;
+ // Returns the dialects extra class declaration code.
+ llvm::Optional<StringRef> getExtraClassDeclaration() const;
+
+ // Returns if this dialect has a constant materializer or not.
+ bool hasConstantMaterializer() const;
+
// Returns whether two dialects are equal by checking the equality of the
// underlying record.
bool operator==(const Dialect &other) const;
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 547c06d5ab20..e484e25b348d 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -28,15 +28,13 @@ using namespace mlir::gpu;
// GPUDialect
//===----------------------------------------------------------------------===//
-StringRef GPUDialect::getDialectName() { return "gpu"; }
-
bool GPUDialect::isKernel(Operation *op) {
UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
return static_cast<bool>(isKernelAttr);
}
GPUDialect::GPUDialect(MLIRContext *context)
- : Dialect(getDialectName(), context) {
+ : Dialect(getDialectNamespace(), context) {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/GPU/GPUOps.cpp.inc"
diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index b11438c2dc02..89dce1958991 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -132,6 +132,10 @@ StringRef tblgen::Attribute::getDerivedCodeBody() const {
return def->getValueAsString("body");
}
+tblgen::Dialect tblgen::Attribute::getDialect() const {
+ return Dialect(def->getValueAsDef("dialect"));
+}
+
tblgen::ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
assert(def->isSubClassOf("ConstantAttr") &&
"must be subclass of TableGen 'ConstantAttr' class");
diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp
index 4ba45520afeb..7e757eaeae4f 100644
--- a/mlir/lib/TableGen/Dialect.cpp
+++ b/mlir/lib/TableGen/Dialect.cpp
@@ -24,6 +24,13 @@ StringRef tblgen::Dialect::getCppNamespace() const {
return def->getValueAsString("cppNamespace");
}
+std::string tblgen::Dialect::getCppClassName() const {
+ // Simply use the name and remove any '_' tokens.
+ std::string cppName = def->getName().str();
+ llvm::erase_if(cppName, [](char c) { return c == '_'; });
+ return cppName;
+}
+
static StringRef getAsStringOrEmpty(const llvm::Record &record,
StringRef fieldName) {
if (auto valueInit = record.getValueInit(fieldName)) {
@@ -42,6 +49,15 @@ StringRef tblgen::Dialect::getDescription() const {
return getAsStringOrEmpty(*def, "description");
}
+llvm::Optional<StringRef> tblgen::Dialect::getExtraClassDeclaration() const {
+ auto value = def->getValueAsString("extraClassDeclaration");
+ return value.empty() ? llvm::Optional<StringRef>() : value;
+}
+
+bool tblgen::Dialect::hasConstantMaterializer() const {
+ return def->getValueAsBit("hasConstantMaterializer");
+}
+
bool Dialect::operator==(const Dialect &other) const {
return def == other.def;
}
diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index fb9ba6ef4062..b7628cff11f8 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -4,6 +4,7 @@ set(LLVM_LINK_COMPONENTS
)
add_tablegen(mlir-tblgen MLIR
+ DialectGen.cpp
EnumsGen.cpp
LLVMIRConversionGen.cpp
LLVMIRIntrinsicGen.cpp
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
new file mode 100644
index 000000000000..c0009d6e1231
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -0,0 +1,166 @@
+//===- DialectGen.cpp - MLIR dialect definitions generator ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// DialectGen uses the description of dialects to generate C++ definitions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/STLExtras.h"
+#include "mlir/Support/StringExtras.h"
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/OpClass.h"
+#include "mlir/TableGen/OpInterfaces.h"
+#include "mlir/TableGen/OpTrait.h"
+#include "mlir/TableGen/Operator.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Signals.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+#define DEBUG_TYPE "mlir-tblgen-opdefgen"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
+static llvm::cl::opt<std::string>
+ selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
+ llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
+
+/// Given a set of records for a T, filter the ones that correspond to
+/// the given dialect.
+template <typename T>
+static auto filterForDialect(ArrayRef<llvm::Record *> records,
+ Dialect &dialect) {
+ return llvm::make_filter_range(records, [&](const llvm::Record *record) {
+ return T(record).getDialect() == dialect;
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// GEN: Dialect declarations
+//===----------------------------------------------------------------------===//
+
+/// The code block for the start of a dialect class declaration.
+///
+/// {0}: The name of the dialect class.
+/// {1}: The dialect namespace.
+static const char *const dialectDeclBeginStr = R"(
+class {0} : public ::mlir::Dialect {
+public:
+ explicit {0}(::mlir::MLIRContext *context);
+ static ::llvm::StringRef getDialectNamespace() { return "{1}"; }
+)";
+
+/// The code block for the attribute parser/printer hooks.
+static const char *const attrParserDecl = R"(
+ /// Parse an attribute registered to this dialect.
+ ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser,
+ ::mlir::Type type) const override;
+
+ /// Print an attribute registered to this dialect.
+ void printAttribute(::mlir::Attribute attr,
+ ::mlir::DialectAsmPrinter &os) const override;
+)";
+
+/// The code block for the type parser/printer hooks.
+static const char *const typeParserDecl = R"(
+ /// Parse a type registered to this dialect.
+ ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
+
+ /// Print a type registered to this dialect.
+ void printType(::mlir::Type type,
+ ::mlir::DialectAsmPrinter &os) const override;
+)";
+
+/// The code block for the constant materializer hook.
+static const char *const constantMaterializerDecl = R"(
+ /// Materialize a single constant operation from a given attribute value with
+ /// the desired resultant type.
+ ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder,
+ ::mlir::Attribute value,
+ ::mlir::Type type,
+ ::mlir::Location loc) override;
+)";
+
+/// Generate the declaration for the given dialect class.
+static void emitDialectDecl(
+ Dialect &dialect,
+ FunctionTraits<decltype(&filterForDialect<Attribute>)>::result_t
+ dialectAttrs,
+ FunctionTraits<decltype(&filterForDialect<Type>)>::result_t dialectTypes,
+ raw_ostream &os) {
+ // Emit the start of the decl.
+ std::string cppName = dialect.getCppClassName();
+ os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName());
+
+ // Check for any attributes/types registered to this dialect. If there are,
+ // add the hooks for parsing/printing.
+ if (!dialectAttrs.empty())
+ os << attrParserDecl;
+ if (!dialectTypes.empty())
+ os << typeParserDecl;
+
+ // Add the decls for the various features of the dialect.
+ if (dialect.hasConstantMaterializer())
+ os << constantMaterializerDecl;
+ if (llvm::Optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
+ os << *extraDecl;
+
+ // End the dialect decl.
+ os << "};\n";
+}
+
+static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
+ raw_ostream &os) {
+ emitSourceFileHeader("Dialect Declarations", os);
+
+ auto defs = recordKeeper.getAllDerivedDefinitions("Dialect");
+ if (defs.empty())
+ return false;
+
+ // Select the dialect to gen for.
+ const llvm::Record *dialectDef = nullptr;
+ if (defs.size() == 1 && selectedDialect.getNumOccurrences() == 0) {
+ dialectDef = defs.front();
+ } else if (selectedDialect.getNumOccurrences() == 0) {
+ llvm::errs() << "when more than 1 dialect is present, one must be selected "
+ "via '-dialect'";
+ return true;
+ } else {
+ auto dialectIt = llvm::find_if(defs, [](const llvm::Record *def) {
+ return Dialect(def).getName() == selectedDialect;
+ });
+ if (dialectIt == defs.end()) {
+ llvm::errs() << "selected dialect with '-dialect' does not exist";
+ return true;
+ }
+ dialectDef = *dialectIt;
+ }
+
+ auto attrDefs = recordKeeper.getAllDerivedDefinitions("DialectAttr");
+ auto typeDefs = recordKeeper.getAllDerivedDefinitions("DialectType");
+ Dialect dialect(dialectDef);
+ emitDialectDecl(dialect, filterForDialect<Attribute>(attrDefs, dialect),
+ filterForDialect<Type>(typeDefs, dialect), os);
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// GEN: Dialect registration hooks
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration
+ genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
+ [](const llvm::RecordKeeper &records, raw_ostream &os) {
+ return emitDialectDecls(records, os);
+ });
More information about the Mlir-commits
mailing list