[Mlir-commits] [mlir] [mlir][emitc] Restrict types in EmitC (PR #88391)

Tina Jung llvmlistbot at llvm.org
Thu Apr 11 06:41:17 PDT 2024


https://github.com/TinaAMD created https://github.com/llvm/llvm-project/pull/88391

Restrict the types which are valid for EmitC operations. Use what is currently supported by the emitter as restriction. Define a utility functions for valid types, such that they can be used to restrict the operations in table gen as well as being available for reuse in dialect conversions.

>From 77bd22583db15d17d7935c0dc31d19d3829e259f Mon Sep 17 00:00:00 2001
From: Tina Jung <tina.maria.jung at xilinx.com>
Date: Thu, 11 Apr 2024 14:36:48 +0100
Subject: [PATCH] [mlir][emitc] Restrict types in EmitC

Restrict the types which are valid for EmitC operations. Use what is currently supported by the emitter as restriction. Define a utility functions for valid types, such that they can be used to restrict the operations in table gen as well as being available for reuse in dialect conversions.
---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.h    |  3 ++
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   | 52 +++++++++----------
 .../mlir/Dialect/EmitC/IR/EmitCTypes.td       |  3 ++
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           | 38 ++++++++++++++
 mlir/test/Dialect/EmitC/invalid_types.mlir    | 48 +++++++++++++++++
 5 files changed, 118 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index c03915667db653..5d9531cd124154 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -31,6 +31,9 @@ namespace mlir {
 namespace emitc {
 void buildTerminatedBody(OpBuilder &builder, Location loc);
 
+/// Determines whether \p type is valid in EmitC.
+bool isSupportedEmitCType(mlir::Type type);
+
 /// Determines whether \p type is a valid integer type in EmitC.
 bool isSupportedIntegerType(mlir::Type type);
 
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index e611fd2f0f15c4..c1a1e77c34ab25 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -34,16 +34,16 @@ class EmitC_Op<string mnemonic, list<Trait> traits = []>
 // Base class for unary operations.
 class EmitC_UnaryOp<string mnemonic, list<Trait> traits = []> :
     EmitC_Op<mnemonic, traits> {
-  let arguments = (ins AnyType);
-  let results = (outs AnyType);
+  let arguments = (ins EmitCType);
+  let results = (outs EmitCType);
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
 // Base class for binary operations.
 class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
     EmitC_Op<mnemonic, traits> {
-  let arguments = (ins AnyType:$lhs, AnyType:$rhs);
-  let results = (outs AnyType);
+  let arguments = (ins EmitCType:$lhs, EmitCType:$rhs);
+  let results = (outs EmitCType);
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
@@ -97,9 +97,9 @@ def EmitC_ApplyOp : EmitC_Op<"apply", [CExpression]> {
   }];
   let arguments = (ins
     Arg<StrAttr, "the operator to apply">:$applicableOperator,
-    AnyType:$operand
+    EmitCType:$operand
   );
-  let results = (outs AnyType:$result);
+  let results = (outs EmitCType:$result);
   let assemblyFormat = [{
     $applicableOperator `(` $operand `)` attr-dict `:` functional-type($operand, results)
   }];
@@ -240,9 +240,9 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
     Arg<StrAttr, "the C++ function to call">:$callee,
     Arg<OptionalAttr<ArrayAttr>, "the order of operands and further attributes">:$args,
     Arg<OptionalAttr<ArrayAttr>, "template arguments">:$template_args,
-    Variadic<AnyType>:$operands
+    Variadic<EmitCType>:$operands
   );
-  let results = (outs Variadic<AnyType>);
+  let results = (outs Variadic<EmitCType>);
   let builders = [
     OpBuilder<(ins
       "::mlir::TypeRange":$resultTypes,
@@ -284,8 +284,8 @@ def EmitC_CastOp : EmitC_Op<"cast",
     ```
   }];
 
-  let arguments = (ins AnyType:$source);
-  let results = (outs AnyType:$dest);
+  let arguments = (ins EmitCType:$source);
+  let results = (outs EmitCType:$dest);
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
 }
 
@@ -323,9 +323,9 @@ def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
   }];
 
   let arguments = (ins EmitC_CmpPredicateAttr:$predicate,
-                       AnyType:$lhs,
-                       AnyType:$rhs);
-  let results = (outs AnyType);
+                       EmitCType:$lhs,
+                       EmitCType:$rhs);
+  let results = (outs EmitCType);
 
   let assemblyFormat = "$predicate `,` operands attr-dict `:` functional-type(operands, results)";
 }
@@ -353,7 +353,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
   }];
 
   let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
-  let results = (outs AnyType);
+  let results = (outs EmitCType);
 
   let hasFolder = 1;
   let hasVerifier = 1;
@@ -423,7 +423,7 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
   }];
 
   let arguments = (ins UnitAttr:$do_not_inline);
-  let results = (outs AnyType:$result);
+  let results = (outs EmitCType:$result);
   let regions = (region SizedRegion<1>:$region);
 
   let hasVerifier = 1;
@@ -531,8 +531,8 @@ def EmitC_CallOp : EmitC_Op<"call",
     %2 = emitc.call @my_add(%0, %1) : (f32, f32) -> f32
     ```
   }];
-  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
-  let results = (outs Variadic<AnyType>);
+  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<EmitCType>:$operands);
+  let results = (outs Variadic<EmitCType>);
 
   let builders = [
     OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{
@@ -722,7 +722,7 @@ def EmitC_ReturnOp : EmitC_Op<"return", [Pure, HasParent<"FuncOp">,
     }
     ```
   }];
-  let arguments = (ins Optional<AnyType>:$operand);
+  let arguments = (ins Optional<EmitCType>:$operand);
 
   let assemblyFormat = "attr-dict ($operand^ `:` type($operand))?";
   let hasVerifier = 1;
@@ -766,7 +766,7 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> {
   }];
 
   let arguments = (ins StrAttr:$value);
-  let results = (outs AnyType:$result);
+  let results = (outs EmitCType:$result);
 
   let hasVerifier = 1;
   let assemblyFormat = "$value attr-dict `:` type($result)";
@@ -932,8 +932,8 @@ def EmitC_ConditionalOp : EmitC_Op<"conditional",
     int32_t v6 = v3 ? v4 : v5;
     ```
   }];
-  let arguments = (ins I1:$condition, AnyType:$true_value, AnyType:$false_value);
-  let results = (outs AnyType:$result);
+  let arguments = (ins I1:$condition, EmitCType:$true_value, EmitCType:$false_value);
+  let results = (outs EmitCType:$result);
   let assemblyFormat = "operands attr-dict `:` type($result)";
 }
 
@@ -1009,7 +1009,7 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
   }];
 
   let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
-  let results = (outs AnyType);
+  let results = (outs EmitCType);
 
   let hasVerifier = 1;
 }
@@ -1068,7 +1068,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
     ```
   }];
 
-  let arguments = (ins AnyType:$var, AnyType:$value);
+  let arguments = (ins EmitCType:$var, EmitCType:$value);
   let results = (outs);
 
   let hasVerifier = 1;
@@ -1089,7 +1089,7 @@ def EmitC_YieldOp : EmitC_Op<"yield",
     value is yielded.
   }];
 
-  let arguments = (ins Optional<AnyType>:$result);
+  let arguments = (ins Optional<EmitCType>:$result);
   let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
 
   let hasVerifier = 1;
@@ -1173,8 +1173,8 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
       EmitC_OpaqueType,
       EmitC_PointerType]>,
     "the value to subscript">:$value,
-    Variadic<AnyType>:$indices);
-  let results = (outs AnyType:$result);
+    Variadic<EmitCType>:$indices);
+  let results = (outs EmitCType:$result);
 
   let builders = [
     OpBuilder<(ins "TypedValue<ArrayType>":$array, "ValueRange":$indices), [{
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
index bce5807230ce49..444395b915e250 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
@@ -22,6 +22,9 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
 // EmitC type definitions
 //===----------------------------------------------------------------------===//
 
+def EmitCType : Type<CPred<"emitc::isSupportedEmitCType($_self)">,
+    "type supported by EmitC">;
+
 def EmitCIntegerType : Type<CPred<"emitc::isSupportedIntegerType($_self)">,
     "integer type supported by EmitC">;
 
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 7cbf28b627342a..b037ef3c0b4152 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -10,11 +10,15 @@
 #include "mlir/Dialect/EmitC/IR/EmitCTraits.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Types.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
 
 using namespace mlir;
 using namespace mlir::emitc;
@@ -54,6 +58,40 @@ void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) {
   builder.create<emitc::YieldOp>(loc);
 }
 
+bool mlir::emitc::isSupportedEmitCType(Type type) {
+  if (llvm::isa<emitc::OpaqueType>(type))
+    return true;
+  if (auto ptrType = llvm::dyn_cast<emitc::PointerType>(type))
+    return isSupportedEmitCType(ptrType.getPointee());
+  if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) {
+    auto elemType = arrayType.getElementType();
+    return !llvm::isa<emitc::ArrayType>(elemType) &&
+           isSupportedEmitCType(elemType);
+  }
+  if (type.isIndex())
+    return true;
+  if (llvm::isa<IntegerType>(type))
+    return isSupportedIntegerType(type);
+  if (llvm::isa<FloatType>(type))
+    return isSupportedFloatType(type);
+  if (auto tensorType = llvm::dyn_cast<TensorType>(type)) {
+    if (!tensorType.hasStaticShape()) {
+      return false;
+    }
+    auto elemType = tensorType.getElementType();
+    if (llvm::isa<emitc::ArrayType>(elemType)) {
+      return false;
+    }
+    return isSupportedEmitCType(elemType);
+  }
+  if (auto tupleType = llvm::dyn_cast<TupleType>(type)) {
+    return llvm::all_of(tupleType.getTypes(), [](Type type) {
+      return !llvm::isa<emitc::ArrayType>(type) && isSupportedEmitCType(type);
+    });
+  }
+  return false;
+}
+
 bool mlir::emitc::isSupportedIntegerType(Type type) {
   if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
     switch (intType.getWidth()) {
diff --git a/mlir/test/Dialect/EmitC/invalid_types.mlir b/mlir/test/Dialect/EmitC/invalid_types.mlir
index f9d517bf689b95..0ad8d4eabe6b8b 100644
--- a/mlir/test/Dialect/EmitC/invalid_types.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_types.mlir
@@ -97,3 +97,51 @@ func.func @illegal_float_type(%arg0: f80, %arg1: f80) {
     %mul = "emitc.mul" (%arg0, %arg1) : (f80, f80) -> f80
     return
 }
+
+// -----
+
+func.func @illegal_pointee_type() {
+    // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got '!emitc.ptr<i11>'}}
+    %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.ptr<i11>
+    return
+}
+
+// -----
+
+func.func @illegal_non_static_tensor_shape_type() {
+    // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor<?xf32>'}}
+    %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor<?xf32>
+    return
+}
+
+// -----
+
+func.func @illegal_tensor_array_element_type() {
+    // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor<!emitc.array<9xi16>>'}}
+    %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor<!emitc.array<9xi16>>
+    return
+}
+
+// -----
+
+func.func @illegal_tensor_integer_element_type() {
+    // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor<9xi11>'}}
+    %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor<9xi11>
+    return
+}
+
+// -----
+
+func.func @illegal_tuple_array_element_type() {
+    // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tuple<!emitc.array<9xf32>, f32>'}}
+    %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tuple<!emitc.array<9xf32>, f32>
+    return
+}
+
+// -----
+
+func.func @illegal_tuple_float_element_type() {
+    // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tuple<i32, f80>'}}
+    %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tuple<i32, f80>
+    return
+}



More information about the Mlir-commits mailing list