[Mlir-commits] [mlir] 5bfe4b9 - [mlir][arith] Disallow casting tensor dimensions (#93349)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 28 17:04:44 PDT 2024
Author: Jakub Kuderski
Date: 2024-05-28T20:04:41-04:00
New Revision: 5bfe4b93e15ad38f211c5dec64be0eeaa4c8e914
URL: https://github.com/llvm/llvm-project/commit/5bfe4b93e15ad38f211c5dec64be0eeaa4c8e914
DIFF: https://github.com/llvm/llvm-project/commit/5bfe4b93e15ad38f211c5dec64be0eeaa4c8e914.diff
LOG: [mlir][arith] Disallow casting tensor dimensions (#93349)
Tighten the verifier for arith cast ops to disallow changing tensor
dimensions, e.g., static to dynamic. After this change:
* `arith.cast_op %x : tensor<4xi32> to tensor<4xf32>` remains valid
* `arith.cast_op %x : tensor<4xi32> to tensor<?xf32>` becomes invalid
* `arith.cast_op %x : tensor<?xi32> to tensor<4xf32>` becomes invalid
This is mostly to simplify the op semantics. See the discussion thread
for more context:
https://discourse.llvm.org/t/rfc-remove-arith-math-ops-on-tensors/74357/63.
Added:
Modified:
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/test/Dialect/Arith/canonicalize.mlir
mlir/test/Dialect/Arith/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 46248dad3be9e..81ed0f924a2e2 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -83,12 +83,25 @@ class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
attr-dict `:` type($result) }];
}
+// Checks that tensor input and outputs have identical shapes. This is stricker
+// than the verification done in `SameOperandsAndResultShape` that allows for
+// tensor dimensions to be 'compatible' (e.g., dynamic dimensions being
+// compatible with static ones).
+def SameInputOutputTensorDims : PredOpTrait<
+ "input and output have the same tensor dimensions",
+ AllMatchSameOperatorPred<["in", "out"],
+ "(::llvm::isa<::mlir::TensorType>($_self.getType()) ?"
+ " ::llvm::cast<::mlir::TensorType>($_self.getType()).getShape() :"
+ " ::llvm::ArrayRef<int64_t>{})">>;
+
// Base class for arithmetic cast operations. Requires a single operand and
-// result. If either is a shaped type, then the other must be of the same shape.
+// result. If either is a shaped type, then the other must be of the same
+// shape. In the case of tensor types, this also includes the corresponding
+// operand/result dimensions being equal.
class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
list<Trait> traits = []> :
Arith_Op<mnemonic, traits # [Pure, SameOperandsAndResultShape,
- DeclareOpInterfaceMethods<CastOpInterface>]>,
+ SameInputOutputTensorDims, DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins From:$in)>,
Results<(outs To:$out)> {
let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)";
@@ -1231,7 +1244,7 @@ def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
def Arith_TruncFOp :
Arith_Op<"truncf",
- [Pure, SameOperandsAndResultShape,
+ [Pure, SameOperandsAndResultShape, SameInputOutputTensorDims,
DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins FloatLike:$in,
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 1a387c20c4b29..e4f95bb0545a2 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2950,14 +2950,6 @@ func.func @unsignedExtendConstantResource() -> tensor<i16> {
return %ext : tensor<i16>
}
-// Just checks that this doesn't crash.
-// CHECK-LABEL: @signedExtendSplatAsDynamicShape
-func.func @signedExtendSplatAsDynamicShape() -> tensor<?xi64> {
- %splat = arith.constant dense<5> : tensor<2xi16>
- %extsplat = arith.extsi %splat : tensor<2xi16> to tensor<?xi64>
- return %extsplat : tensor<?xi64>
-}
-
// CHECK-LABEL: @extsi_i0
// CHECK: %[[ZERO:.*]] = arith.constant 0 : i16
// CHECK: return %[[ZERO]] : i16
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index ada849220bb83..652aa738ad392 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -1,13 +1,21 @@
// RUN: mlir-opt -split-input-file %s -verify-diagnostics
func.func @test_index_cast_shape_error(%arg0 : tensor<index>) -> tensor<2xi64> {
- // expected-error @+1 {{'arith.index_cast' op requires the same shape for all operands and results}}
+ // expected-error @+1 {{'arith.index_cast' op failed to verify that input and output have the same tensor dimensions}}
%0 = arith.index_cast %arg0 : tensor<index> to tensor<2xi64>
return %0 : tensor<2xi64>
}
// -----
+func.func @test_index_cast_shape_dim_error(%arg0 : tensor<2xindex>) -> tensor<?xi64> {
+ // expected-error @+1 {{'arith.index_cast' op failed to verify that input and output have the same tensor dimensions}}
+ %0 = arith.index_cast %arg0 : tensor<2xindex> to tensor<?xi64>
+ return %0 : tensor<?xi64>
+}
+
+// -----
+
func.func @test_index_cast_tensor_error(%arg0 : tensor<index>) -> i64 {
// expected-error @+1 {{'arith.index_cast' op requires the same shape for all operands and results}}
%0 = arith.index_cast %arg0 : tensor<index> to i64
@@ -655,6 +663,14 @@ func.func @extsi_scalable_to_fl(%arg0 : vector<[4]xi32>) {
// -----
+func.func @extsi_tensor_dim(%arg0 : tensor<4xi32>) {
+ // expected-error at +1 {{'arith.extsi' op failed to verify that input and output have the same tensor dimensions}}
+ %0 = arith.extsi %arg0 : tensor<4xi32> to tensor<?xi64>
+ return
+}
+
+// -----
+
func.func @extf_scalable_to_fl(%arg0 : vector<[4]xf32>) {
// expected-error at +1 {{'arith.extf' op requires the same shape for all operands and results}}
%0 = arith.extf %arg0 : vector<[4]xf32> to vector<4xf64>
@@ -703,6 +719,22 @@ func.func @bitcast_scalable_to_fl(%arg0 : vector<[4]xf32>) {
// -----
+func.func @bitcast_tensor_dim(%arg0 : tensor<4xf32>) {
+ // expected-error at +1 {{'arith.bitcast' op failed to verify that input and output have the same tensor dimensions}}
+ %0 = arith.bitcast %arg0 : tensor<4xf32> to tensor<?xi32>
+ return
+}
+
+// -----
+
+func.func @bitcast_tensor_dim(%arg0 : tensor<?xf32>) {
+ // expected-error at +1 {{'arith.bitcast' op failed to verify that input and output have the same tensor dimensions}}
+ %0 = arith.bitcast %arg0 : tensor<?xf32> to tensor<4xi32>
+ return
+}
+
+// -----
+
func.func @trunci_fl_to_scalable(%arg0 : vector<4xi32>) {
// expected-error at +1 {{'arith.trunci' op requires the same shape for all operands and results}}
%0 = arith.trunci %arg0 : vector<4xi32> to vector<[4]xi8>
@@ -719,6 +751,14 @@ func.func @truncf_fl_to_scalable(%arg0 : vector<4xf64>) {
// -----
+func.func @truncf_tensor_dim(%arg0 : tensor<4xf64>) {
+ // expected-error at +1 {{'arith.truncf' op failed to verify that input and output have the same tensor dimensions}}
+ %0 = arith.truncf %arg0 : tensor<4xf64> to tensor<?xf32>
+ return
+}
+
+// -----
+
func.func @extui_fl_to_scalable(%arg0 : vector<4xi32>) {
// expected-error at +1 {{'arith.extui' op requires the same shape for all operands and results}}
%0 = arith.extui %arg0 : vector<4xi32> to vector<[4]xi64>
More information about the Mlir-commits
mailing list