[Mlir-commits] [mlir] [mlir][emitc] Restrict integer and float types (PR #85788)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 19 06:59:59 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-emitc
@llvm/pr-subscribers-mlir
Author: Tina Jung (TinaAMD)
<details>
<summary>Changes</summary>
Restrict which integers and floating-point types are valid in EmitC. This should cover the types which are supported in C++ and is aligned with what the emitter currently supports.
The checks are implemented as functions and not fully in tablegen to allow them to be re-used by conversions to EmitC.
---
Full diff: https://github.com/llvm/llvm-project/pull/85788.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.h (+4)
- (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.td (+2-2)
- (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td (+6)
- (modified) mlir/lib/Dialect/EmitC/IR/EmitC.cpp (+29)
- (modified) mlir/test/Dialect/EmitC/invalid_ops.mlir (+4-4)
``````````diff
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 1f0df3cb336b12..f3d250fbdc2863 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -30,6 +30,10 @@
namespace mlir {
namespace emitc {
void buildTerminatedBody(OpBuilder &builder, Location loc);
+/// Determines whether \p type is a valid integer type in EmitC.
+bool isValidEmitCIntegerType(mlir::Type type);
+/// Determines whether \p type is a valid floating-point type in EmitC.
+bool isValidEmitCFloatType(mlir::Type type);
} // namespace emitc
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 78bfd561171f50..8e6a8d48b0f744 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -51,8 +51,8 @@ class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
def CExpression : NativeOpTrait<"emitc::CExpression">;
// Types only used in binary arithmetic operations.
-def IntegerIndexOrOpaqueType : AnyTypeOf<[AnyInteger, Index, EmitC_OpaqueType]>;
-def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[AnyFloat, IntegerIndexOrOpaqueType]>;
+def IntegerIndexOrOpaqueType : AnyTypeOf<[Valid_EmitC_Integer_Type, Index, EmitC_OpaqueType]>;
+def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[Valid_EmitC_Float_Type, IntegerIndexOrOpaqueType]>;
def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
let summary = "Addition operation";
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
index a2ba45a1f6a12b..ed51c6d05d567e 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
@@ -22,6 +22,12 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
// EmitC type definitions
//===----------------------------------------------------------------------===//
+def Valid_EmitC_Integer_Type : Type<CPred<"emitc::isValidEmitCIntegerType($_self)">,
+ "EmitC integer type">;
+
+def Valid_EmitC_Float_Type : Type<CPred<"emitc::isValidEmitCFloatType($_self)">,
+ "EmitC floating-point type">;
+
class EmitC_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<EmitC_Dialect, name, traits> {
let mnemonic = typeMnemonic;
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index e401a83bcb42e6..f6e2168e0bac7b 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -54,6 +54,35 @@ void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) {
builder.create<emitc::YieldOp>(loc);
}
+bool mlir::emitc::isValidEmitCIntegerType(Type type) {
+ if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
+ switch (intType.getWidth()) {
+ case 1:
+ case 8:
+ case 16:
+ case 32:
+ case 64:
+ return true;
+ default:
+ return false;
+ }
+ }
+ return false;
+}
+
+bool mlir::emitc::isValidEmitCFloatType(Type type) {
+ if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
+ switch (floatType.getWidth()) {
+ case 32:
+ case 64:
+ return true;
+ default:
+ return false;
+ }
+ }
+ return false;
+}
+
/// Check that the type of the initial value is compatible with the operations
/// result type.
static LogicalResult verifyInitializationAttribute(Operation *op,
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 6294c853d99931..6cc833cf396462 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -170,7 +170,7 @@ func.func @add_float_pointer(%arg0: f32, %arg1: !emitc.ptr<f32>) {
// -----
func.func @div_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
- // expected-error @+1 {{'emitc.div' op operand #0 must be floating-point or integer or index or EmitC opaque type, but got 'tensor<i32>'}}
+ // expected-error @+1 {{'emitc.div' op operand #0 must be EmitC floating-point type or EmitC integer type or index or EmitC opaque type, but got 'tensor<i32>'}}
%1 = "emitc.div" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
return
}
@@ -178,7 +178,7 @@ func.func @div_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
// -----
func.func @mul_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
- // expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point or integer or index or EmitC opaque type, but got 'tensor<i32>'}}
+ // expected-error @+1 {{'emitc.mul' op operand #0 must be EmitC floating-point type or EmitC integer type or index or EmitC opaque type, but got 'tensor<i32>'}}
%1 = "emitc.mul" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
return
}
@@ -186,7 +186,7 @@ func.func @mul_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
// -----
func.func @rem_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
- // expected-error @+1 {{'emitc.rem' op operand #0 must be integer or index or EmitC opaque type, but got 'tensor<i32>'}}
+ // expected-error @+1 {{'emitc.rem' op operand #0 must be EmitC integer type or index or EmitC opaque type, but got 'tensor<i32>'}}
%1 = "emitc.rem" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
return
}
@@ -194,7 +194,7 @@ func.func @rem_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
// -----
func.func @rem_float(%arg0: f32, %arg1: f32) {
- // expected-error @+1 {{'emitc.rem' op operand #0 must be integer or index or EmitC opaque type, but got 'f32'}}
+ // expected-error @+1 {{'emitc.rem' op operand #0 must be EmitC integer type or index or EmitC opaque type, but got 'f32'}}
%1 = "emitc.rem" (%arg0, %arg1) : (f32, f32) -> f32
return
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/85788
More information about the Mlir-commits
mailing list