[Mlir-commits] [mlir] bcc9b37 - Split `ElementwiseMappable` trait into four more precise traits.
Frederik Gossen
llvmlistbot at llvm.org
Tue Mar 2 06:31:40 PST 2021
Author: Frederik Gossen
Date: 2021-03-02T15:31:19+01:00
New Revision: bcc9b371e43be8fa3fba65f2363eaf767731e0c7
URL: https://github.com/llvm/llvm-project/commit/bcc9b371e43be8fa3fba65f2363eaf767731e0c7
DIFF: https://github.com/llvm/llvm-project/commit/bcc9b371e43be8fa3fba65f2363eaf767731e0c7.diff
LOG: Split `ElementwiseMappable` trait into four more precise traits.
Some elementwise operations are not scalarizable, vectorizable, or tensorizable.
Split `ElementwiseMappable` trait into the following, more precise traits.
- `Elementwise`
- `Scalarizable`
- `Vectorizable`
- `Tensorizable`
This allows for reuse of `Elementwise` in dialects like HLO.
Differential Revision: https://reviews.llvm.org/D97674
Added:
Modified:
mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
mlir/include/mlir/Dialect/Math/IR/MathOps.td
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/IR/Operation.cpp
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 960f5f64eec3..5e4648da488d 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -20,11 +20,9 @@ class Complex_Op<string mnemonic, list<OpTrait> traits = []>
// floating-point element type. These operations take two operands and return
// one result, all of which must be complex numbers of the same type.
class ComplexArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
- Complex_Op<mnemonic,
- !listconcat(traits, [NoSideEffect,
- SameOperandsAndResultType,
- DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
- ElementwiseMappable])> {
+ Complex_Op<mnemonic, traits # [NoSideEffect, SameOperandsAndResultType,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
+ ElementwiseMappable.traits> {
let arguments = (ins Complex<AnyFloat>:$lhs, Complex<AnyFloat>:$rhs);
let results = (outs Complex<AnyFloat>:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 0ac5c0e05706..7ce4dc3a7aee 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -17,10 +17,9 @@ class MathOp<string mnemonic, list<OpTrait> traits = []>
: Op<Math_Dialect, mnemonic, traits # [NoSideEffect]>;
class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
- MathOp<mnemonic,
- traits # [DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
- ElementwiseMappable,
- SameOperandsAndResultType]> {
+ MathOp<mnemonic, traits #
+ [DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
+ SameOperandsAndResultType] # ElementwiseMappable.traits> {
let arguments = (ins FloatLike:$operand);
let results = (outs FloatLike:$result);
@@ -29,10 +28,9 @@ class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
}
class FloatBinaryOp<string mnemonic, list<OpTrait> traits = []> :
- MathOp<mnemonic,
- traits # [DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
- ElementwiseMappable,
- SameOperandsAndResultType]> {
+ MathOp<mnemonic, traits # [
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
+ SameOperandsAndResultType] # ElementwiseMappable.traits> {
let arguments = (ins FloatLike:$lhs, FloatLike:$rhs);
let results = (outs FloatLike:$result);
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index fe054c59ae6e..24ea8ea135d6 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -71,9 +71,9 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
// Base class for arithmetic cast operations.
class ArithmeticCastOp<string mnemonic, list<OpTrait> traits = []> :
- CastOp<mnemonic,
- !listconcat(traits, [ElementwiseMappable,
- DeclareOpInterfaceMethods<VectorUnrollOpInterface>])> {
+ CastOp<mnemonic, traits # [
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
+ ElementwiseMappable.traits> {
}
// Base class for unary ops. Requires single operand and result. Individual
@@ -95,21 +95,18 @@ class UnaryOpSameOperandAndResultType<string mnemonic,
}
class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
- UnaryOpSameOperandAndResultType<mnemonic,
- !listconcat(traits,
- [DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
- ElementwiseMappable])>,
- Arguments<(ins FloatLike:$operand)>;
+ UnaryOpSameOperandAndResultType<mnemonic, traits # [
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
+ ElementwiseMappable.traits>, Arguments<(ins FloatLike:$operand)>;
// Base class for standard arithmetic operations. Requires operands and
// results to be of the same type, but does not constrain them to specific
// types.
class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
- Op<StandardOps_Dialect, mnemonic,
- !listconcat(traits, [NoSideEffect,
- SameOperandsAndResultType,
- DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
- ElementwiseMappable])> {
+ Op<StandardOps_Dialect, mnemonic, traits # [NoSideEffect,
+ SameOperandsAndResultType,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
+ ElementwiseMappable.traits> {
let results = (outs AnyType:$result);
@@ -930,12 +927,10 @@ def CmpFPredicateAttr : I64EnumAttr<
let cppNamespace = "::mlir";
}
-def CmpFOp : Std_Op<"cmpf",
- [NoSideEffect, SameTypeOperands, ElementwiseMappable,
- DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
- TypesMatchWith<
- "result type has i1 element type and same shape as operands",
- "lhs", "result", "getI1SameShape($_self)">]> {
+def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>, TypesMatchWith<
+ "result type has i1 element type and same shape as operands",
+ "lhs", "result", "getI1SameShape($_self)">] # ElementwiseMappable.traits> {
let summary = "floating-point comparison operation";
let description = [{
The `cmpf` operation compares its two operands according to the float
@@ -1015,12 +1010,10 @@ def CmpIPredicateAttr : I64EnumAttr<
let cppNamespace = "::mlir";
}
-def CmpIOp : Std_Op<"cmpi",
- [NoSideEffect, SameTypeOperands, ElementwiseMappable,
- DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
- TypesMatchWith<
- "result type has i1 element type and same shape as operands",
- "lhs", "result", "getI1SameShape($_self)">]> {
+def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameTypeOperands,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>, TypesMatchWith<
+ "result type has i1 element type and same shape as operands",
+ "lhs", "result", "getI1SameShape($_self)">] # ElementwiseMappable.traits> {
let summary = "integer comparison operation";
let description = [{
The `cmpi` operation is a generic comparison for integer-like types. Its two
@@ -2160,8 +2153,9 @@ def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
//===----------------------------------------------------------------------===//
def SelectOp : Std_Op<"select", [NoSideEffect,
- AllTypesMatch<["true_value", "false_value", "result"]>,
- ElementwiseMappable, DeclareOpInterfaceMethods<VectorUnrollOpInterface>]> {
+ AllTypesMatch<["true_value", "false_value", "result"]>,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
+ ElementwiseMappable.traits> {
let summary = "select operation";
let description = [{
The `select` operation chooses one value based on a binary condition
@@ -2391,9 +2385,9 @@ def SignedShiftRightOp : IntBinaryOp<"shift_right_signed"> {
// SignExtendIOp
//===----------------------------------------------------------------------===//
-def SignExtendIOp : Std_Op<"sexti",
- [NoSideEffect, ElementwiseMappable,
- DeclareOpInterfaceMethods<VectorUnrollOpInterface>]> {
+def SignExtendIOp : Std_Op<"sexti", [NoSideEffect,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
+ ElementwiseMappable.traits> {
let summary = "integer sign extension operation";
let description = [{
The integer sign extension operation takes an integer input of
@@ -3220,9 +3214,9 @@ def TransposeOp : Std_Op<"transpose", [NoSideEffect]>,
// TruncateIOp
//===----------------------------------------------------------------------===//
-def TruncateIOp : Std_Op<"trunci",
- [NoSideEffect, ElementwiseMappable,
- DeclareOpInterfaceMethods<VectorUnrollOpInterface>,]> {
+def TruncateIOp : Std_Op<"trunci", [NoSideEffect,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
+ ElementwiseMappable.traits> {
let summary = "integer truncation operation";
let description = [{
The integer truncation operation takes an integer input of
@@ -3463,9 +3457,9 @@ def XOrOp : IntBinaryOp<"xor", [Commutative]> {
// ZeroExtendIOp
//===----------------------------------------------------------------------===//
-def ZeroExtendIOp : Std_Op<"zexti",
- [NoSideEffect, ElementwiseMappable,
- DeclareOpInterfaceMethods<VectorUnrollOpInterface>,]> {
+def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
+ ElementwiseMappable.traits> {
let summary = "integer zero extension operation";
let description = [{
The integer zero extension operation takes an integer input of
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index c0089f7a2a51..f0f5f1d0e4ea 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1785,9 +1785,25 @@ def Terminator : NativeOpTrait<"IsTerminator">;
// Op can be safely normalized in the presence of MemRefs with
// non-identity maps.
def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">;
-// Op can be systematically interconverted between scalar and vector/tensor
-// form by mapping elementwise based on the type.
-def ElementwiseMappable : NativeOpTrait<"ElementwiseMappable">;
+// Op is elementwise on tensor/vector operands and results.
+def Elementwise : NativeOpTrait<"Elementwise">;
+// Elementwise op can be applied to scalars instead tensor/vector operands.
+def Scalarizable : NativeOpTrait<"Scalarizable">;
+// Elementwise op can be applied to all-vector operands.
+def Vectorizable : NativeOpTrait<"Vectorizable">;
+// Elementwise op can be applied to all-tensor operands.
+def Tensorizable : NativeOpTrait<"Tensorizable">;
+
+// Group together `Elementwise`, `Scalarizable`, `Vectorizable`, and
+// `Tensorizable` for convenience.
+def ElementwiseMappable {
+ list<OpTrait> traits = [
+ Elementwise,
+ Scalarizable,
+ Vectorizable,
+ Tensorizable,
+ ];
+}
// Op's regions have a single block with the specified terminator.
class SingleBlockImplicitTerminator<string op>
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index bfddcea950f2..c65e653f2a50 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -282,7 +282,7 @@ LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
LogicalResult verifyNoRegionArguments(Operation *op);
-LogicalResult verifyElementwiseMappable(Operation *op);
+LogicalResult verifyElementwise(Operation *op);
} // namespace impl
/// Helper class for implementing traits. Clients are not expected to interact
@@ -1213,93 +1213,144 @@ template <typename ConcrentType>
struct MemRefsNormalizable
: public TraitBase<ConcrentType, MemRefsNormalizable> {};
-/// This trait tags scalar ops that also can be applied to vectors/tensors, with
-/// their semantics on vectors/tensors being elementwise application.
+/// This trait tags element-wise ops that operate on scalars, vectors, or
+/// tensors.
///
/// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this
-/// trait. In particular, broadcasting behavior is not allowed. This trait
-/// describes a set of invariants that allow systematic
-/// vectorization/tensorization, and the reverse, scalarization. The properties
-/// needed for this also can be used to implement a number of
-/// transformations/analyses/interfaces.
+/// trait. In particular, broadcasting behavior is not allowed.
///
-/// An `ElementwiseMappable` op must satisfy the following properties:
+/// An `Elementwise` op must satisfy the following properties:
///
-/// 1. If any result is a vector (resp. tensor), then at least one operand must
-/// be a vector (resp. tensor).
-/// 2. If any operand is a vector (resp. tensor), then there must be at least
-/// one result, and all results must be vectors (resp. tensors).
-/// 3. The static types of all vector (resp. tensor) operands and results must
-/// have the same shape.
-/// 4. In the case of tensor operands, the dynamic shapes of all tensor operands
-/// must be the same, otherwise the op has undefined behavior.
-/// 5. ("systematic scalarization" property) If an op has vector/tensor
-/// operands/results, then the same op, with the operand/result types changed to
-/// their corresponding element type, shall be a verifier-valid op.
-/// 6. The semantics of the op on vectors (resp. tensors) shall be the same as
-/// applying the scalarized version of the op for each corresponding element of
-/// the vector (resp. tensor) operands in parallel.
-/// 7. ("systematic vectorization/tensorization" property) If an op has
-/// scalar operands/results, the op shall remain verifier-valid if all scalar
-/// operands are replaced with vectors/tensors of the same shape and
-/// corresponding element types.
+/// 1. If any result is a vector/tensor then at least one operand must also be a
+/// vector/tensor.
+/// 2. If any operand is a vector/tensor then there must be at least one result
+/// and all results must be vectors/tensors.
+/// 3. All operand and result vector/tensor types must be of the same shape. The
+/// shape may be dynamic in which case the op's behaviour is undefined for
+/// non-matching shapes.
+/// 4. The operation must be elementwise on its vector/tensor operands and
+/// results. When applied to single-element vectors/tensors, the result must
+/// be the same per elememnt.
///
-/// Together, these properties provide an easy way for scalar operations to
-/// conveniently generalize their behavior to vectors/tensors, and systematize
-/// conversion between these forms.
+/// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new
+/// interface `ElementwiseTypeInterface` that describes the container types for
+/// which the operation is elementwise.
///
-/// Examples:
+/// Rationale:
+/// - 1. and 2. guarantee a well-defined iteration space and exclude the cases
+/// of 0 non-scalar operands or 0 non-scalar results, which complicate a
+/// generic definition of the iteration space.
+/// - 3. guarantees that folding can be done across scalars/vectors/tensors with
+/// the same pattern, as otherwise lots of special handling for type
+/// mismatches would be needed.
+/// - 4. guarantees that no error handling is needed. Higher-level dialects
+/// should reify any needed guards or error handling code before lowering to
+/// an `Elementwise` op.
+template <typename ConcreteType>
+struct Elementwise : public TraitBase<ConcreteType, Elementwise> {
+ static LogicalResult verifyTrait(Operation *op) {
+ return ::mlir::OpTrait::impl::verifyElementwise(op);
+ }
+};
+
+/// This trait tags `Elementwise` operatons that can be systematically
+/// scalarized. All vector/tensor operands and results are then replaced by
+/// scalars of the respective element type. Semantically, this is the operation
+/// on a single element per vector/tensor.
+///
+/// Rationale:
+/// Allow to define the vector/tensor semantics of elementwise operations based
+/// on scalars. This provides a constructive procedure for IR transformations
+/// to, e.g., create scalar loop bodies from tensor ops.
+///
+/// Example:
/// ```
-/// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32
-/// // Applying the systematic vectorization/tensorization property, this op
-/// // must also be valid:
-/// %tensor = "std.addf"(%a_tensor, %b_tensor)
-/// : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>)
+/// %tensor_select = "std.select"(%pred_tensor, %true_val, %false_val)
+/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
+/// -> tensor<?xf32>
+/// ```
+/// can be scalarized to
///
-/// // These properties generalize well to the cases of non-scalar operands.
-/// %select_scalar_pred = "std.select"(%pred, %true_val, %false_val)
-/// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
-/// // Applying the systematic vectorization / tensorization property, this
-/// // op must also be valid:
-/// %select_tensor_pred = "std.select"(%pred_tensor, %true_val, %false_val)
-/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
-/// -> tensor<?xf32>
-/// // Applying the systematic scalarization property, this op must also
-/// // be valid.
-/// %select_scalar = "std.select"(%pred, %true_val_scalar, %false_val_scalar)
-/// : (i1, f32, f32) -> f32
/// ```
+/// %scalar_select = "std.select"(%pred, %true_val_scalar, %false_val_scalar)
+/// : (i1, f32, f32) -> f32
+/// ```
+template <typename ConcreteType>
+struct Scalarizable : public TraitBase<ConcreteType, Scalarizable> {
+ static LogicalResult verifyTrait(Operation *op) {
+ static_assert(
+ ConcreteType::template hasTrait<Elementwise>(),
+ "`Scalarizable` trait is only applicable to `Elementwise` ops.");
+ return success();
+ }
+};
+
+/// This trait tags `Elementwise` operatons that can be systematically
+/// vectorized. All scalar operands and results are then replaced by vectors
+/// with the respective element type. Semantically, this is the operation on
+/// multiple arguments simultaneously.
///
-/// TODO: Avoid hardcoding vector/tensor, and generalize this to any type
-/// implementing a new "ElementwiseMappableTypeInterface" that describes types
-/// for which it makes sense to apply a scalar function to each element.
+/// Rationale:
+/// Provide the reverse to `Scalarizable` which, when chained together, allows
+/// reasoning about the relationship between the tensor and vector case.
+/// Additionally, it permits reasoning about promoting scalars to vectors via
+/// broadcasting in cases like `%select_scalar_pred` above.
+template <typename ConcreteType>
+struct Vectorizable : public TraitBase<ConcreteType, Vectorizable> {
+ static LogicalResult verifyTrait(Operation *op) {
+ static_assert(
+ ConcreteType::template hasTrait<Elementwise>(),
+ "`Vectorizable` trait is only applicable to `Elementwise` ops.");
+ return success();
+ }
+};
+
+/// This trait tags `Elementwise` operatons that can be systematically
+/// tensorized. All scalar operands and results are then replaced by tensors
+/// with the respective element type. Semantically, this is the operation on
+/// multiple arguments simultaneously.
///
/// Rationale:
-/// - 1. and 2. guarantee a well-defined iteration space for 6.
-/// - These also exclude the cases of 0 non-scalar operands or 0 non-scalar
-/// results, which complicate a generic definition of the iteration space.
-/// - 3. guarantees that folding can be done across scalars/vectors/tensors
-/// with the same pattern, as otherwise lots of special handling of type
-/// mismatches would be needed.
-/// - 4. guarantees that no error handling cases need to be considered.
-/// - Higher-level dialects should reify any needed guards / error handling
-/// code before lowering to an ElementwiseMappable op.
-/// - 5. and 6. allow defining the semantics on vectors/tensors via the scalar
-/// semantics and provide a constructive procedure for IR transformations
-/// to e.g. create scalar loop bodies from tensor ops.
-/// - 7. provides the reverse of 5., which when chained together allows
-/// reasoning about the relationship between the tensor and vector case.
-/// Additionally, it permits reasoning about promoting scalars to
-/// vectors/tensors via broadcasting in cases like `%select_scalar_pred`
-/// above.
+/// Provide the reverse to `Scalarizable` which, when chained together, allows
+/// reasoning about the relationship between the tensor and vector case.
+/// Additionally, it permits reasoning about promoting scalars to tensors via
+/// broadcasting in cases like `%select_scalar_pred` above.
+///
+/// Examples:
+/// ```
+/// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32
+/// ```
+/// can be tensorized to
+/// ```
+/// %tensor = "std.addf"(%a, %b) : (tensor<?xf32>, tensor<?xf32>)
+/// -> tensor<?xf32>)
+/// ```
+///
+/// ```
+/// %scalar_pred = "std.select"(%pred, %true_val, %false_val)
+/// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+/// ```
+/// can be tensorized to
+/// ```
+/// %tensor_pred = "std.select"(%pred, %true_val, %false_val)
+/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
+/// -> tensor<?xf32>
+/// ```
template <typename ConcreteType>
-struct ElementwiseMappable
- : public TraitBase<ConcreteType, ElementwiseMappable> {
+struct Tensorizable : public TraitBase<ConcreteType, Tensorizable> {
static LogicalResult verifyTrait(Operation *op) {
- return ::mlir::OpTrait::impl::verifyElementwiseMappable(op);
+ static_assert(
+ ConcreteType::template hasTrait<Elementwise>(),
+ "`Tensorizable` trait is only applicable to `Elementwise` ops.");
+ return success();
}
};
+/// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable`
+/// provide an easy way for scalar operations to conveniently generalize their
+/// behavior to vectors/tensors, and systematize conversion between these forms.
+bool hasElementwiseMappableTraits(Operation *op);
+
} // end namespace OpTrait
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index 3012e920a82f..1d50e067cff3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -18,7 +18,7 @@
using namespace mlir;
static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
- if (!op->hasTrait<OpTrait::ElementwiseMappable>())
+ if (!OpTrait::hasElementwiseMappableTraits(op))
return false;
// TODO: The conversion pattern can be made to work for `any_of` here, but
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index e350d8931753..f471ab0ebd75 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -205,7 +205,7 @@ vectorizeOneOp(OpBuilder &builder, Operation *op,
return VectorizationResult{VectorizationStatus::NewOp, builder.clone(*op)};
// 3. Only ElementwiseMappable are allowed in the generic vectorization.
- if (!op->hasTrait<OpTrait::ElementwiseMappable>())
+ if (!OpTrait::hasElementwiseMappableTraits(op))
return VectorizationResult{VectorizationStatus::Failure, nullptr};
// 4. Generic vectorization path for ElementwiseMappable ops.
@@ -323,7 +323,7 @@ static bool hasOnlyScalarElementwiseOp(Region &r) {
return false;
for (Operation &op : r.front()) {
if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
- op.hasTrait<OpTrait::ElementwiseMappable>()) ||
+ OpTrait::hasElementwiseMappableTraits(&op)) ||
llvm::any_of(op.getResultTypes(),
[](Type type) { return !type.isIntOrIndexOrFloat(); }))
return false;
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index b1f2ad66efdb..f9ccda5a2544 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1085,7 +1085,7 @@ static bool areSameShapedTypeIgnoringElementType(ShapedType a, ShapedType b) {
return a.getShape() == b.getShape();
}
-LogicalResult OpTrait::impl::verifyElementwiseMappable(Operation *op) {
+LogicalResult OpTrait::impl::verifyElementwise(Operation *op) {
auto isMappableType = [](Type type) {
return type.isa<VectorType, TensorType>();
};
@@ -1127,6 +1127,11 @@ LogicalResult OpTrait::impl::verifyElementwiseMappable(Operation *op) {
return success();
}
+bool OpTrait::hasElementwiseMappableTraits(Operation *op) {
+ return op->hasTrait<Elementwise>() && op->hasTrait<Scalarizable>() &&
+ op->hasTrait<Vectorizable>() && op->hasTrait<Tensorizable>();
+}
+
//===----------------------------------------------------------------------===//
// BinaryOp implementation
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 4893ac3d8492..828c31d82e11 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -356,7 +356,8 @@ def SameOperandElementTypeOp : TEST_Op<"same_operand_element_type",
let results = (outs AnyType);
}
-def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_element_type",
+def SameOperandAndResultElementTypeOp :
+ TEST_Op<"same_operand_and_result_element_type",
[SameOperandsAndResultElementType]> {
let arguments = (ins Variadic<AnyType>);
let results = (outs Variadic<AnyType>);
@@ -379,7 +380,7 @@ def SameOperandAndResultTypeOp : TEST_Op<"same_operand_and_result_type",
}
def ElementwiseMappableOp : TEST_Op<"elementwise_mappable",
- [ElementwiseMappable]> {
+ ElementwiseMappable.traits> {
let arguments = (ins Variadic<AnyType>);
let results = (outs Variadic<AnyType>);
}
More information about the Mlir-commits
mailing list