[Mlir-commits] [mlir] [mlir][arith] Disallow casting tensor dimensions (PR #93349)

Jakub Kuderski llvmlistbot at llvm.org
Fri May 24 14:17:11 PDT 2024


https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/93349

Tighten the verifier for arith cast ops to disallow changing tensor dimensions, e.g., static to dynamic. After this change:
* `arith.cast_op %x : tensor<4xi32> to tensor<4xf32>` remains valid
* `arith.cast_op %x : tensor<4xi32> to tensor<?xf32>` becomes invalid
* `arith.cast_op %x : tensor<?xi32> to tensor<4xf32>` becomes invalid

This is mostly to simplify the op semantics. See the discussion thread for more context: https://discourse.llvm.org/t/rfc-remove-arith-math-ops-on-tensors/74357/63.

>From b5248f87062ba2612ff94c3953ccb6d6f12972f1 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Fri, 24 May 2024 17:09:43 -0400
Subject: [PATCH] [mlir][arith] Disallow casting tensor dimensions

Tighten the verifier for arith cast ops to disallow changing tensor
dimensions, e.g., static to dynamic. After this change:
* `arith.cast_op %x : tensor<4xi32> to tensor<4xf32>` remains valid
* `arith.cast_op %x : tensor<4xi32> to tensor<?xf32>` becomes invalid
* `arith.cast_op %x : tensor<?xi32> to tensor<4xf32>` becomes invalid

This is mostly to simplify the op semantics. See the discussion thread
for more context: https://discourse.llvm.org/t/rfc-remove-arith-math-ops-on-tensors/74357/63.
---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 17 ++++++++++--
 mlir/test/Dialect/Arith/canonicalize.mlir     |  8 ------
 mlir/test/Dialect/Arith/invalid.mlir          | 26 ++++++++++++++++++-
 3 files changed, 40 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 4e4c6fd601777..bdf264aec1d5d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -83,12 +83,25 @@ class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
                           attr-dict `:` type($result) }];
 }
 
+// Checks that tensor input and outputs have identical shapes. This is stricker
+// than the verification done in `SameOperandsAndResultShape` that allows for
+// tensor dimensions to be 'compatible' (e.g., dynamic dimensions being
+// compatible with static ones).
+def SameInputOutputTensorDims : PredOpTrait<
+    "input and output have the same tensor dimensions",
+    AllMatchSameOperatorPred<["in", "out"],
+      "(::llvm::isa<::mlir::TensorType>($_self.getType()) ?"
+      " ::llvm::cast<::mlir::TensorType>($_self.getType()).getShape() :"
+      " ::llvm::ArrayRef<int64_t>{})">>;
+
 // Base class for arithmetic cast operations. Requires a single operand and
-// result. If either is a shaped type, then the other must be of the same shape.
+// result. If either is a shaped type, then the other must be of the same
+// shape.  In the case of tensor types, this also includes the corresponding
+// operand/result dimensions being equal.
 class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
                    list<Trait> traits = []> :
     Arith_Op<mnemonic, traits # [Pure, SameOperandsAndResultShape,
-      DeclareOpInterfaceMethods<CastOpInterface>]>,
+      SameInputOutputTensorDims, DeclareOpInterfaceMethods<CastOpInterface>]>,
     Arguments<(ins From:$in)>,
     Results<(outs To:$out)> {
   let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)";
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 1a387c20c4b29..e4f95bb0545a2 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2950,14 +2950,6 @@ func.func @unsignedExtendConstantResource() -> tensor<i16> {
   return %ext : tensor<i16>
 }
 
-// Just checks that this doesn't crash.
-// CHECK-LABEL: @signedExtendSplatAsDynamicShape
-func.func @signedExtendSplatAsDynamicShape() -> tensor<?xi64> {
-  %splat = arith.constant dense<5> : tensor<2xi16>
-  %extsplat = arith.extsi %splat : tensor<2xi16> to tensor<?xi64>
-  return %extsplat : tensor<?xi64>
-}
-
 // CHECK-LABEL: @extsi_i0
 //       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i16
 //       CHECK:   return %[[ZERO]] : i16
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index ada849220bb83..a3cfb6baa2e1d 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -1,13 +1,21 @@
 // RUN: mlir-opt -split-input-file %s -verify-diagnostics
 
 func.func @test_index_cast_shape_error(%arg0 : tensor<index>) -> tensor<2xi64> {
-  // expected-error @+1 {{'arith.index_cast' op requires the same shape for all operands and results}}
+  // expected-error @+1 {{'arith.index_cast' op failed to verify that input and output have the same tensor dimensions}}
   %0 = arith.index_cast %arg0 : tensor<index> to tensor<2xi64>
   return %0 : tensor<2xi64>
 }
 
 // -----
 
+func.func @test_index_cast_shape_dim_error(%arg0 : tensor<2xindex>) -> tensor<?xi64> {
+  // expected-error @+1 {{'arith.index_cast' op failed to verify that input and output have the same tensor dimensions}}
+  %0 = arith.index_cast %arg0 : tensor<2xindex> to tensor<?xi64>
+  return %0 : tensor<?xi64>
+}
+
+// -----
+
 func.func @test_index_cast_tensor_error(%arg0 : tensor<index>) -> i64 {
   // expected-error @+1 {{'arith.index_cast' op requires the same shape for all operands and results}}
   %0 = arith.index_cast %arg0 : tensor<index> to i64
@@ -655,6 +663,14 @@ func.func @extsi_scalable_to_fl(%arg0 : vector<[4]xi32>) {
 
 // -----
 
+func.func @extsi_tesor_dim(%arg0 : tensor<4xi32>) {
+  // expected-error at +1 {{'arith.extsi' op failed to verify that input and output have the same tensor dimensions}}
+  %0 = arith.extsi %arg0 : tensor<4xi32> to tensor<?xi64>
+  return
+}
+
+// -----
+
 func.func @extf_scalable_to_fl(%arg0 : vector<[4]xf32>) {
   // expected-error at +1 {{'arith.extf' op requires the same shape for all operands and results}}
   %0 = arith.extf %arg0 : vector<[4]xf32> to vector<4xf64>
@@ -703,6 +719,14 @@ func.func @bitcast_scalable_to_fl(%arg0 : vector<[4]xf32>) {
 
 // -----
 
+func.func @bitcast_tensor_dim(%arg0 : tensor<4xf32>) {
+  // expected-error at +1 {{'arith.bitcast' op failed to verify that input and output have the same tensor dimensions}}
+  %0 = arith.bitcast %arg0 : tensor<4xf32> to tensor<?xi32>
+  return
+}
+
+// -----
+
 func.func @trunci_fl_to_scalable(%arg0 : vector<4xi32>) {
   // expected-error at +1 {{'arith.trunci' op requires the same shape for all operands and results}}
   %0 = arith.trunci %arg0 : vector<4xi32> to vector<[4]xi8>



More information about the Mlir-commits mailing list