[Mlir-commits] [mlir] [mlir] Allow all shaped types for arith ops. (PR #99028)

Alexander Belyaev llvmlistbot at llvm.org
Wed Jul 17 03:15:52 PDT 2024


https://github.com/pifon2a updated https://github.com/llvm/llvm-project/pull/99028

>From 4b10aca51f8c39e85b0561e13839b6acf8237d06 Mon Sep 17 00:00:00 2001
From: Alexander Belyaev <pifon at google.com>
Date: Tue, 16 Jul 2024 14:49:53 +0200
Subject: [PATCH] [mlir] Allow all shaped types for arith ops.

---
 mlir/include/mlir/IR/CommonTypeConstraints.td | 9 ++++++---
 mlir/test/Dialect/Arith/ops.mlir              | 6 ++++++
 2 files changed, 12 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index af4f13dc09360..ff736966c6212 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -636,6 +636,10 @@ def AnyScalableVector : ScalableVectorOf<[AnyType]>;
 
 // Shaped types.
 
+class ShapedTypeOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsShapedTypePred, "shaped",
+                      "::mlir::ShapedType">;
+
 def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",
                                    "::mlir::ShapedType">;
 
@@ -842,10 +846,9 @@ class NestedTupleOf<list<Type> allowedTypes> :
 // Common type constraints
 //===----------------------------------------------------------------------===//
 // Type constraint for types that are "like" some type or set of types T, that is
-// they're either a T, a vector of Ts, or a tensor of Ts
+// they're either a T or a shaped type of Ts
 class TypeOrContainer<Type allowedType, string name> : TypeConstraint<Or<[
-  allowedType.predicate, VectorOf<[allowedType]>.predicate,
-  TensorOf<[allowedType]>.predicate]>,
+  allowedType.predicate, ShapedTypeOf<[allowedType]>.predicate]>,
   name>;
 
 // Temporary constraint to allow gradual transition to supporting 0-D vectors.
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index f684e02344a51..0e786c211431a 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -13,6 +13,12 @@ func.func @test_addi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) ->
   return %0 : tensor<8x8xi64>
 }
 
+// CHECK-LABEL: test_addi_unranked_tensor
+func.func @test_addi_unranked_tensor(%arg0 : tensor<*xi32>, %arg1 : tensor<*xi32>) -> tensor<*xi32> {
+  %0 = arith.addi %arg0, %arg1 : tensor<*xi32>
+  return %0 : tensor<*xi32>
+}
+
 // CHECK-LABEL: test_addi_vector
 func.func @test_addi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> {
   %0 = arith.addi %arg0, %arg1 : vector<8xi64>



More information about the Mlir-commits mailing list