[llvm-branch-commits] [mlir] 3e62997 - [mlir] Fix arith verifier for tensor with encoding
Tobias Hieta via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Aug 10 00:09:53 PDT 2023
Author: Thomas Raoux
Date: 2023-08-10T09:06:20+02:00
New Revision: 3e62997c4c0fb48ce43496be80013dbf98b96a6e
URL: https://github.com/llvm/llvm-project/commit/3e62997c4c0fb48ce43496be80013dbf98b96a6e
DIFF: https://github.com/llvm/llvm-project/commit/3e62997c4c0fb48ce43496be80013dbf98b96a6e.diff
LOG: [mlir] Fix arith verifier for tensor with encoding
The verifier for some arith ops were not considering that ranked
tensor types can have encodings.
Differential Revision: https://reviews.llvm.org/D156557
Added:
Modified:
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/invalid.mlir
mlir/test/Dialect/Arith/ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 1c41818c318d24..eb73787b354778 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -130,13 +130,10 @@ namespace {
/// Return the type of the same shape (scalar, vector or tensor) containing i1.
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
- if (auto tensorType = llvm::dyn_cast<RankedTensorType>(type))
- return RankedTensorType::get(tensorType.getShape(), i1Type);
+ if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
+ return shapedType.cloneWith(std::nullopt, i1Type);
if (llvm::isa<UnrankedTensorType>(type))
return UnrankedTensorType::get(i1Type);
- if (auto vectorType = llvm::dyn_cast<VectorType>(type))
- return VectorType::get(vectorType.getShape(), i1Type,
- vectorType.getScalableDims());
return i1Type;
}
@@ -1150,9 +1147,21 @@ static Type getTypeIfLikeOrMemRef(Type type) {
type_list<ElementTypes...>());
}
+/// Return false if both types are ranked tensor with mismatching encoding.
+static bool hasSameEncoding(Type typeA, Type typeB) {
+ auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
+ auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
+ if (!rankedTensorA || !rankedTensorB)
+ return true;
+ return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
+}
+
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
- return inputs.size() == 1 && outputs.size() == 1 &&
- succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
+ if (inputs.size() != 1 || outputs.size() != 1)
+ return false;
+ if (!hasSameEncoding(inputs.front(), outputs.front()))
+ return false;
+ return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 9f131e5afab058..6d8ac0ada52be3 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -206,6 +206,15 @@ func.func @func_with_ops() {
// -----
+func.func @func_with_ops() {
+^bb0:
+ %c = arith.constant dense<0> : tensor<42 x i32, "foo">
+ // expected-error at +1 {{op failed to verify that result type has i1 element type and same shape as operands}}
+ %r = "arith.cmpi"(%c, %c) {predicate = 0} : (tensor<42 x i32, "foo">, tensor<42 x i32, "foo">) -> tensor<42 x i1, "bar">
+}
+
+// -----
+
func.func @invalid_cmp_shape(%idx : () -> ()) {
// expected-error at +1 {{'lhs' must be signless-integer-like, but got '() -> ()'}}
%cmp = arith.cmpi eq, %idx, %idx : () -> ()
@@ -420,6 +429,14 @@ func.func @fpext_vec_f32_to_i32(%arg0 : vector<2xf32>) {
// -----
+func.func @fpext_vec_f32_to_i32(%arg0 : tensor<2xf32, "foo">) {
+ // expected-error at +1 {{op operand type 'tensor<2xf32, "foo">' and result type 'tensor<2xf64, "bar">' are cast incompatible}}
+ %0 = arith.extf %arg0 : tensor<2xf32, "foo"> to tensor<2xf64, "bar">
+ return
+}
+
+// -----
+
func.func @fptrunc_f16_to_f32(%arg0 : f16) {
// expected-error at +1 {{are cast incompatible}}
%0 = arith.truncf %arg0 : f16 to f32
@@ -769,3 +786,12 @@ func.func @disallow_zero_rank_tensor_with_unranked_tensor(%arg0 : tensor<i1>, %a
%0 = arith.select %arg0, %arg1, %arg2 : tensor<i1>, tensor<2x?xi64>
return %0 : tensor<2x?xi64>
}
+
+// -----
+
+func.func @select_tensor_encoding(
+ %arg0 : tensor<8xi1, "bar">, %arg1 : tensor<8xi32, "foo">, %arg2 : tensor<8xi32, "foo">) -> tensor<8xi32, "foo"> {
+ // expected-error @+1 {{'arith.select' op expected condition type to have the same shape as the result type}}
+ %0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "bar">, tensor<8xi32, "foo">
+ return %0 : tensor<8xi32, "foo">
+}
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 25c015bfa3ccc5..faa138a170ddfa 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -637,6 +637,12 @@ func.func @test_extf_tensor(%arg0 : tensor<8x8xf32>) -> tensor<8x8xf64> {
return %0 : tensor<8x8xf64>
}
+// CHECK-LABEL: test_extf_tensor_encoding
+func.func @test_extf_tensor_encoding(%arg0 : tensor<8x8xf32, "foo">) -> tensor<8x8xf64, "foo"> {
+ %0 = arith.extf %arg0 : tensor<8x8xf32, "foo"> to tensor<8x8xf64, "foo">
+ return %0 : tensor<8x8xf64, "foo">
+}
+
// CHECK-LABEL: test_extf_vector
func.func @test_extf_vector(%arg0 : vector<8xf32>) -> vector<8xf64> {
%0 = arith.extf %arg0 : vector<8xf32> to vector<8xf64>
@@ -950,6 +956,12 @@ func.func @test_cmpi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) ->
return %0 : tensor<8x8xi1>
}
+// CHECK-LABEL: test_cmpi_tensor_encoding
+func.func @test_cmpi_tensor_encoding(%arg0 : tensor<8x8xi64, "foo">, %arg1 : tensor<8x8xi64, "foo">) -> tensor<8x8xi1, "foo"> {
+ %0 = arith.cmpi slt, %arg0, %arg1 : tensor<8x8xi64, "foo">
+ return %0 : tensor<8x8xi1, "foo">
+}
+
// CHECK-LABEL: test_cmpi_vector
func.func @test_cmpi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi1> {
%0 = arith.cmpi ult, %arg0, %arg1 : vector<8xi64>
@@ -1103,3 +1115,18 @@ func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
return
}
+
+// CHECK-LABEL: @select_tensor
+func.func @select_tensor(%arg0 : tensor<8xi1>, %arg1 : tensor<8xi32>, %arg2 : tensor<8xi32>) -> tensor<8xi32> {
+ // CHECK: = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xi1>, tensor<8xi32>
+ %0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1>, tensor<8xi32>
+ return %0 : tensor<8xi32>
+}
+
+// CHECK-LABEL: @select_tensor_encoding
+func.func @select_tensor_encoding(
+ %arg0 : tensor<8xi1, "foo">, %arg1 : tensor<8xi32, "foo">, %arg2 : tensor<8xi32, "foo">) -> tensor<8xi32, "foo"> {
+ // CHECK: = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xi1, "foo">, tensor<8xi32, "foo">
+ %0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "foo">, tensor<8xi32, "foo">
+ return %0 : tensor<8xi32, "foo">
+}
More information about the llvm-branch-commits
mailing list