[Mlir-commits] [mlir] 2a71f95 - [MLIR] Allow compatible shapes in `Elementwise` operations

Frederik Gossen llvmlistbot at llvm.org
Mon Mar 15 01:56:37 PDT 2021


Author: Frederik Gossen
Date: 2021-03-15T09:56:20+01:00
New Revision: 2a71f95767490f5ac65d42bf55ad571e6fbd1123

URL: https://github.com/llvm/llvm-project/commit/2a71f95767490f5ac65d42bf55ad571e6fbd1123
DIFF: https://github.com/llvm/llvm-project/commit/2a71f95767490f5ac65d42bf55ad571e6fbd1123.diff

LOG: [MLIR] Allow compatible shapes in `Elementwise` operations

Differential Revision: https://reviews.llvm.org/D98186

Added: 
    

Modified: 
    mlir/lib/IR/Operation.cpp
    mlir/test/Dialect/Standard/invalid.mlir
    mlir/test/IR/invalid-ops.mlir
    mlir/test/IR/traits.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 0614a7bd63d3..f427e100d347 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1051,15 +1051,6 @@ LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) {
   return success();
 }
 
-/// Checks if two ShapedTypes are the same, ignoring the element type.
-static bool areSameShapedTypeIgnoringElementType(ShapedType a, ShapedType b) {
-  if (a.getTypeID() != b.getTypeID())
-    return false;
-  if (!a.hasRank())
-    return !b.hasRank();
-  return a.getShape() == b.getShape();
-}
-
 LogicalResult OpTrait::impl::verifyElementwise(Operation *op) {
   auto isMappableType = [](Type type) {
     return type.isa<VectorType, TensorType>();
@@ -1088,15 +1079,14 @@ LogicalResult OpTrait::impl::verifyElementwise(Operation *op) {
     return op->emitOpError(
         "if an operand is non-scalar, then all results must be non-scalar");
 
-  auto mustMatchType = operandMappableTypes[0].cast<ShapedType>();
-  for (auto type :
-       llvm::concat<Type>(resultMappableTypes, operandMappableTypes)) {
-    if (!areSameShapedTypeIgnoringElementType(type.cast<ShapedType>(),
-                                              mustMatchType)) {
-      return op->emitOpError() << "all non-scalar operands/results must have "
-                                  "the same shape and base type: found "
-                               << type << " and " << mustMatchType;
-    }
+  SmallVector<Type, 4> types = llvm::to_vector<2>(
+      llvm::concat<Type>(operandMappableTypes, resultMappableTypes));
+  TypeID expectedBaseTy = types.front().getTypeID();
+  if (!llvm::all_of(types,
+                    [&](Type t) { return t.getTypeID() == expectedBaseTy; }) ||
+      failed(verifyCompatibleShapes(types))) {
+    return op->emitOpError() << "all non-scalar operands/results must have the "
+                                "same shape and base type";
   }
 
   return success();

diff  --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir
index 9b986e5ef75f..c6f11a871848 100644
--- a/mlir/test/Dialect/Standard/invalid.mlir
+++ b/mlir/test/Dialect/Standard/invalid.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -split-input-file %s -verify-diagnostics
 
 func @test_index_cast_shape_error(%arg0 : tensor<index>) -> tensor<2xi64> {
-  // expected-error @+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<2xi64>' and 'tensor<index>'}}
+  // expected-error @+1 {{all non-scalar operands/results must have the same shape and base type}}
   %0 = index_cast %arg0 : tensor<index> to tensor<2xi64>
   return %0 : tensor<2xi64>
 }

diff  --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir
index b5ee96839fbd..797c1d478364 100644
--- a/mlir/test/IR/invalid-ops.mlir
+++ b/mlir/test/IR/invalid-ops.mlir
@@ -236,7 +236,7 @@ func @func_with_ops(i32, i32) {
 func @func_with_ops() {
 ^bb0:
   %c = constant dense<0> : vector<42 x i32>
-  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<41xi1>' and 'vector<42xi32>'}}
+  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type}}
   %r = "std.cmpi"(%c, %c) {predicate = 0} : (vector<42 x i32>, vector<42 x i32>) -> vector<41 x i1>
 }
 
@@ -269,7 +269,7 @@ func @func_with_ops(i1, i32, i64) {
 
 func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) {
 ^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>):
-  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<42xi32>' and 'vector<12xi1>'}}
+  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type}}
   %r = "std.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
 }
 
@@ -277,7 +277,7 @@ func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) {
 
 func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) {
 ^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
-  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<42xi32>' and 'tensor<12xi1>'}}
+  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type}}
   %r = "std.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
 }
 
@@ -514,7 +514,7 @@ func @cmpf_canonical_wrong_result_type(%a : f32, %b : f32) -> f32 {
 // -----
 
 func @cmpf_result_shape_mismatch(%a : vector<42xf32>) {
-  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<41xi1>' and 'vector<42xf32>'}}
+  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type}}
   %r = "std.cmpf"(%a, %a) {predicate = 0} : (vector<42 x f32>, vector<42 x f32>) -> vector<41 x i1>
 }
 
@@ -614,7 +614,7 @@ func @fpext_f32_to_i32(%arg0 : f32) {
 // -----
 
 func @fpext_vec(%arg0 : vector<2xf16>) {
-  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<3xf32>' and 'vector<2xf16>'}}
+  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type}}
   %0 = fpext %arg0 : vector<2xf16> to vector<3xf32>
   return
 }
@@ -686,7 +686,7 @@ func @fptrunc_f32_to_i32(%arg0 : f32) {
 // -----
 
 func @fptrunc_vec(%arg0 : vector<2xf16>) {
-  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<3xf32>' and 'vector<2xf16>'}}
+  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type}}
   %0 = fptrunc %arg0 : vector<2xf16> to vector<3xf32>
   return
 }

diff  --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index 858f601c0211..dc9e5106a8d7 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -169,33 +169,38 @@ func @succeededSameOperandAndResultType(%t10x10 : tensor<10x10xf32>, %t1: tensor
 func @failedSameOperandAndResultType_operand_result_mismatch(%t10 : tensor<10xf32>, %t20 : tensor<20xf32>) {
   // expected-error at +1 {{requires the same type for all operands and results}}
   "test.same_operand_and_result_type"(%t10, %t20) : (tensor<10xf32>, tensor<20xf32>) -> tensor<10xf32>
+  return
 }
 
 // -----
 
 func @failedElementwiseMappable_
diff erent_rankedness(%arg0: tensor<?xf32>, %arg1: tensor<*xf32>) {
-  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<*xf32>' and 'tensor<?xf32>'}}
+  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type}}
   %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor<?xf32>, tensor<*xf32>) -> tensor<*xf32>
+  return
 }
 
 // -----
 
 func @failedElementwiseMappable_
diff erent_rank(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) {
-  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<?x?xf32>' and 'tensor<?xf32>'}}
+  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type}}
   %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+  return
 }
 
 // -----
 
-func @failedElementwiseMappable_
diff erent_shape(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
-  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<5xf32>' and 'tensor<?xf32>'}}
-  %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor<?xf32>, tensor<5xf32>) -> tensor<?xf32>
+func @elementwiseMappable_dynamic_shapes(%arg0: tensor<?xf32>,
+    %arg1: tensor<5xf32>) {
+  %0 = "test.elementwise_mappable"(%arg0, %arg1) :
+      (tensor<?xf32>, tensor<5xf32>) -> tensor<?xf32>
+  return
 }
 
 // -----
 
 func @failedElementwiseMappable_
diff erent_base_type(%arg0: vector<2xf32>, %arg1: tensor<2xf32>) {
-  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<2xf32>' and 'vector<2xf32>'}}
+  // expected-error at +1 {{all non-scalar operands/results must have the same shape and base type}}
   %0 = "test.elementwise_mappable"(%arg0, %arg1) : (vector<2xf32>, tensor<2xf32>) -> tensor<2xf32>
   return
 }


        


More information about the Mlir-commits mailing list