[Mlir-commits] [mlir] [mlir] Allow all shaped types for arith ops. (PR #99028)
Alexander Belyaev
llvmlistbot at llvm.org
Wed Jul 17 03:15:52 PDT 2024
https://github.com/pifon2a updated https://github.com/llvm/llvm-project/pull/99028
>From 4b10aca51f8c39e85b0561e13839b6acf8237d06 Mon Sep 17 00:00:00 2001
From: Alexander Belyaev <pifon at google.com>
Date: Tue, 16 Jul 2024 14:49:53 +0200
Subject: [PATCH] [mlir] Allow all shaped types for arith ops.
---
mlir/include/mlir/IR/CommonTypeConstraints.td | 9 ++++++---
mlir/test/Dialect/Arith/ops.mlir | 6 ++++++
2 files changed, 12 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index af4f13dc09360..ff736966c6212 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -636,6 +636,10 @@ def AnyScalableVector : ScalableVectorOf<[AnyType]>;
// Shaped types.
+class ShapedTypeOf<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes, IsShapedTypePred, "shaped",
+ "::mlir::ShapedType">;
+
def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",
"::mlir::ShapedType">;
@@ -842,10 +846,9 @@ class NestedTupleOf<list<Type> allowedTypes> :
// Common type constraints
//===----------------------------------------------------------------------===//
// Type constraint for types that are "like" some type or set of types T, that is
-// they're either a T, a vector of Ts, or a tensor of Ts
+// they're either a T or a shaped type of Ts
class TypeOrContainer<Type allowedType, string name> : TypeConstraint<Or<[
- allowedType.predicate, VectorOf<[allowedType]>.predicate,
- TensorOf<[allowedType]>.predicate]>,
+ allowedType.predicate, ShapedTypeOf<[allowedType]>.predicate]>,
name>;
// Temporary constraint to allow gradual transition to supporting 0-D vectors.
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index f684e02344a51..0e786c211431a 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -13,6 +13,12 @@ func.func @test_addi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) ->
return %0 : tensor<8x8xi64>
}
+// CHECK-LABEL: test_addi_unranked_tensor
+func.func @test_addi_unranked_tensor(%arg0 : tensor<*xi32>, %arg1 : tensor<*xi32>) -> tensor<*xi32> {
+ %0 = arith.addi %arg0, %arg1 : tensor<*xi32>
+ return %0 : tensor<*xi32>
+}
+
// CHECK-LABEL: test_addi_vector
func.func @test_addi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> {
%0 = arith.addi %arg0, %arg1 : vector<8xi64>
More information about the Mlir-commits
mailing list