[llvm-branch-commits] [mlir] 3e62997 - [mlir] Fix arith verifier for tensor with encoding

Tobias Hieta via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Aug 10 00:09:53 PDT 2023


Author: Thomas Raoux
Date: 2023-08-10T09:06:20+02:00
New Revision: 3e62997c4c0fb48ce43496be80013dbf98b96a6e

URL: https://github.com/llvm/llvm-project/commit/3e62997c4c0fb48ce43496be80013dbf98b96a6e
DIFF: https://github.com/llvm/llvm-project/commit/3e62997c4c0fb48ce43496be80013dbf98b96a6e.diff

LOG: [mlir] Fix arith verifier for tensor with encoding

The verifier for some arith ops were not considering that ranked
tensor types can have encodings.

Differential Revision: https://reviews.llvm.org/D156557

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Dialect/Arith/invalid.mlir
    mlir/test/Dialect/Arith/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 1c41818c318d24..eb73787b354778 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -130,13 +130,10 @@ namespace {
 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
 static Type getI1SameShape(Type type) {
   auto i1Type = IntegerType::get(type.getContext(), 1);
-  if (auto tensorType = llvm::dyn_cast<RankedTensorType>(type))
-    return RankedTensorType::get(tensorType.getShape(), i1Type);
+  if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
+    return shapedType.cloneWith(std::nullopt, i1Type);
   if (llvm::isa<UnrankedTensorType>(type))
     return UnrankedTensorType::get(i1Type);
-  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
-    return VectorType::get(vectorType.getShape(), i1Type,
-                           vectorType.getScalableDims());
   return i1Type;
 }
 
@@ -1150,9 +1147,21 @@ static Type getTypeIfLikeOrMemRef(Type type) {
                            type_list<ElementTypes...>());
 }
 
+/// Return false if both types are ranked tensor with mismatching encoding.
+static bool hasSameEncoding(Type typeA, Type typeB) {
+  auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
+  auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
+  if (!rankedTensorA || !rankedTensorB)
+    return true;
+  return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
+}
+
 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
-  return inputs.size() == 1 && outputs.size() == 1 &&
-         succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
+  if (inputs.size() != 1 || outputs.size() != 1)
+    return false;
+  if (!hasSameEncoding(inputs.front(), outputs.front()))
+    return false;
+  return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 9f131e5afab058..6d8ac0ada52be3 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -206,6 +206,15 @@ func.func @func_with_ops() {
 
 // -----
 
+func.func @func_with_ops() {
+^bb0:
+  %c = arith.constant dense<0> : tensor<42 x i32, "foo">
+  // expected-error at +1 {{op failed to verify that result type has i1 element type and same shape as operands}}
+  %r = "arith.cmpi"(%c, %c) {predicate = 0} : (tensor<42 x i32, "foo">, tensor<42 x i32, "foo">) -> tensor<42 x i1, "bar">
+}
+
+// -----
+
 func.func @invalid_cmp_shape(%idx : () -> ()) {
   // expected-error at +1 {{'lhs' must be signless-integer-like, but got '() -> ()'}}
   %cmp = arith.cmpi eq, %idx, %idx : () -> ()
@@ -420,6 +429,14 @@ func.func @fpext_vec_f32_to_i32(%arg0 : vector<2xf32>) {
 
 // -----
 
+func.func @fpext_vec_f32_to_i32(%arg0 : tensor<2xf32, "foo">) {
+  // expected-error at +1 {{op operand type 'tensor<2xf32, "foo">' and result type 'tensor<2xf64, "bar">' are cast incompatible}}
+  %0 = arith.extf %arg0 : tensor<2xf32, "foo"> to tensor<2xf64, "bar">
+  return
+}
+
+// -----
+
 func.func @fptrunc_f16_to_f32(%arg0 : f16) {
   // expected-error at +1 {{are cast incompatible}}
   %0 = arith.truncf %arg0 : f16 to f32
@@ -769,3 +786,12 @@ func.func @disallow_zero_rank_tensor_with_unranked_tensor(%arg0 : tensor<i1>, %a
   %0 = arith.select %arg0, %arg1, %arg2 : tensor<i1>, tensor<2x?xi64>
   return %0 : tensor<2x?xi64>
 }
+
+// -----
+
+func.func @select_tensor_encoding(
+  %arg0 : tensor<8xi1, "bar">, %arg1 : tensor<8xi32, "foo">, %arg2 : tensor<8xi32, "foo">) -> tensor<8xi32, "foo"> {
+  // expected-error @+1 {{'arith.select' op expected condition type to have the same shape as the result type}}
+  %0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "bar">, tensor<8xi32, "foo">
+  return %0 : tensor<8xi32, "foo">
+}

diff  --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 25c015bfa3ccc5..faa138a170ddfa 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -637,6 +637,12 @@ func.func @test_extf_tensor(%arg0 : tensor<8x8xf32>) -> tensor<8x8xf64> {
   return %0 : tensor<8x8xf64>
 }
 
+// CHECK-LABEL: test_extf_tensor_encoding
+func.func @test_extf_tensor_encoding(%arg0 : tensor<8x8xf32, "foo">) -> tensor<8x8xf64, "foo"> {
+  %0 = arith.extf %arg0 : tensor<8x8xf32, "foo"> to tensor<8x8xf64, "foo">
+  return %0 : tensor<8x8xf64, "foo">
+}
+
 // CHECK-LABEL: test_extf_vector
 func.func @test_extf_vector(%arg0 : vector<8xf32>) -> vector<8xf64> {
   %0 = arith.extf %arg0 : vector<8xf32> to vector<8xf64>
@@ -950,6 +956,12 @@ func.func @test_cmpi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) ->
   return %0 : tensor<8x8xi1>
 }
 
+// CHECK-LABEL: test_cmpi_tensor_encoding
+func.func @test_cmpi_tensor_encoding(%arg0 : tensor<8x8xi64, "foo">, %arg1 : tensor<8x8xi64, "foo">) -> tensor<8x8xi1, "foo"> {
+  %0 = arith.cmpi slt, %arg0, %arg1 : tensor<8x8xi64, "foo">
+  return %0 : tensor<8x8xi1, "foo">
+}
+
 // CHECK-LABEL: test_cmpi_vector
 func.func @test_cmpi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi1> {
   %0 = arith.cmpi ult, %arg0, %arg1 : vector<8xi64>
@@ -1103,3 +1115,18 @@ func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
 
   return
 }
+
+// CHECK-LABEL: @select_tensor
+func.func @select_tensor(%arg0 : tensor<8xi1>, %arg1 : tensor<8xi32>, %arg2 : tensor<8xi32>) -> tensor<8xi32> {
+  // CHECK: = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xi1>, tensor<8xi32>
+  %0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1>, tensor<8xi32>
+  return %0 : tensor<8xi32>
+}
+
+// CHECK-LABEL: @select_tensor_encoding
+func.func @select_tensor_encoding(
+  %arg0 : tensor<8xi1, "foo">, %arg1 : tensor<8xi32, "foo">, %arg2 : tensor<8xi32, "foo">) -> tensor<8xi32, "foo"> {
+  // CHECK: = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xi1, "foo">, tensor<8xi32, "foo">
+  %0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "foo">, tensor<8xi32, "foo">
+  return %0 : tensor<8xi32, "foo">
+}


        


More information about the llvm-branch-commits mailing list