[Mlir-commits] [mlir] bcc9b37 - Split `ElementwiseMappable` trait into four more precise traits.

Frederik Gossen llvmlistbot at llvm.org
Tue Mar 2 06:31:40 PST 2021


Author: Frederik Gossen
Date: 2021-03-02T15:31:19+01:00
New Revision: bcc9b371e43be8fa3fba65f2363eaf767731e0c7

URL: https://github.com/llvm/llvm-project/commit/bcc9b371e43be8fa3fba65f2363eaf767731e0c7
DIFF: https://github.com/llvm/llvm-project/commit/bcc9b371e43be8fa3fba65f2363eaf767731e0c7.diff

LOG: Split `ElementwiseMappable` trait into four more precise traits.

Some elementwise operations are not scalarizable, vectorizable, or tensorizable.
Split `ElementwiseMappable` trait into the following, more precise traits.
  - `Elementwise`
  - `Scalarizable`
  - `Vectorizable`
  - `Tensorizable`
This allows for reuse of `Elementwise` in dialects like HLO.

Differential Revision: https://reviews.llvm.org/D97674

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
    mlir/include/mlir/Dialect/Math/IR/MathOps.td
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/OpDefinition.h
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/IR/Operation.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 960f5f64eec3..5e4648da488d 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -20,11 +20,9 @@ class Complex_Op<string mnemonic, list<OpTrait> traits = []>
 // floating-point element type. These operations take two operands and return
 // one result, all of which must be complex numbers of the same type.
 class ComplexArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
-    Complex_Op<mnemonic,
-       !listconcat(traits, [NoSideEffect,
-                            SameOperandsAndResultType,
-                            DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
-                            ElementwiseMappable])> {
+    Complex_Op<mnemonic, traits # [NoSideEffect, SameOperandsAndResultType,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
+    ElementwiseMappable.traits> {
   let arguments = (ins Complex<AnyFloat>:$lhs, Complex<AnyFloat>:$rhs);
   let results = (outs Complex<AnyFloat>:$result);
   let assemblyFormat = "$lhs `,` $rhs  attr-dict `:` type($result)";

diff  --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 0ac5c0e05706..7ce4dc3a7aee 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -17,10 +17,9 @@ class MathOp<string mnemonic, list<OpTrait> traits = []>
     : Op<Math_Dialect, mnemonic, traits # [NoSideEffect]>;
 
 class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
-    MathOp<mnemonic,
-      traits # [DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
-                ElementwiseMappable,
-                SameOperandsAndResultType]> {
+    MathOp<mnemonic, traits #
+    [DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
+    SameOperandsAndResultType] # ElementwiseMappable.traits> {
   let arguments = (ins FloatLike:$operand);
 
   let results = (outs FloatLike:$result);
@@ -29,10 +28,9 @@ class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
 }
 
 class FloatBinaryOp<string mnemonic, list<OpTrait> traits = []> :
-    MathOp<mnemonic,
-      traits # [DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
-                ElementwiseMappable,
-                SameOperandsAndResultType]> {
+    MathOp<mnemonic, traits # [
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
+    SameOperandsAndResultType] # ElementwiseMappable.traits> {
   let arguments = (ins FloatLike:$lhs, FloatLike:$rhs);
   let results = (outs FloatLike:$result);
 

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index fe054c59ae6e..24ea8ea135d6 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -71,9 +71,9 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
 
 // Base class for arithmetic cast operations.
 class ArithmeticCastOp<string mnemonic, list<OpTrait> traits = []> :
-    CastOp<mnemonic,
-    !listconcat(traits, [ElementwiseMappable,
-                         DeclareOpInterfaceMethods<VectorUnrollOpInterface>])> {
+    CastOp<mnemonic, traits # [
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
+    ElementwiseMappable.traits> {
 }
 
 // Base class for unary ops. Requires single operand and result. Individual
@@ -95,21 +95,18 @@ class UnaryOpSameOperandAndResultType<string mnemonic,
 }
 
 class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
-    UnaryOpSameOperandAndResultType<mnemonic,
-      !listconcat(traits,
-                  [DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
-                   ElementwiseMappable])>,
-    Arguments<(ins FloatLike:$operand)>;
+    UnaryOpSameOperandAndResultType<mnemonic, traits # [
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
+    ElementwiseMappable.traits>, Arguments<(ins FloatLike:$operand)>;
 
 // Base class for standard arithmetic operations.  Requires operands and
 // results to be of the same type, but does not constrain them to specific
 // types.
 class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
-    Op<StandardOps_Dialect, mnemonic,
-       !listconcat(traits, [NoSideEffect,
-                            SameOperandsAndResultType,
-                            DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
-                            ElementwiseMappable])> {
+    Op<StandardOps_Dialect, mnemonic, traits # [NoSideEffect,
+    SameOperandsAndResultType,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
+    ElementwiseMappable.traits> {
 
   let results = (outs AnyType:$result);
 
@@ -930,12 +927,10 @@ def CmpFPredicateAttr : I64EnumAttr<
   let cppNamespace = "::mlir";
 }
 
-def CmpFOp : Std_Op<"cmpf",
-    [NoSideEffect, SameTypeOperands, ElementwiseMappable,
-     DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
-     TypesMatchWith<
-       "result type has i1 element type and same shape as operands",
-       "lhs", "result", "getI1SameShape($_self)">]> {
+def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface>, TypesMatchWith<
+    "result type has i1 element type and same shape as operands",
+    "lhs", "result", "getI1SameShape($_self)">] # ElementwiseMappable.traits> {
   let summary = "floating-point comparison operation";
   let description = [{
     The `cmpf` operation compares its two operands according to the float
@@ -1015,12 +1010,10 @@ def CmpIPredicateAttr : I64EnumAttr<
   let cppNamespace = "::mlir";
 }
 
-def CmpIOp : Std_Op<"cmpi",
-    [NoSideEffect, SameTypeOperands, ElementwiseMappable,
-     DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
-     TypesMatchWith<
-       "result type has i1 element type and same shape as operands",
-       "lhs", "result", "getI1SameShape($_self)">]> {
+def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameTypeOperands,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface>, TypesMatchWith<
+    "result type has i1 element type and same shape as operands",
+    "lhs", "result", "getI1SameShape($_self)">] # ElementwiseMappable.traits> {
   let summary = "integer comparison operation";
   let description = [{
     The `cmpi` operation is a generic comparison for integer-like types. Its two
@@ -2160,8 +2153,9 @@ def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
 //===----------------------------------------------------------------------===//
 
 def SelectOp : Std_Op<"select", [NoSideEffect,
-     AllTypesMatch<["true_value", "false_value", "result"]>,
-     ElementwiseMappable, DeclareOpInterfaceMethods<VectorUnrollOpInterface>]> {
+    AllTypesMatch<["true_value", "false_value", "result"]>,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
+    ElementwiseMappable.traits> {
   let summary = "select operation";
   let description = [{
     The `select` operation chooses one value based on a binary condition
@@ -2391,9 +2385,9 @@ def SignedShiftRightOp : IntBinaryOp<"shift_right_signed"> {
 // SignExtendIOp
 //===----------------------------------------------------------------------===//
 
-def SignExtendIOp : Std_Op<"sexti",
-    [NoSideEffect, ElementwiseMappable,
-    DeclareOpInterfaceMethods<VectorUnrollOpInterface>]> {
+def SignExtendIOp : Std_Op<"sexti", [NoSideEffect,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
+    ElementwiseMappable.traits> {
   let summary = "integer sign extension operation";
   let description = [{
     The integer sign extension operation takes an integer input of
@@ -3220,9 +3214,9 @@ def TransposeOp : Std_Op<"transpose", [NoSideEffect]>,
 // TruncateIOp
 //===----------------------------------------------------------------------===//
 
-def TruncateIOp : Std_Op<"trunci",
-  [NoSideEffect, ElementwiseMappable,
-   DeclareOpInterfaceMethods<VectorUnrollOpInterface>,]> {
+def TruncateIOp : Std_Op<"trunci", [NoSideEffect,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
+    ElementwiseMappable.traits> {
   let summary = "integer truncation operation";
   let description = [{
     The integer truncation operation takes an integer input of
@@ -3463,9 +3457,9 @@ def XOrOp : IntBinaryOp<"xor", [Commutative]> {
 // ZeroExtendIOp
 //===----------------------------------------------------------------------===//
 
-def ZeroExtendIOp : Std_Op<"zexti",
-  [NoSideEffect, ElementwiseMappable,
-   DeclareOpInterfaceMethods<VectorUnrollOpInterface>,]> {
+def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
+    ElementwiseMappable.traits> {
   let summary = "integer zero extension operation";
   let description = [{
     The integer zero extension operation takes an integer input of

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index c0089f7a2a51..f0f5f1d0e4ea 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1785,9 +1785,25 @@ def Terminator : NativeOpTrait<"IsTerminator">;
 // Op can be safely normalized in the presence of MemRefs with
 // non-identity maps.
 def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">;
-// Op can be systematically interconverted between scalar and vector/tensor
-// form by mapping elementwise based on the type.
-def ElementwiseMappable : NativeOpTrait<"ElementwiseMappable">;
+// Op is elementwise on tensor/vector operands and results.
+def Elementwise : NativeOpTrait<"Elementwise">;
+// Elementwise op can be applied to scalars instead tensor/vector operands.
+def Scalarizable : NativeOpTrait<"Scalarizable">;
+// Elementwise op can be applied to all-vector operands.
+def Vectorizable : NativeOpTrait<"Vectorizable">;
+// Elementwise op can be applied to all-tensor operands.
+def Tensorizable : NativeOpTrait<"Tensorizable">;
+
+// Group together `Elementwise`, `Scalarizable`, `Vectorizable`, and
+// `Tensorizable` for convenience.
+def ElementwiseMappable {
+  list<OpTrait> traits = [
+    Elementwise,
+    Scalarizable,
+    Vectorizable,
+    Tensorizable,
+  ];
+}
 
 // Op's regions have a single block with the specified terminator.
 class SingleBlockImplicitTerminator<string op>

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index bfddcea950f2..c65e653f2a50 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -282,7 +282,7 @@ LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
 LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
 LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
 LogicalResult verifyNoRegionArguments(Operation *op);
-LogicalResult verifyElementwiseMappable(Operation *op);
+LogicalResult verifyElementwise(Operation *op);
 } // namespace impl
 
 /// Helper class for implementing traits.  Clients are not expected to interact
@@ -1213,93 +1213,144 @@ template <typename ConcrentType>
 struct MemRefsNormalizable
     : public TraitBase<ConcrentType, MemRefsNormalizable> {};
 
-/// This trait tags scalar ops that also can be applied to vectors/tensors, with
-/// their semantics on vectors/tensors being elementwise application.
+/// This trait tags element-wise ops that operate on scalars, vectors, or
+/// tensors.
 ///
 /// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this
-/// trait. In particular, broadcasting behavior is not allowed. This trait
-/// describes a set of invariants that allow systematic
-/// vectorization/tensorization, and the reverse, scalarization. The properties
-/// needed for this also can be used to implement a number of
-/// transformations/analyses/interfaces.
+/// trait. In particular, broadcasting behavior is not allowed.
 ///
-/// An `ElementwiseMappable` op must satisfy the following properties:
+/// An `Elementwise` op must satisfy the following properties:
 ///
-/// 1. If any result is a vector (resp. tensor), then at least one operand must
-/// be a vector (resp. tensor).
-/// 2. If any operand is a vector (resp. tensor), then there must be at least
-/// one result, and all results must be vectors (resp. tensors).
-/// 3. The static types of all vector (resp. tensor) operands and results must
-/// have the same shape.
-/// 4. In the case of tensor operands, the dynamic shapes of all tensor operands
-/// must be the same, otherwise the op has undefined behavior.
-/// 5. ("systematic scalarization" property) If an op has vector/tensor
-/// operands/results, then the same op, with the operand/result types changed to
-/// their corresponding element type, shall be a verifier-valid op.
-/// 6. The semantics of the op on vectors (resp. tensors) shall be the same as
-/// applying the scalarized version of the op for each corresponding element of
-/// the vector (resp. tensor) operands in parallel.
-/// 7. ("systematic vectorization/tensorization" property) If an op has
-/// scalar operands/results, the op shall remain verifier-valid if all scalar
-/// operands are replaced with vectors/tensors of the same shape and
-/// corresponding element types.
+/// 1. If any result is a vector/tensor then at least one operand must also be a
+///    vector/tensor.
+/// 2. If any operand is a vector/tensor then there must be at least one result
+///    and all results must be vectors/tensors.
+/// 3. All operand and result vector/tensor types must be of the same shape. The
+///    shape may be dynamic in which case the op's behaviour is undefined for
+///    non-matching shapes.
+/// 4. The operation must be elementwise on its vector/tensor operands and
+///    results. When applied to single-element vectors/tensors, the result must
+///    be the same per elememnt.
 ///
-/// Together, these properties provide an easy way for scalar operations to
-/// conveniently generalize their behavior to vectors/tensors, and systematize
-/// conversion between these forms.
+/// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new
+/// interface `ElementwiseTypeInterface` that describes the container types for
+/// which the operation is elementwise.
 ///
-/// Examples:
+/// Rationale:
+/// - 1. and 2. guarantee a well-defined iteration space and exclude the cases
+///   of 0 non-scalar operands or 0 non-scalar results, which complicate a
+///   generic definition of the iteration space.
+/// - 3. guarantees that folding can be done across scalars/vectors/tensors with
+///   the same pattern, as otherwise lots of special handling for type
+///   mismatches would be needed.
+/// - 4. guarantees that no error handling is needed. Higher-level dialects
+///   should reify any needed guards or error handling code before lowering to
+///   an `Elementwise` op.
+template <typename ConcreteType>
+struct Elementwise : public TraitBase<ConcreteType, Elementwise> {
+  static LogicalResult verifyTrait(Operation *op) {
+    return ::mlir::OpTrait::impl::verifyElementwise(op);
+  }
+};
+
+/// This trait tags `Elementwise` operatons that can be systematically
+/// scalarized. All vector/tensor operands and results are then replaced by
+/// scalars of the respective element type. Semantically, this is the operation
+/// on a single element per vector/tensor.
+///
+/// Rationale:
+/// Allow to define the vector/tensor semantics of elementwise operations based
+/// on scalars. This provides a constructive procedure for IR transformations
+/// to, e.g., create scalar loop bodies from tensor ops.
+///
+/// Example:
 /// ```
-/// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32
-/// // Applying the systematic vectorization/tensorization property, this op
-/// // must also be valid:
-/// %tensor = "std.addf"(%a_tensor, %b_tensor)
-///           : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>)
+/// %tensor_select = "std.select"(%pred_tensor, %true_val, %false_val)
+///                      : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
+///                      -> tensor<?xf32>
+/// ```
+/// can be scalarized to
 ///
-/// // These properties generalize well to the cases of non-scalar operands.
-/// %select_scalar_pred = "std.select"(%pred, %true_val, %false_val)
-///                       : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
-/// // Applying the systematic vectorization / tensorization property, this
-/// // op must also be valid:
-/// %select_tensor_pred = "std.select"(%pred_tensor, %true_val, %false_val)
-///                       : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
-///                       -> tensor<?xf32>
-/// // Applying the systematic scalarization property, this op must also
-/// // be valid.
-/// %select_scalar = "std.select"(%pred, %true_val_scalar, %false_val_scalar)
-///                  : (i1, f32, f32) -> f32
 /// ```
+/// %scalar_select = "std.select"(%pred, %true_val_scalar, %false_val_scalar)
+///                      : (i1, f32, f32) -> f32
+/// ```
+template <typename ConcreteType>
+struct Scalarizable : public TraitBase<ConcreteType, Scalarizable> {
+  static LogicalResult verifyTrait(Operation *op) {
+    static_assert(
+        ConcreteType::template hasTrait<Elementwise>(),
+        "`Scalarizable` trait is only applicable to `Elementwise` ops.");
+    return success();
+  }
+};
+
+/// This trait tags `Elementwise` operatons that can be systematically
+/// vectorized. All scalar operands and results are then replaced by vectors
+/// with the respective element type. Semantically, this is the operation on
+/// multiple arguments simultaneously.
 ///
-/// TODO: Avoid hardcoding vector/tensor, and generalize this to any type
-/// implementing a new "ElementwiseMappableTypeInterface" that describes types
-/// for which it makes sense to apply a scalar function to each element.
+/// Rationale:
+/// Provide the reverse to `Scalarizable` which, when chained together, allows
+/// reasoning about the relationship between the tensor and vector case.
+/// Additionally, it permits reasoning about promoting scalars to vectors via
+/// broadcasting in cases like `%select_scalar_pred` above.
+template <typename ConcreteType>
+struct Vectorizable : public TraitBase<ConcreteType, Vectorizable> {
+  static LogicalResult verifyTrait(Operation *op) {
+    static_assert(
+        ConcreteType::template hasTrait<Elementwise>(),
+        "`Vectorizable` trait is only applicable to `Elementwise` ops.");
+    return success();
+  }
+};
+
+/// This trait tags `Elementwise` operatons that can be systematically
+/// tensorized. All scalar operands and results are then replaced by tensors
+/// with the respective element type. Semantically, this is the operation on
+/// multiple arguments simultaneously.
 ///
 /// Rationale:
-/// - 1. and 2. guarantee a well-defined iteration space for 6.
-///   - These also exclude the cases of 0 non-scalar operands or 0 non-scalar
-///     results, which complicate a generic definition of the iteration space.
-/// - 3. guarantees that folding can be done across scalars/vectors/tensors
-///   with the same pattern, as otherwise lots of special handling of type
-///   mismatches would be needed.
-/// - 4. guarantees that no error handling cases need to be considered.
-///   - Higher-level dialects should reify any needed guards / error handling
-///   code before lowering to an ElementwiseMappable op.
-/// - 5. and 6. allow defining the semantics on vectors/tensors via the scalar
-///   semantics and provide a constructive procedure for IR transformations
-///   to e.g. create scalar loop bodies from tensor ops.
-/// - 7. provides the reverse of 5., which when chained together allows
-///   reasoning about the relationship between the tensor and vector case.
-///   Additionally, it permits reasoning about promoting scalars to
-///   vectors/tensors via broadcasting in cases like `%select_scalar_pred`
-///   above.
+/// Provide the reverse to `Scalarizable` which, when chained together, allows
+/// reasoning about the relationship between the tensor and vector case.
+/// Additionally, it permits reasoning about promoting scalars to tensors via
+/// broadcasting in cases like `%select_scalar_pred` above.
+///
+/// Examples:
+/// ```
+/// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32
+/// ```
+/// can be tensorized to
+/// ```
+/// %tensor = "std.addf"(%a, %b) : (tensor<?xf32>, tensor<?xf32>)
+///               -> tensor<?xf32>)
+/// ```
+///
+/// ```
+/// %scalar_pred = "std.select"(%pred, %true_val, %false_val)
+///                    : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+/// ```
+/// can be tensorized to
+/// ```
+/// %tensor_pred = "std.select"(%pred, %true_val, %false_val)
+///                    : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
+///                    -> tensor<?xf32>
+/// ```
 template <typename ConcreteType>
-struct ElementwiseMappable
-    : public TraitBase<ConcreteType, ElementwiseMappable> {
+struct Tensorizable : public TraitBase<ConcreteType, Tensorizable> {
   static LogicalResult verifyTrait(Operation *op) {
-    return ::mlir::OpTrait::impl::verifyElementwiseMappable(op);
+    static_assert(
+        ConcreteType::template hasTrait<Elementwise>(),
+        "`Tensorizable` trait is only applicable to `Elementwise` ops.");
+    return success();
   }
 };
 
+/// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable`
+/// provide an easy way for scalar operations to conveniently generalize their
+/// behavior to vectors/tensors, and systematize conversion between these forms.
+bool hasElementwiseMappableTraits(Operation *op);
+
 } // end namespace OpTrait
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index 3012e920a82f..1d50e067cff3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -18,7 +18,7 @@
 using namespace mlir;
 
 static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
-  if (!op->hasTrait<OpTrait::ElementwiseMappable>())
+  if (!OpTrait::hasElementwiseMappableTraits(op))
     return false;
 
   // TODO: The conversion pattern can be made to work for `any_of` here, but

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index e350d8931753..f471ab0ebd75 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -205,7 +205,7 @@ vectorizeOneOp(OpBuilder &builder, Operation *op,
     return VectorizationResult{VectorizationStatus::NewOp, builder.clone(*op)};
 
   // 3. Only ElementwiseMappable are allowed in the generic vectorization.
-  if (!op->hasTrait<OpTrait::ElementwiseMappable>())
+  if (!OpTrait::hasElementwiseMappableTraits(op))
     return VectorizationResult{VectorizationStatus::Failure, nullptr};
 
   // 4. Generic vectorization path for ElementwiseMappable ops.
@@ -323,7 +323,7 @@ static bool hasOnlyScalarElementwiseOp(Region &r) {
     return false;
   for (Operation &op : r.front()) {
     if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
-          op.hasTrait<OpTrait::ElementwiseMappable>()) ||
+          OpTrait::hasElementwiseMappableTraits(&op)) ||
         llvm::any_of(op.getResultTypes(),
                      [](Type type) { return !type.isIntOrIndexOrFloat(); }))
       return false;

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index b1f2ad66efdb..f9ccda5a2544 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1085,7 +1085,7 @@ static bool areSameShapedTypeIgnoringElementType(ShapedType a, ShapedType b) {
   return a.getShape() == b.getShape();
 }
 
-LogicalResult OpTrait::impl::verifyElementwiseMappable(Operation *op) {
+LogicalResult OpTrait::impl::verifyElementwise(Operation *op) {
   auto isMappableType = [](Type type) {
     return type.isa<VectorType, TensorType>();
   };
@@ -1127,6 +1127,11 @@ LogicalResult OpTrait::impl::verifyElementwiseMappable(Operation *op) {
   return success();
 }
 
+bool OpTrait::hasElementwiseMappableTraits(Operation *op) {
+  return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() &&
+         op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>();
+}
+
 //===----------------------------------------------------------------------===//
 // BinaryOp implementation
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 4893ac3d8492..828c31d82e11 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -356,7 +356,8 @@ def SameOperandElementTypeOp : TEST_Op<"same_operand_element_type",
   let results = (outs AnyType);
 }
 
-def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_element_type",
+def SameOperandAndResultElementTypeOp :
+    TEST_Op<"same_operand_and_result_element_type",
     [SameOperandsAndResultElementType]> {
   let arguments = (ins Variadic<AnyType>);
   let results = (outs Variadic<AnyType>);
@@ -379,7 +380,7 @@ def SameOperandAndResultTypeOp : TEST_Op<"same_operand_and_result_type",
 }
 
 def ElementwiseMappableOp : TEST_Op<"elementwise_mappable",
-    [ElementwiseMappable]> {
+    ElementwiseMappable.traits> {
   let arguments = (ins Variadic<AnyType>);
   let results = (outs Variadic<AnyType>);
 }


        


More information about the Mlir-commits mailing list