[Mlir-commits] [mlir] b4fa28b - [mlir] Add ElementwiseMappable trait and apply it to std elementwise ops.
Sean Silva
llvmlistbot at llvm.org
Tue Nov 10 13:46:32 PST 2020
Author: Sean Silva
Date: 2020-11-10T13:44:44-08:00
New Revision: b4fa28b408662a638f7d835b302db08633ccb91e
URL: https://github.com/llvm/llvm-project/commit/b4fa28b408662a638f7d835b302db08633ccb91e
DIFF: https://github.com/llvm/llvm-project/commit/b4fa28b408662a638f7d835b302db08633ccb91e.diff
LOG: [mlir] Add ElementwiseMappable trait and apply it to std elementwise ops.
This patch adds an `ElementwiseMappable` trait as discussed in the RFC
here:
https://llvm.discourse.group/t/rfc-std-elementwise-ops-on-tensors/2113/23
This trait can power a number of transformations and analyses.
A subsequent patch adds a convert-elementwise-to-linalg pass exhibits
how this trait allows writing generic transformations.
See https://reviews.llvm.org/D90354 for that patch.
This trait slightly changes some verifier messages, but the diagnostics
are usually about as good. I fiddled with the ordering of the trait in
the .td file trait lists to minimize the changes here.
Differential Revision: https://reviews.llvm.org/D90731
Added:
Modified:
mlir/docs/Traits.md
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/lib/IR/Operation.cpp
mlir/test/Dialect/Standard/invalid.mlir
mlir/test/IR/invalid-ops.mlir
mlir/test/IR/traits.mlir
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md
index c6db640f7f5f..0b85929aab93 100644
--- a/mlir/docs/Traits.md
+++ b/mlir/docs/Traits.md
@@ -239,6 +239,20 @@ This trait requires that the operands are either vector or tensor types.
This trait adds the property that the operation is commutative, i.e. `X op Y ==
Y op X`
+### ElementwiseMappable
+
+* `OpTrait::ElementwiseMappable` -- `ElementwiseMappable`
+
+This trait tags scalar ops that also can be applied to vectors/tensors, with
+their semantics on vectors/tensors being elementwise application. This trait
+establishes a set of properties that allow reasoning about / converting between
+scalar/vector/tensor code. These same properties allow blanket implementations
+of various analyses/transformations for all `ElementwiseMappable` ops.
+
+Note: Not all ops that are "elementwise" in some abstract sense satisfy this
+trait. In particular, broadcasting behavior is not allowed. See the comments on
+`OpTrait::ElementwiseMappable` for the precise requirements.
+
### Function-Like
* `OpTrait::FunctionLike`
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 6037b3b32683..d15f06b37fa5 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -67,6 +67,11 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
let hasFolder = 1;
}
+// Base class for arithmetic cast operations.
+class ArithmeticCastOp<string mnemonic, list<OpTrait> traits = []> :
+ CastOp<mnemonic, !listconcat(traits, [ElementwiseMappable])> {
+}
+
// Base class for unary ops. Requires single operand and result. Individual
// classes will have `operand` accessor.
class UnaryOp<string mnemonic, list<OpTrait> traits = []> :
@@ -88,7 +93,8 @@ class UnaryOpSameOperandAndResultType<string mnemonic,
class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
UnaryOpSameOperandAndResultType<mnemonic,
!listconcat(traits,
- [DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
+ [DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
+ ElementwiseMappable])>,
Arguments<(ins FloatLike:$operand)>;
// Base class for standard arithmetic operations. Requires operands and
@@ -96,7 +102,9 @@ class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
// types. Individual classes will have `lhs` and `rhs` accessor to operands.
class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
Op<StandardOps_Dialect, mnemonic,
- !listconcat(traits, [NoSideEffect, SameOperandsAndResultType])> {
+ !listconcat(traits, [NoSideEffect,
+ SameOperandsAndResultType,
+ ElementwiseMappable])> {
let results = (outs AnyType);
@@ -1152,10 +1160,10 @@ def CmpFPredicateAttr : I64EnumAttr<
}
def CmpFOp : Std_Op<"cmpf",
- [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape,
- TypesMatchWith<
+ [NoSideEffect, SameTypeOperands,
+ SameOperandsAndResultShape, TypesMatchWith<
"result type has i1 element type and same shape as operands",
- "lhs", "result", "getI1SameShape($_self)">]> {
+ "lhs", "result", "getI1SameShape($_self)">, ElementwiseMappable]> {
let summary = "floating-point comparison operation";
let description = [{
The `cmpf` operation compares its two operands according to the float
@@ -1236,10 +1244,10 @@ def CmpIPredicateAttr : I64EnumAttr<
}
def CmpIOp : Std_Op<"cmpi",
- [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape,
- TypesMatchWith<
+ [NoSideEffect, SameTypeOperands,
+ SameOperandsAndResultShape, TypesMatchWith<
"result type has i1 element type and same shape as operands",
- "lhs", "result", "getI1SameShape($_self)">]> {
+ "lhs", "result", "getI1SameShape($_self)">, ElementwiseMappable]> {
let summary = "integer comparison operation";
let description = [{
The `cmpi` operation is a generic comparison for integer-like types. Its two
@@ -1926,7 +1934,7 @@ def TensorFromElementsOp : Std_Op<"tensor_from_elements", [
// FPExtOp
//===----------------------------------------------------------------------===//
-def FPExtOp : CastOp<"fpext">, Arguments<(ins AnyType:$in)> {
+def FPExtOp : ArithmeticCastOp<"fpext">, Arguments<(ins AnyType:$in)> {
let summary = "cast from floating-point to wider floating-point";
let description = [{
Cast a floating-point value to a larger floating-point-typed value.
@@ -1947,7 +1955,7 @@ def FPExtOp : CastOp<"fpext">, Arguments<(ins AnyType:$in)> {
// FPToSIOp
//===----------------------------------------------------------------------===//
-def FPToSIOp : CastOp<"fptosi">, Arguments<(ins AnyType:$in)> {
+def FPToSIOp : ArithmeticCastOp<"fptosi">, Arguments<(ins AnyType:$in)> {
let summary = "cast from floating-point type to integer type";
let description = [{
Cast from a value interpreted as floating-point to the nearest (rounding
@@ -1967,7 +1975,7 @@ def FPToSIOp : CastOp<"fptosi">, Arguments<(ins AnyType:$in)> {
// FPToUIOp
//===----------------------------------------------------------------------===//
-def FPToUIOp : CastOp<"fptoui">, Arguments<(ins AnyType:$in)> {
+def FPToUIOp : ArithmeticCastOp<"fptoui">, Arguments<(ins AnyType:$in)> {
let summary = "cast from floating-point type to integer type";
let description = [{
Cast from a value interpreted as floating-point to the nearest (rounding
@@ -1987,7 +1995,7 @@ def FPToUIOp : CastOp<"fptoui">, Arguments<(ins AnyType:$in)> {
// FPTruncOp
//===----------------------------------------------------------------------===//
-def FPTruncOp : CastOp<"fptrunc">, Arguments<(ins AnyType:$in)> {
+def FPTruncOp : ArithmeticCastOp<"fptrunc">, Arguments<(ins AnyType:$in)> {
let summary = "cast from floating-point to narrower floating-point";
let description = [{
Truncate a floating-point value to a smaller floating-point-typed value.
@@ -2131,7 +2139,7 @@ def ImOp : Std_Op<"im",
// IndexCastOp
//===----------------------------------------------------------------------===//
-def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
+def IndexCastOp : ArithmeticCastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
let summary = "cast between index and integer types";
let description = [{
Casts between integer scalars and 'index' scalars. Index is an integer of
@@ -2714,7 +2722,8 @@ def RsqrtOp : FloatUnaryOp<"rsqrt"> {
//===----------------------------------------------------------------------===//
def SelectOp : Std_Op<"select", [NoSideEffect,
- AllTypesMatch<["true_value", "false_value", "result"]>]> {
+ AllTypesMatch<["true_value", "false_value", "result"]>,
+ ElementwiseMappable]> {
let summary = "select operation";
let description = [{
The `select` operation chooses one value based on a binary condition
@@ -2945,7 +2954,7 @@ def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> {
//===----------------------------------------------------------------------===//
def SignExtendIOp : Std_Op<"sexti",
- [NoSideEffect, SameOperandsAndResultShape]> {
+ [NoSideEffect, SameOperandsAndResultShape, ElementwiseMappable]> {
let summary = "integer sign extension operation";
let description = [{
The integer sign extension operation takes an integer input of
@@ -2987,7 +2996,7 @@ def SignExtendIOp : Std_Op<"sexti",
// SIToFPOp
//===----------------------------------------------------------------------===//
-def SIToFPOp : CastOp<"sitofp">, Arguments<(ins AnyType:$in)> {
+def SIToFPOp : ArithmeticCastOp<"sitofp">, Arguments<(ins AnyType:$in)> {
let summary = "cast from integer type to floating-point";
let description = [{
Cast from a value interpreted as signed or vector of signed integers to the
@@ -3786,7 +3795,9 @@ def TransposeOp : Std_Op<"transpose", [NoSideEffect]>,
// TruncateIOp
//===----------------------------------------------------------------------===//
-def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> {
+def TruncateIOp : Std_Op<"trunci", [NoSideEffect,
+ SameOperandsAndResultShape,
+ ElementwiseMappable]> {
let summary = "integer truncation operation";
let description = [{
The integer truncation operation takes an integer input of
@@ -3826,7 +3837,7 @@ def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> {
// UIToFPOp
//===----------------------------------------------------------------------===//
-def UIToFPOp : CastOp<"uitofp">, Arguments<(ins AnyType:$in)> {
+def UIToFPOp : ArithmeticCastOp<"uitofp">, Arguments<(ins AnyType:$in)> {
let summary = "cast from unsigned integer type to floating-point";
let description = [{
Cast from a value interpreted as unsigned integer or vector of unsigned
@@ -4053,7 +4064,9 @@ def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
// ZeroExtendIOp
//===----------------------------------------------------------------------===//
-def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, SameOperandsAndResultShape]> {
+def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect,
+ SameOperandsAndResultShape,
+ ElementwiseMappable]> {
let summary = "integer zero extension operation";
let description = [{
The integer zero extension operation takes an integer input of
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 814b8573d8b3..b031769022d9 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1750,6 +1750,9 @@ def Terminator : NativeOpTrait<"IsTerminator">;
// Op can be safely normalized in the presence of MemRefs with
// non-identity maps.
def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">;
+// Op can be systematically interconverted between scalar and vector/tensor
+// form by mapping elementwise based on the type.
+def ElementwiseMappable : NativeOpTrait<"ElementwiseMappable">;
// Op's regions have a single block with the specified terminator.
class SingleBlockImplicitTerminator<string op>
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index f1be18758756..d673b2b8161a 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -321,6 +321,7 @@ LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
LogicalResult verifyNoRegionArguments(Operation *op);
+LogicalResult verifyElementwiseMappable(Operation *op);
} // namespace impl
/// Helper class for implementing traits. Clients are not expected to interact
@@ -1182,6 +1183,93 @@ template <typename ConcrentType>
struct MemRefsNormalizable
: public TraitBase<ConcrentType, MemRefsNormalizable> {};
+/// This trait tags scalar ops that also can be applied to vectors/tensors, with
+/// their semantics on vectors/tensors being elementwise application.
+///
+/// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this
+/// trait. In particular, broadcasting behavior is not allowed. This trait
+/// describes a set of invariants that allow systematic
+/// vectorization/tensorization, and the reverse, scalarization. The properties
+/// needed for this also can be used to implement a number of
+/// transformations/analyses/interfaces.
+///
+/// An `ElementwiseMappable` op must satisfy the following properties:
+///
+/// 1. If any result is a vector (resp. tensor), then at least one operand must
+/// be a vector (resp. tensor).
+/// 2. If any operand is a vector (resp. tensor), then there must be at least
+/// one result, and all results must be vectors (resp. tensors).
+/// 3. The static types of all vector (resp. tensor) operands and results must
+/// have the same shape.
+/// 4. In the case of tensor operands, the dynamic shapes of all tensor operands
+/// must be the same, otherwise the op has undefined behavior.
+/// 5. ("systematic scalarization" property) If an op has vector/tensor
+/// operands/results, then the same op, with the operand/result types changed to
+/// their corresponding element type, shall be a verifier-valid op.
+/// 6. The semantics of the op on vectors (resp. tensors) shall be the same as
+/// applying the scalarized version of the op for each corresponding element of
+/// the vector (resp. tensor) operands in parallel.
+/// 7. ("systematic vectorization/tensorization" property) If an op has
+/// scalar operands/results, the op shall remain verifier-valid if all scalar
+/// operands are replaced with vectors/tensors of the same shape and
+/// corresponding element types.
+///
+/// Together, these properties provide an easy way for scalar operations to
+/// conveniently generalize their behavior to vectors/tensors, and systematize
+/// conversion between these forms.
+///
+/// Examples:
+/// ```
+/// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32
+/// // Applying the systematic vectorization/tensorization property, this op
+/// // must also be valid:
+/// %tensor = "std.addf"(%a_tensor, %b_tensor)
+/// : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>)
+///
+/// // These properties generalize well to the cases of non-scalar operands.
+/// %select_scalar_pred = "std.select"(%pred, %true_val, %false_val)
+/// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+/// // Applying the systematic vectorization / tensorization property, this
+/// // op must also be valid:
+/// %select_tensor_pred = "std.select"(%pred_tensor, %true_val, %false_val)
+/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
+/// -> tensor<?xf32>
+/// // Applying the systematic scalarization property, this op must also
+/// // be valid.
+/// %select_scalar = "std.select"(%pred, %true_val_scalar, %false_val_scalar)
+/// : (i1, f32, f32) -> f32
+/// ```
+///
+/// TODO: Avoid hardcoding vector/tensor, and generalize this to any type
+/// implementing a new "ElementwiseMappableTypeInterface" that describes types
+/// for which it makes sense to apply a scalar function to each element.
+///
+/// Rationale:
+/// - 1. and 2. guarantee a well-defined iteration space for 6.
+/// - These also exclude the cases of 0 non-scalar operands or 0 non-scalar
+/// results, which complicate a generic definition of the iteration space.
+/// - 3. guarantees that folding can be done across scalars/vectors/tensors
+/// with the same pattern, as otherwise lots of special handling of type
+/// mismatches would be needed.
+/// - 4. guarantees that no error handling cases need to be considered.
+/// - Higher-level dialects should reify any needed guards / error handling
+/// code before lowering to an ElementwiseMappable op.
+/// - 5. and 6. allow defining the semantics on vectors/tensors via the scalar
+/// semantics and provide a constructive procedure for IR transformations
+/// to e.g. create scalar loop bodies from tensor ops.
+/// - 7. provides the reverse of 5., which when chained together allows
+/// reasoning about the relationship between the tensor and vector case.
+/// Additionally, it permits reasoning about promoting scalars to
+/// vectors/tensors via broadcasting in cases like `%select_scalar_pred`
+/// above.
+template <typename ConcreteType>
+struct ElementwiseMappable
+ : public TraitBase<ConcreteType, ElementwiseMappable> {
+ static LogicalResult verifyTrait(Operation *op) {
+ return ::mlir::OpTrait::impl::verifyElementwiseMappable(op);
+ }
+};
+
} // end namespace OpTrait
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index c85efe83afe8..5fed2ec8e753 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1068,6 +1068,57 @@ LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) {
return success();
}
+/// Checks if two ShapedTypes are the same, ignoring the element type.
+static bool areSameShapedTypeIgnoringElementType(ShapedType a, ShapedType b) {
+ if (a.getTypeID() != b.getTypeID())
+ return false;
+ if (!a.hasRank())
+ return !b.hasRank();
+ return a.getShape() == b.getShape();
+}
+
+LogicalResult OpTrait::impl::verifyElementwiseMappable(Operation *op) {
+ auto isMappableType = [](Type type) {
+ return type.isa<VectorType, TensorType>();
+ };
+ auto resultMappableTypes = llvm::to_vector<1>(
+ llvm::make_filter_range(op->getResultTypes(), isMappableType));
+ auto operandMappableTypes = llvm::to_vector<2>(
+ llvm::make_filter_range(op->getOperandTypes(), isMappableType));
+
+ // If the op only has scalar operand/result types, then we have nothing to
+ // check.
+ if (resultMappableTypes.empty() && operandMappableTypes.empty())
+ return success();
+
+ if (!resultMappableTypes.empty() && operandMappableTypes.empty())
+ return op->emitOpError("if a result is non-scalar, then at least one "
+ "operand must be non-scalar");
+
+ assert(!operandMappableTypes.empty());
+
+ if (resultMappableTypes.empty())
+ return op->emitOpError("if an operand is non-scalar, then there must be at "
+ "least one non-scalar result");
+
+ if (resultMappableTypes.size() != op->getNumResults())
+ return op->emitOpError(
+ "if an operand is non-scalar, then all results must be non-scalar");
+
+ auto mustMatchType = operandMappableTypes[0].cast<ShapedType>();
+ for (auto type :
+ llvm::concat<Type>(resultMappableTypes, operandMappableTypes)) {
+ if (!areSameShapedTypeIgnoringElementType(type.cast<ShapedType>(),
+ mustMatchType)) {
+ return op->emitOpError() << "all non-scalar operands/results must have "
+ "the same shape and base type: found "
+ << type << " and " << mustMatchType;
+ }
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// BinaryOp implementation
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir
index 3c4be54a1d43..04ecdc64351a 100644
--- a/mlir/test/Dialect/Standard/invalid.mlir
+++ b/mlir/test/Dialect/Standard/invalid.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt -split-input-file %s -verify-diagnostics
func @test_index_cast_shape_error(%arg0 : tensor<index>) -> tensor<2xi64> {
- // expected-error @+1 {{requires the same shape for all operands and results}}
+ // expected-error @+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<2xi64>' and 'tensor<index>'}}
%0 = index_cast %arg0 : tensor<index> to tensor<2xi64>
return %0 : tensor<2xi64>
}
@@ -9,7 +9,7 @@ func @test_index_cast_shape_error(%arg0 : tensor<index>) -> tensor<2xi64> {
// -----
func @test_index_cast_tensor_error(%arg0 : tensor<index>) -> i64 {
- // expected-error @+1 {{requires the same shape for all operands and results}}
+ // expected-error @+1 {{if an operand is non-scalar, then there must be at least one non-scalar result}}
%0 = index_cast %arg0 : tensor<index> to i64
return %0 : i64
}
diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index b59353aa2f7c..76aff5c6d401 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -267,7 +267,7 @@ func @func_with_ops(i1, i32, i64) {
func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) {
^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>):
- // expected-error at +1 {{expected condition type to have the same shape as the result type, expected 'vector<42xi1>', but got 'vector<12xi1>'}}
+ // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<42xi32>' and 'vector<12xi1>'}}
%r = "std.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
}
@@ -275,7 +275,7 @@ func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) {
func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) {
^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
- // expected-error at +1 {{expected condition type to have the same shape as the result type, expected 'tensor<42xi1>', but got 'tensor<12xi1>'}}
+ // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<42xi32>' and 'tensor<12xi1>'}}
%r = "std.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
}
@@ -685,7 +685,7 @@ func @fpext_f32_to_i32(%arg0 : f32) {
// -----
func @fpext_vec(%arg0 : vector<2xf16>) {
- // expected-error at +1 {{requires the same shape for all operands and results}}
+ // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<3xf32>' and 'vector<2xf16>'}}
%0 = fpext %arg0 : vector<2xf16> to vector<3xf32>
return
}
@@ -757,7 +757,7 @@ func @fptrunc_f32_to_i32(%arg0 : f32) {
// -----
func @fptrunc_vec(%arg0 : vector<2xf16>) {
- // expected-error at +1 {{requires the same shape for all operands and results}}
+ // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<3xf32>' and 'vector<2xf16>'}}
%0 = fptrunc %arg0 : vector<2xf16> to vector<3xf32>
return
}
diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index 206881a1c767..4dd6ef7dcf37 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -166,6 +166,84 @@ func @failedSameOperandAndResultType_operand_result_mismatch(%t10 : tensor<10xf3
// -----
+func @failedElementwiseMappable_
diff erent_rankedness(%arg0: tensor<?xf32>, %arg1: tensor<*xf32>) {
+ // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<*xf32>' and 'tensor<?xf32>'}}
+ %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor<?xf32>, tensor<*xf32>) -> tensor<*xf32>
+}
+
+// -----
+
+func @failedElementwiseMappable_
diff erent_rank(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) {
+ // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<?x?xf32>' and 'tensor<?xf32>'}}
+ %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+}
+
+// -----
+
+func @failedElementwiseMappable_
diff erent_shape(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
+ // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<5xf32>' and 'tensor<?xf32>'}}
+ %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor<?xf32>, tensor<5xf32>) -> tensor<?xf32>
+}
+
+// -----
+
+func @failedElementwiseMappable_
diff erent_base_type(%arg0: vector<2xf32>, %arg1: tensor<2xf32>) {
+ // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<2xf32>' and 'vector<2xf32>'}}
+ %0 = "test.elementwise_mappable"(%arg0, %arg1) : (vector<2xf32>, tensor<2xf32>) -> tensor<2xf32>
+ return
+}
+
+// -----
+
+func @failedElementwiseMappable_non_scalar_output(%arg0: vector<2xf32>) {
+ // expected-error at +1 {{if an operand is non-scalar, then there must be at least one non-scalar result}}
+ %0 = "test.elementwise_mappable"(%arg0) : (vector<2xf32>) -> f32
+ return
+}
+
+// -----
+
+func @failedElementwiseMappable_non_scalar_result_all_scalar_input(%arg0: f32) {
+ // expected-error at +1 {{if a result is non-scalar, then at least one operand must be non-scalar}}
+ %0 = "test.elementwise_mappable"(%arg0) : (f32) -> tensor<f32>
+ return
+}
+
+// -----
+
+func @failedElementwiseMappable_mixed_scalar_non_scalar_results(%arg0: tensor<10xf32>) {
+ // expected-error at +1 {{if an operand is non-scalar, then all results must be non-scalar}}
+ %0, %1 = "test.elementwise_mappable"(%arg0) : (tensor<10xf32>) -> (f32, tensor<10xf32>)
+ return
+}
+
+// -----
+
+func @failedElementwiseMappable_zero_results(%arg0: tensor<10xf32>) {
+ // expected-error at +1 {{if an operand is non-scalar, then there must be at least one non-scalar result}}
+ "test.elementwise_mappable"(%arg0) : (tensor<10xf32>) -> ()
+ return
+}
+
+// -----
+
+func @failedElementwiseMappable_zero_operands() {
+ // expected-error at +1 {{if a result is non-scalar, then at least one operand must be non-scalar}}
+ "test.elementwise_mappable"() : () -> (tensor<6xf32>)
+ return
+}
+
+// -----
+
+func @succeededElementwiseMappable(%arg0: vector<2xf32>) {
+ // Check that varying element types are allowed.
+ // CHECK: test.elementwise_mappable
+ %0 = "test.elementwise_mappable"(%arg0) : (vector<2xf32>) -> vector<2xf16>
+ return
+}
+
+// -----
+
func @failedHasParent_wrong_parent() {
"some.op"() ({
// expected-error at +1 {{'test.child' op expects parent op 'test.parent'}}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 1ed6c545ff83..209d26c8feab 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -370,6 +370,12 @@ def SameOperandAndResultTypeOp : TEST_Op<"same_operand_and_result_type",
let results = (outs Variadic<AnyType>);
}
+def ElementwiseMappableOp : TEST_Op<"elementwise_mappable",
+ [ElementwiseMappable]> {
+ let arguments = (ins Variadic<AnyType>);
+ let results = (outs Variadic<AnyType>);
+}
+
def ArgAndResHaveFixedElementTypesOp :
TEST_Op<"arg_and_res_have_fixed_element_types",
[PredOpTrait<"fixed type combination",
More information about the Mlir-commits
mailing list