[Mlir-commits] [mlir] [mlir][arith] Disallow casting tensor dimensions (PR #93349)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 24 14:17:40 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jakub Kuderski (kuhar)
<details>
<summary>Changes</summary>
Tighten the verifier for arith cast ops to disallow changing tensor dimensions, e.g., static to dynamic. After this change:
* `arith.cast_op %x : tensor<4xi32> to tensor<4xf32>` remains valid
* `arith.cast_op %x : tensor<4xi32> to tensor<?xf32>` becomes invalid
* `arith.cast_op %x : tensor<?xi32> to tensor<4xf32>` becomes invalid
This is mostly to simplify the op semantics. See the discussion thread for more context: https://discourse.llvm.org/t/rfc-remove-arith-math-ops-on-tensors/74357/63.
---
Full diff: https://github.com/llvm/llvm-project/pull/93349.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+15-2)
- (modified) mlir/test/Dialect/Arith/canonicalize.mlir (-8)
- (modified) mlir/test/Dialect/Arith/invalid.mlir (+25-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 4e4c6fd601777..bdf264aec1d5d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -83,12 +83,25 @@ class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
attr-dict `:` type($result) }];
}
+// Checks that tensor input and outputs have identical shapes. This is stricker
+// than the verification done in `SameOperandsAndResultShape` that allows for
+// tensor dimensions to be 'compatible' (e.g., dynamic dimensions being
+// compatible with static ones).
+def SameInputOutputTensorDims : PredOpTrait<
+ "input and output have the same tensor dimensions",
+ AllMatchSameOperatorPred<["in", "out"],
+ "(::llvm::isa<::mlir::TensorType>($_self.getType()) ?"
+ " ::llvm::cast<::mlir::TensorType>($_self.getType()).getShape() :"
+ " ::llvm::ArrayRef<int64_t>{})">>;
+
// Base class for arithmetic cast operations. Requires a single operand and
-// result. If either is a shaped type, then the other must be of the same shape.
+// result. If either is a shaped type, then the other must be of the same
+// shape. In the case of tensor types, this also includes the corresponding
+// operand/result dimensions being equal.
class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
list<Trait> traits = []> :
Arith_Op<mnemonic, traits # [Pure, SameOperandsAndResultShape,
- DeclareOpInterfaceMethods<CastOpInterface>]>,
+ SameInputOutputTensorDims, DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins From:$in)>,
Results<(outs To:$out)> {
let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)";
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 1a387c20c4b29..e4f95bb0545a2 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2950,14 +2950,6 @@ func.func @unsignedExtendConstantResource() -> tensor<i16> {
return %ext : tensor<i16>
}
-// Just checks that this doesn't crash.
-// CHECK-LABEL: @signedExtendSplatAsDynamicShape
-func.func @signedExtendSplatAsDynamicShape() -> tensor<?xi64> {
- %splat = arith.constant dense<5> : tensor<2xi16>
- %extsplat = arith.extsi %splat : tensor<2xi16> to tensor<?xi64>
- return %extsplat : tensor<?xi64>
-}
-
// CHECK-LABEL: @extsi_i0
// CHECK: %[[ZERO:.*]] = arith.constant 0 : i16
// CHECK: return %[[ZERO]] : i16
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index ada849220bb83..a3cfb6baa2e1d 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -1,13 +1,21 @@
// RUN: mlir-opt -split-input-file %s -verify-diagnostics
func.func @test_index_cast_shape_error(%arg0 : tensor<index>) -> tensor<2xi64> {
- // expected-error @+1 {{'arith.index_cast' op requires the same shape for all operands and results}}
+ // expected-error @+1 {{'arith.index_cast' op failed to verify that input and output have the same tensor dimensions}}
%0 = arith.index_cast %arg0 : tensor<index> to tensor<2xi64>
return %0 : tensor<2xi64>
}
// -----
+func.func @test_index_cast_shape_dim_error(%arg0 : tensor<2xindex>) -> tensor<?xi64> {
+ // expected-error @+1 {{'arith.index_cast' op failed to verify that input and output have the same tensor dimensions}}
+ %0 = arith.index_cast %arg0 : tensor<2xindex> to tensor<?xi64>
+ return %0 : tensor<?xi64>
+}
+
+// -----
+
func.func @test_index_cast_tensor_error(%arg0 : tensor<index>) -> i64 {
// expected-error @+1 {{'arith.index_cast' op requires the same shape for all operands and results}}
%0 = arith.index_cast %arg0 : tensor<index> to i64
@@ -655,6 +663,14 @@ func.func @extsi_scalable_to_fl(%arg0 : vector<[4]xi32>) {
// -----
+func.func @extsi_tesor_dim(%arg0 : tensor<4xi32>) {
+ // expected-error at +1 {{'arith.extsi' op failed to verify that input and output have the same tensor dimensions}}
+ %0 = arith.extsi %arg0 : tensor<4xi32> to tensor<?xi64>
+ return
+}
+
+// -----
+
func.func @extf_scalable_to_fl(%arg0 : vector<[4]xf32>) {
// expected-error at +1 {{'arith.extf' op requires the same shape for all operands and results}}
%0 = arith.extf %arg0 : vector<[4]xf32> to vector<4xf64>
@@ -703,6 +719,14 @@ func.func @bitcast_scalable_to_fl(%arg0 : vector<[4]xf32>) {
// -----
+func.func @bitcast_tensor_dim(%arg0 : tensor<4xf32>) {
+ // expected-error at +1 {{'arith.bitcast' op failed to verify that input and output have the same tensor dimensions}}
+ %0 = arith.bitcast %arg0 : tensor<4xf32> to tensor<?xi32>
+ return
+}
+
+// -----
+
func.func @trunci_fl_to_scalable(%arg0 : vector<4xi32>) {
// expected-error at +1 {{'arith.trunci' op requires the same shape for all operands and results}}
%0 = arith.trunci %arg0 : vector<4xi32> to vector<[4]xi8>
``````````
</details>
https://github.com/llvm/llvm-project/pull/93349
More information about the Mlir-commits
mailing list