[Mlir-commits] [mlir] 32b327e - [mlir][ods] Use lambda in element type check pred rather than repeated casts

Jacques Pienaar llvmlistbot at llvm.org
Wed Nov 10 16:27:47 PST 2021


Author: Jacques Pienaar
Date: 2021-11-10T16:27:37-08:00
New Revision: 32b327e4ed8c9ad9f99b88ccaaaab74a4cf809db

URL: https://github.com/llvm/llvm-project/commit/32b327e4ed8c9ad9f99b88ccaaaab74a4cf809db
DIFF: https://github.com/llvm/llvm-project/commit/32b327e4ed8c9ad9f99b88ccaaaab74a4cf809db.diff

LOG: [mlir][ods] Use lambda in element type check pred rather than repeated casts

Avoids multiple cast & getElementType calls. Just a local change for ShapedType
containers but reduces one model case from 24.7 to 24.04s.

Resultant code generated change:
https://gist.github.com/jpienaar/7ffd2e9b0737134ba2ea2729b91c9572

Differential Revision: https://reviews.llvm.org/D113621

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpBase.td
    mlir/test/mlir-tblgen/predicate.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 360832f9478b8..32535c23076e1 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -566,20 +566,17 @@ class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
     Type<And<[containerPred,
                 SubstLeaves<"$_self", !cast<string>(elementTypeCall),
                 etype.predicate>]>,
-         descr # " of " # etype.summary # " values", cppClassName> {
-  // The type of elements in the container.
-  Type elementType = etype;
-
-  // Call to retrieve.
-  code getElementTypeCall = elementTypeCall;
-}
+         descr # " of " # etype.summary # " values", cppClassName>;
 
 class ShapedContainerType<list<Type> allowedTypes,
                           Pred containerPred, string descr,
                           string cppClassName = "::mlir::Type"> :
-    ContainerType<AnyTypeOf<allowedTypes>, containerPred,
-                  "$_self.cast<::mlir::ShapedType>().getElementType()", descr,
-                  cppClassName>;
+    Type<And<[containerPred,
+              Concat<"[](::mlir::Type elementType) { return ",
+                SubstLeaves<"$_self", "elementType",
+                AnyTypeOf<allowedTypes>.predicate>,
+                "; }($_self.cast<::mlir::ShapedType>().getElementType())">]>,
+         descr # " of " # AnyTypeOf<allowedTypes>.summary # " values", cppClassName>;
 
 // Whether a shaped type is ranked.
 def HasRankPred : CPred<"$_self.cast<::mlir::ShapedType>().hasRank()">;

diff  --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td
index 2170bfb829d55..f0130ac52f6bc 100644
--- a/mlir/test/mlir-tblgen/predicate.td
+++ b/mlir/test/mlir-tblgen/predicate.td
@@ -25,11 +25,11 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
 // CHECK-NOT:    return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
 
 // CHECK: static ::mlir::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
-// CHECK-NEXT:  if (!(((type.isa<::mlir::TensorType>())) && ((true)))) {
+// CHECK-NEXT:  if (!(((type.isa<::mlir::TensorType>())) && ([](Type elementType) { return (true); }(type.cast<::mlir::ShapedType>().getElementType())))) {
 // CHECK-NEXT:    return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type;
 
 // CHECK: static ::mlir::LogicalResult [[$TENSOR_INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
-// CHECK-NEXT:  if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) {
+// CHECK-NEXT:  if (!(((type.isa<::mlir::TensorType>())) && ([](Type elementType) { return ((elementType.isF32())) || ((elementType.isSignlessInteger(32))); }(type.cast<::mlir::ShapedType>().getElementType())))) {
 // CHECK-NEXT:    return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type;
 
 // CHECK-LABEL: OpA::verify


        


More information about the Mlir-commits mailing list