[Mlir-commits] [mlir] [mlir] Allow all shaped types for arith ops. (PR #99028)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 16 05:53:10 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir-ods

Author: Alexander Belyaev (pifon2a)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/99028.diff


2 Files Affected:

- (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+5-2) 
- (modified) mlir/test/Dialect/Arith/ops.mlir (+6) 


``````````diff
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index af4f13dc09360..9414c61feb365 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -636,6 +636,10 @@ def AnyScalableVector : ScalableVectorOf<[AnyType]>;
 
 // Shaped types.
 
+class ShapedTypeOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsShapedTypePred, "shaped",
+                      "::mlir::ShapedType">;
+
 def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",
                                    "::mlir::ShapedType">;
 
@@ -844,8 +848,7 @@ class NestedTupleOf<list<Type> allowedTypes> :
 // Type constraint for types that are "like" some type or set of types T, that is
 // they're either a T, a vector of Ts, or a tensor of Ts
 class TypeOrContainer<Type allowedType, string name> : TypeConstraint<Or<[
-  allowedType.predicate, VectorOf<[allowedType]>.predicate,
-  TensorOf<[allowedType]>.predicate]>,
+  allowedType.predicate, ShapedTypeOf<[allowedType]>.predicate]>,
   name>;
 
 // Temporary constraint to allow gradual transition to supporting 0-D vectors.
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index f684e02344a51..0e786c211431a 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -13,6 +13,12 @@ func.func @test_addi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) ->
   return %0 : tensor<8x8xi64>
 }
 
+// CHECK-LABEL: test_addi_unranked_tensor
+func.func @test_addi_unranked_tensor(%arg0 : tensor<*xi32>, %arg1 : tensor<*xi32>) -> tensor<*xi32> {
+  %0 = arith.addi %arg0, %arg1 : tensor<*xi32>
+  return %0 : tensor<*xi32>
+}
+
 // CHECK-LABEL: test_addi_vector
 func.func @test_addi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> {
   %0 = arith.addi %arg0, %arg1 : vector<8xi64>

``````````

</details>


https://github.com/llvm/llvm-project/pull/99028


More information about the Mlir-commits mailing list