[Mlir-commits] [mlir] b4fa28b - [mlir] Add ElementwiseMappable trait and apply it to std elementwise ops.

Sean Silva llvmlistbot at llvm.org
Tue Nov 10 13:46:32 PST 2020


Author: Sean Silva
Date: 2020-11-10T13:44:44-08:00
New Revision: b4fa28b408662a638f7d835b302db08633ccb91e

URL: https://github.com/llvm/llvm-project/commit/b4fa28b408662a638f7d835b302db08633ccb91e
DIFF: https://github.com/llvm/llvm-project/commit/b4fa28b408662a638f7d835b302db08633ccb91e.diff

LOG: [mlir] Add ElementwiseMappable trait and apply it to std elementwise ops.

This patch adds an `ElementwiseMappable` trait as discussed in the RFC
here:
https://llvm.discourse.group/t/rfc-std-elementwise-ops-on-tensors/2113/23

This trait can power a number of transformations and analyses.
A subsequent patch adds a convert-elementwise-to-linalg pass exhibits
how this trait allows writing generic transformations.
See https://reviews.llvm.org/D90354 for that patch.

This trait slightly changes some verifier messages, but the diagnostics
are usually about as good. I fiddled with the ordering of the trait in
the .td file trait lists to minimize the changes here.

Differential Revision: https://reviews.llvm.org/D90731

Added: 
    

Modified: 
    mlir/docs/Traits.md
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/OpDefinition.h
    mlir/lib/IR/Operation.cpp
    mlir/test/Dialect/Standard/invalid.mlir
    mlir/test/IR/invalid-ops.mlir
    mlir/test/IR/traits.mlir
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md
index c6db640f7f5f..0b85929aab93 100644
--- a/mlir/docs/Traits.md
+++ b/mlir/docs/Traits.md
@@ -239,6 +239,20 @@ This trait requires that the operands are either vector or tensor types.
 This trait adds the property that the operation is commutative, i.e. `X op Y ==
 Y op X`
 
+### ElementwiseMappable
+
+* `OpTrait::ElementwiseMappable` -- `ElementwiseMappable`
+
+This trait tags scalar ops that also can be applied to vectors/tensors, with
+their semantics on vectors/tensors being elementwise application. This trait
+establishes a set of properties that allow reasoning about / converting between
+scalar/vector/tensor code. These same properties allow blanket implementations
+of various analyses/transformations for all `ElementwiseMappable` ops.
+
+Note: Not all ops that are "elementwise" in some abstract sense satisfy this
+trait. In particular, broadcasting behavior is not allowed. See the comments on
+`OpTrait::ElementwiseMappable` for the precise requirements.
+
 ### Function-Like
 
 *   `OpTrait::FunctionLike`

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 6037b3b32683..d15f06b37fa5 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -67,6 +67,11 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
   let hasFolder = 1;
 }
 
+// Base class for arithmetic cast operations.
+class ArithmeticCastOp<string mnemonic, list<OpTrait> traits = []> :
+    CastOp<mnemonic, !listconcat(traits, [ElementwiseMappable])> {
+}
+
 // Base class for unary ops. Requires single operand and result. Individual
 // classes will have `operand` accessor.
 class UnaryOp<string mnemonic, list<OpTrait> traits = []> :
@@ -88,7 +93,8 @@ class UnaryOpSameOperandAndResultType<string mnemonic,
 class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
     UnaryOpSameOperandAndResultType<mnemonic,
       !listconcat(traits,
-                  [DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
+                  [DeclareOpInterfaceMethods<VectorUnrollOpInterface>,
+                   ElementwiseMappable])>,
     Arguments<(ins FloatLike:$operand)>;
 
 // Base class for standard arithmetic operations.  Requires operands and
@@ -96,7 +102,9 @@ class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
 // types.  Individual classes will have `lhs` and `rhs` accessor to operands.
 class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
     Op<StandardOps_Dialect, mnemonic,
-       !listconcat(traits, [NoSideEffect, SameOperandsAndResultType])> {
+       !listconcat(traits, [NoSideEffect,
+                            SameOperandsAndResultType,
+                            ElementwiseMappable])> {
 
   let results = (outs AnyType);
 
@@ -1152,10 +1160,10 @@ def CmpFPredicateAttr : I64EnumAttr<
 }
 
 def CmpFOp : Std_Op<"cmpf",
-    [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape,
-     TypesMatchWith<
+    [NoSideEffect, SameTypeOperands,
+     SameOperandsAndResultShape, TypesMatchWith<
        "result type has i1 element type and same shape as operands",
-       "lhs", "result", "getI1SameShape($_self)">]> {
+       "lhs", "result", "getI1SameShape($_self)">, ElementwiseMappable]> {
   let summary = "floating-point comparison operation";
   let description = [{
     The `cmpf` operation compares its two operands according to the float
@@ -1236,10 +1244,10 @@ def CmpIPredicateAttr : I64EnumAttr<
 }
 
 def CmpIOp : Std_Op<"cmpi",
-    [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape,
-     TypesMatchWith<
+    [NoSideEffect, SameTypeOperands,
+     SameOperandsAndResultShape, TypesMatchWith<
        "result type has i1 element type and same shape as operands",
-       "lhs", "result", "getI1SameShape($_self)">]> {
+       "lhs", "result", "getI1SameShape($_self)">, ElementwiseMappable]> {
   let summary = "integer comparison operation";
   let description = [{
     The `cmpi` operation is a generic comparison for integer-like types. Its two
@@ -1926,7 +1934,7 @@ def TensorFromElementsOp : Std_Op<"tensor_from_elements", [
 // FPExtOp
 //===----------------------------------------------------------------------===//
 
-def FPExtOp : CastOp<"fpext">, Arguments<(ins AnyType:$in)> {
+def FPExtOp : ArithmeticCastOp<"fpext">, Arguments<(ins AnyType:$in)> {
   let summary = "cast from floating-point to wider floating-point";
   let description = [{
     Cast a floating-point value to a larger floating-point-typed value.
@@ -1947,7 +1955,7 @@ def FPExtOp : CastOp<"fpext">, Arguments<(ins AnyType:$in)> {
 // FPToSIOp
 //===----------------------------------------------------------------------===//
 
-def FPToSIOp : CastOp<"fptosi">, Arguments<(ins AnyType:$in)> {
+def FPToSIOp : ArithmeticCastOp<"fptosi">, Arguments<(ins AnyType:$in)> {
   let summary = "cast from floating-point type to integer type";
   let description = [{
     Cast from a value interpreted as floating-point to the nearest (rounding
@@ -1967,7 +1975,7 @@ def FPToSIOp : CastOp<"fptosi">, Arguments<(ins AnyType:$in)> {
 // FPToUIOp
 //===----------------------------------------------------------------------===//
 
-def FPToUIOp : CastOp<"fptoui">, Arguments<(ins AnyType:$in)> {
+def FPToUIOp : ArithmeticCastOp<"fptoui">, Arguments<(ins AnyType:$in)> {
   let summary = "cast from floating-point type to integer type";
   let description = [{
     Cast from a value interpreted as floating-point to the nearest (rounding
@@ -1987,7 +1995,7 @@ def FPToUIOp : CastOp<"fptoui">, Arguments<(ins AnyType:$in)> {
 // FPTruncOp
 //===----------------------------------------------------------------------===//
 
-def FPTruncOp : CastOp<"fptrunc">, Arguments<(ins AnyType:$in)> {
+def FPTruncOp : ArithmeticCastOp<"fptrunc">, Arguments<(ins AnyType:$in)> {
   let summary = "cast from floating-point to narrower floating-point";
   let description = [{
     Truncate a floating-point value to a smaller floating-point-typed value.
@@ -2131,7 +2139,7 @@ def ImOp : Std_Op<"im",
 // IndexCastOp
 //===----------------------------------------------------------------------===//
 
-def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
+def IndexCastOp : ArithmeticCastOp<"index_cast">, Arguments<(ins AnyType:$in)> {
   let summary = "cast between index and integer types";
   let description = [{
     Casts between integer scalars and 'index' scalars. Index is an integer of
@@ -2714,7 +2722,8 @@ def RsqrtOp : FloatUnaryOp<"rsqrt"> {
 //===----------------------------------------------------------------------===//
 
 def SelectOp : Std_Op<"select", [NoSideEffect,
-     AllTypesMatch<["true_value", "false_value", "result"]>]> {
+     AllTypesMatch<["true_value", "false_value", "result"]>,
+     ElementwiseMappable]> {
   let summary = "select operation";
   let description = [{
     The `select` operation chooses one value based on a binary condition
@@ -2945,7 +2954,7 @@ def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> {
 //===----------------------------------------------------------------------===//
 
 def SignExtendIOp : Std_Op<"sexti",
-    [NoSideEffect, SameOperandsAndResultShape]> {
+    [NoSideEffect, SameOperandsAndResultShape, ElementwiseMappable]> {
   let summary = "integer sign extension operation";
   let description = [{
     The integer sign extension operation takes an integer input of
@@ -2987,7 +2996,7 @@ def SignExtendIOp : Std_Op<"sexti",
 // SIToFPOp
 //===----------------------------------------------------------------------===//
 
-def SIToFPOp : CastOp<"sitofp">, Arguments<(ins AnyType:$in)> {
+def SIToFPOp : ArithmeticCastOp<"sitofp">, Arguments<(ins AnyType:$in)> {
   let summary = "cast from integer type to floating-point";
   let description = [{
     Cast from a value interpreted as signed or vector of signed integers to the
@@ -3786,7 +3795,9 @@ def TransposeOp : Std_Op<"transpose", [NoSideEffect]>,
 // TruncateIOp
 //===----------------------------------------------------------------------===//
 
-def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> {
+def TruncateIOp : Std_Op<"trunci", [NoSideEffect,
+                                    SameOperandsAndResultShape,
+                                    ElementwiseMappable]> {
   let summary = "integer truncation operation";
   let description = [{
     The integer truncation operation takes an integer input of
@@ -3826,7 +3837,7 @@ def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> {
 // UIToFPOp
 //===----------------------------------------------------------------------===//
 
-def UIToFPOp : CastOp<"uitofp">, Arguments<(ins AnyType:$in)> {
+def UIToFPOp : ArithmeticCastOp<"uitofp">, Arguments<(ins AnyType:$in)> {
   let summary = "cast from unsigned integer type to floating-point";
   let description = [{
     Cast from a value interpreted as unsigned integer or vector of unsigned
@@ -4053,7 +4064,9 @@ def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
 // ZeroExtendIOp
 //===----------------------------------------------------------------------===//
 
-def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, SameOperandsAndResultShape]> {
+def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect,
+                                     SameOperandsAndResultShape,
+                                     ElementwiseMappable]> {
   let summary = "integer zero extension operation";
   let description = [{
     The integer zero extension operation takes an integer input of

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 814b8573d8b3..b031769022d9 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1750,6 +1750,9 @@ def Terminator : NativeOpTrait<"IsTerminator">;
 // Op can be safely normalized in the presence of MemRefs with
 // non-identity maps.
 def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">;
+// Op can be systematically interconverted between scalar and vector/tensor
+// form by mapping elementwise based on the type.
+def ElementwiseMappable : NativeOpTrait<"ElementwiseMappable">;
 
 // Op's regions have a single block with the specified terminator.
 class SingleBlockImplicitTerminator<string op>

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index f1be18758756..d673b2b8161a 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -321,6 +321,7 @@ LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
 LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
 LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
 LogicalResult verifyNoRegionArguments(Operation *op);
+LogicalResult verifyElementwiseMappable(Operation *op);
 } // namespace impl
 
 /// Helper class for implementing traits.  Clients are not expected to interact
@@ -1182,6 +1183,93 @@ template <typename ConcrentType>
 struct MemRefsNormalizable
     : public TraitBase<ConcrentType, MemRefsNormalizable> {};
 
+/// This trait tags scalar ops that also can be applied to vectors/tensors, with
+/// their semantics on vectors/tensors being elementwise application.
+///
+/// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this
+/// trait. In particular, broadcasting behavior is not allowed. This trait
+/// describes a set of invariants that allow systematic
+/// vectorization/tensorization, and the reverse, scalarization. The properties
+/// needed for this also can be used to implement a number of
+/// transformations/analyses/interfaces.
+///
+/// An `ElementwiseMappable` op must satisfy the following properties:
+///
+/// 1. If any result is a vector (resp. tensor), then at least one operand must
+/// be a vector (resp. tensor).
+/// 2. If any operand is a vector (resp. tensor), then there must be at least
+/// one result, and all results must be vectors (resp. tensors).
+/// 3. The static types of all vector (resp. tensor) operands and results must
+/// have the same shape.
+/// 4. In the case of tensor operands, the dynamic shapes of all tensor operands
+/// must be the same, otherwise the op has undefined behavior.
+/// 5. ("systematic scalarization" property) If an op has vector/tensor
+/// operands/results, then the same op, with the operand/result types changed to
+/// their corresponding element type, shall be a verifier-valid op.
+/// 6. The semantics of the op on vectors (resp. tensors) shall be the same as
+/// applying the scalarized version of the op for each corresponding element of
+/// the vector (resp. tensor) operands in parallel.
+/// 7. ("systematic vectorization/tensorization" property) If an op has
+/// scalar operands/results, the op shall remain verifier-valid if all scalar
+/// operands are replaced with vectors/tensors of the same shape and
+/// corresponding element types.
+///
+/// Together, these properties provide an easy way for scalar operations to
+/// conveniently generalize their behavior to vectors/tensors, and systematize
+/// conversion between these forms.
+///
+/// Examples:
+/// ```
+/// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32
+/// // Applying the systematic vectorization/tensorization property, this op
+/// // must also be valid:
+/// %tensor = "std.addf"(%a_tensor, %b_tensor)
+///           : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>)
+///
+/// // These properties generalize well to the cases of non-scalar operands.
+/// %select_scalar_pred = "std.select"(%pred, %true_val, %false_val)
+///                       : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+/// // Applying the systematic vectorization / tensorization property, this
+/// // op must also be valid:
+/// %select_tensor_pred = "std.select"(%pred_tensor, %true_val, %false_val)
+///                       : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
+///                       -> tensor<?xf32>
+/// // Applying the systematic scalarization property, this op must also
+/// // be valid.
+/// %select_scalar = "std.select"(%pred, %true_val_scalar, %false_val_scalar)
+///                  : (i1, f32, f32) -> f32
+/// ```
+///
+/// TODO: Avoid hardcoding vector/tensor, and generalize this to any type
+/// implementing a new "ElementwiseMappableTypeInterface" that describes types
+/// for which it makes sense to apply a scalar function to each element.
+///
+/// Rationale:
+/// - 1. and 2. guarantee a well-defined iteration space for 6.
+///   - These also exclude the cases of 0 non-scalar operands or 0 non-scalar
+///     results, which complicate a generic definition of the iteration space.
+/// - 3. guarantees that folding can be done across scalars/vectors/tensors
+///   with the same pattern, as otherwise lots of special handling of type
+///   mismatches would be needed.
+/// - 4. guarantees that no error handling cases need to be considered.
+///   - Higher-level dialects should reify any needed guards / error handling
+///   code before lowering to an ElementwiseMappable op.
+/// - 5. and 6. allow defining the semantics on vectors/tensors via the scalar
+///   semantics and provide a constructive procedure for IR transformations
+///   to e.g. create scalar loop bodies from tensor ops.
+/// - 7. provides the reverse of 5., which when chained together allows
+///   reasoning about the relationship between the tensor and vector case.
+///   Additionally, it permits reasoning about promoting scalars to
+///   vectors/tensors via broadcasting in cases like `%select_scalar_pred`
+///   above.
+template <typename ConcreteType>
+struct ElementwiseMappable
+    : public TraitBase<ConcreteType, ElementwiseMappable> {
+  static LogicalResult verifyTrait(Operation *op) {
+    return ::mlir::OpTrait::impl::verifyElementwiseMappable(op);
+  }
+};
+
 } // end namespace OpTrait
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index c85efe83afe8..5fed2ec8e753 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1068,6 +1068,57 @@ LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) {
   return success();
 }
 
+/// Checks if two ShapedTypes are the same, ignoring the element type.
+static bool areSameShapedTypeIgnoringElementType(ShapedType a, ShapedType b) {
+  if (a.getTypeID() != b.getTypeID())
+    return false;
+  if (!a.hasRank())
+    return !b.hasRank();
+  return a.getShape() == b.getShape();
+}
+
+LogicalResult OpTrait::impl::verifyElementwiseMappable(Operation *op) {
+  auto isMappableType = [](Type type) {
+    return type.isa<VectorType, TensorType>();
+  };
+  auto resultMappableTypes = llvm::to_vector<1>(
+      llvm::make_filter_range(op->getResultTypes(), isMappableType));
+  auto operandMappableTypes = llvm::to_vector<2>(
+      llvm::make_filter_range(op->getOperandTypes(), isMappableType));
+
+  // If the op only has scalar operand/result types, then we have nothing to
+  // check.
+  if (resultMappableTypes.empty() && operandMappableTypes.empty())
+    return success();
+
+  if (!resultMappableTypes.empty() && operandMappableTypes.empty())
+    return op->emitOpError("if a result is non-scalar, then at least one "
+                           "operand must be non-scalar");
+
+  assert(!operandMappableTypes.empty());
+
+  if (resultMappableTypes.empty())
+    return op->emitOpError("if an operand is non-scalar, then there must be at "
+                           "least one non-scalar result");
+
+  if (resultMappableTypes.size() != op->getNumResults())
+    return op->emitOpError(
+        "if an operand is non-scalar, then all results must be non-scalar");
+
+  auto mustMatchType = operandMappableTypes[0].cast<ShapedType>();
+  for (auto type :
+       llvm::concat<Type>(resultMappableTypes, operandMappableTypes)) {
+    if (!areSameShapedTypeIgnoringElementType(type.cast<ShapedType>(),
+                                              mustMatchType)) {
+      return op->emitOpError() << "all non-scalar operands/results must have "
+                                  "the same shape and base type: found "
+                               << type << " and " << mustMatchType;
+    }
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // BinaryOp implementation
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir
index 3c4be54a1d43..04ecdc64351a 100644
--- a/mlir/test/Dialect/Standard/invalid.mlir
+++ b/mlir/test/Dialect/Standard/invalid.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -split-input-file %s -verify-diagnostics
 
 func @test_index_cast_shape_error(%arg0 : tensor<index>) -> tensor<2xi64> {
-  // expected-error @+1 {{requires the same shape for all operands and results}}
+  // expected-error @+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<2xi64>' and 'tensor<index>'}}
   %0 = index_cast %arg0 : tensor<index> to tensor<2xi64>
   return %0 : tensor<2xi64>
 }
@@ -9,7 +9,7 @@ func @test_index_cast_shape_error(%arg0 : tensor<index>) -> tensor<2xi64> {
 // -----
 
 func @test_index_cast_tensor_error(%arg0 : tensor<index>) -> i64 {
-  // expected-error @+1 {{requires the same shape for all operands and results}}
+  // expected-error @+1 {{if an operand is non-scalar, then there must be at least one non-scalar result}}
   %0 = index_cast %arg0 : tensor<index> to i64
   return %0 : i64
 }

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index b59353aa2f7c..76aff5c6d401 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -267,7 +267,7 @@ func @func_with_ops(i1, i32, i64) {
 
 func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) {
 ^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>):
-  // expected-error at +1 {{expected condition type to have the same shape as the result type, expected 'vector<42xi1>', but got 'vector<12xi1>'}}
+  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<42xi32>' and 'vector<12xi1>'}}
   %r = "std.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
 }
 
@@ -275,7 +275,7 @@ func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) {
 
 func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) {
 ^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
-  // expected-error at +1 {{expected condition type to have the same shape as the result type, expected 'tensor<42xi1>', but got 'tensor<12xi1>'}}
+  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<42xi32>' and 'tensor<12xi1>'}}
   %r = "std.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
 }
 
@@ -685,7 +685,7 @@ func @fpext_f32_to_i32(%arg0 : f32) {
 // -----
 
 func @fpext_vec(%arg0 : vector<2xf16>) {
-  // expected-error at +1 {{requires the same shape for all operands and results}}
+  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<3xf32>' and 'vector<2xf16>'}}
   %0 = fpext %arg0 : vector<2xf16> to vector<3xf32>
   return
 }
@@ -757,7 +757,7 @@ func @fptrunc_f32_to_i32(%arg0 : f32) {
 // -----
 
 func @fptrunc_vec(%arg0 : vector<2xf16>) {
-  // expected-error at +1 {{requires the same shape for all operands and results}}
+  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<3xf32>' and 'vector<2xf16>'}}
   %0 = fptrunc %arg0 : vector<2xf16> to vector<3xf32>
   return
 }

diff  --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index 206881a1c767..4dd6ef7dcf37 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -166,6 +166,84 @@ func @failedSameOperandAndResultType_operand_result_mismatch(%t10 : tensor<10xf3
 
 // -----
 
+func @failedElementwiseMappable_
diff erent_rankedness(%arg0: tensor<?xf32>, %arg1: tensor<*xf32>) {
+  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<*xf32>' and 'tensor<?xf32>'}}
+  %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor<?xf32>, tensor<*xf32>) -> tensor<*xf32>
+}
+
+// -----
+
+func @failedElementwiseMappable_
diff erent_rank(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) {
+  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<?x?xf32>' and 'tensor<?xf32>'}}
+  %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+}
+
+// -----
+
+func @failedElementwiseMappable_
diff erent_shape(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
+  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<5xf32>' and 'tensor<?xf32>'}}
+  %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor<?xf32>, tensor<5xf32>) -> tensor<?xf32>
+}
+
+// -----
+
+func @failedElementwiseMappable_
diff erent_base_type(%arg0: vector<2xf32>, %arg1: tensor<2xf32>) {
+  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<2xf32>' and 'vector<2xf32>'}}
+  %0 = "test.elementwise_mappable"(%arg0, %arg1) : (vector<2xf32>, tensor<2xf32>) -> tensor<2xf32>
+  return
+}
+
+// -----
+
+func @failedElementwiseMappable_non_scalar_output(%arg0: vector<2xf32>) {
+  // expected-error at +1 {{if an operand is non-scalar, then there must be at least one non-scalar result}}
+  %0 = "test.elementwise_mappable"(%arg0) : (vector<2xf32>) -> f32
+  return
+}
+
+// -----
+
+func @failedElementwiseMappable_non_scalar_result_all_scalar_input(%arg0: f32) {
+  // expected-error at +1 {{if a result is non-scalar, then at least one operand must be non-scalar}}
+  %0 = "test.elementwise_mappable"(%arg0) : (f32) -> tensor<f32>
+  return
+}
+
+// -----
+
+func @failedElementwiseMappable_mixed_scalar_non_scalar_results(%arg0: tensor<10xf32>) {
+  // expected-error at +1 {{if an operand is non-scalar, then all results must be non-scalar}}
+  %0, %1 = "test.elementwise_mappable"(%arg0) : (tensor<10xf32>) -> (f32, tensor<10xf32>)
+  return
+}
+
+// -----
+
+func @failedElementwiseMappable_zero_results(%arg0: tensor<10xf32>) {
+  // expected-error at +1 {{if an operand is non-scalar, then there must be at least one non-scalar result}}
+  "test.elementwise_mappable"(%arg0) : (tensor<10xf32>) -> ()
+  return
+}
+
+// -----
+
+func @failedElementwiseMappable_zero_operands() {
+  // expected-error at +1 {{if a result is non-scalar, then at least one operand must be non-scalar}}
+  "test.elementwise_mappable"() : () -> (tensor<6xf32>)
+  return
+}
+
+// -----
+
+func @succeededElementwiseMappable(%arg0: vector<2xf32>) {
+  // Check that varying element types are allowed.
+  // CHECK: test.elementwise_mappable
+  %0 = "test.elementwise_mappable"(%arg0) : (vector<2xf32>) -> vector<2xf16>
+  return
+}
+
+// -----
+
 func @failedHasParent_wrong_parent() {
   "some.op"() ({
    // expected-error at +1 {{'test.child' op expects parent op 'test.parent'}}

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 1ed6c545ff83..209d26c8feab 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -370,6 +370,12 @@ def SameOperandAndResultTypeOp : TEST_Op<"same_operand_and_result_type",
   let results = (outs Variadic<AnyType>);
 }
 
+def ElementwiseMappableOp : TEST_Op<"elementwise_mappable",
+    [ElementwiseMappable]> {
+  let arguments = (ins Variadic<AnyType>);
+  let results = (outs Variadic<AnyType>);
+}
+
 def ArgAndResHaveFixedElementTypesOp :
     TEST_Op<"arg_and_res_have_fixed_element_types",
       [PredOpTrait<"fixed type combination",


        


More information about the Mlir-commits mailing list