[Mlir-commits] [mlir] [mlir] Allow all shaped types for arith ops. (PR #99028)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jul 16 05:53:10 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir-ods
Author: Alexander Belyaev (pifon2a)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/99028.diff
2 Files Affected:
- (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+5-2)
- (modified) mlir/test/Dialect/Arith/ops.mlir (+6)
``````````diff
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index af4f13dc09360..9414c61feb365 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -636,6 +636,10 @@ def AnyScalableVector : ScalableVectorOf<[AnyType]>;
// Shaped types.
+class ShapedTypeOf<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes, IsShapedTypePred, "shaped",
+ "::mlir::ShapedType">;
+
def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",
"::mlir::ShapedType">;
@@ -844,8 +848,7 @@ class NestedTupleOf<list<Type> allowedTypes> :
// Type constraint for types that are "like" some type or set of types T, that is
// they're either a T, a vector of Ts, or a tensor of Ts
class TypeOrContainer<Type allowedType, string name> : TypeConstraint<Or<[
- allowedType.predicate, VectorOf<[allowedType]>.predicate,
- TensorOf<[allowedType]>.predicate]>,
+ allowedType.predicate, ShapedTypeOf<[allowedType]>.predicate]>,
name>;
// Temporary constraint to allow gradual transition to supporting 0-D vectors.
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index f684e02344a51..0e786c211431a 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -13,6 +13,12 @@ func.func @test_addi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) ->
return %0 : tensor<8x8xi64>
}
+// CHECK-LABEL: test_addi_unranked_tensor
+func.func @test_addi_unranked_tensor(%arg0 : tensor<*xi32>, %arg1 : tensor<*xi32>) -> tensor<*xi32> {
+ %0 = arith.addi %arg0, %arg1 : tensor<*xi32>
+ return %0 : tensor<*xi32>
+}
+
// CHECK-LABEL: test_addi_vector
func.func @test_addi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> {
%0 = arith.addi %arg0, %arg1 : vector<8xi64>
``````````
</details>
https://github.com/llvm/llvm-project/pull/99028
More information about the Mlir-commits
mailing list