[Mlir-commits] [mlir] [mlir][arith] Refine the verifier for arith.constant (PR #87999)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Apr 8 07:31:24 PDT 2024


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/87999

Disallows initialization of scalable vectors with an attribute of
arbitrary values, e.g.:
```mlir
  %c = arith.constant dense<[0, 1]> : vector<[2] x i32>
```

Initialization using vector splats remains allowed (i.e. when all the
init values are identical):
```mlir
  %c = arith.constant dense<[1, 1]> : vector<[2] x i32>
```

Note: This is a re-upload of #86178


>From 14103cbf7dce52b46b9324b1d896a07a905a99d0 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 21 Mar 2024 19:01:12 +0000
Subject: [PATCH] [mlir][arith] Refine the verifier for arith.constant

Disallows initialization of scalable vectors with an attribute of
arbitrary values, e.g.:
```mlir
  %c = arith.constant dense<[0, 1]> : vector<[2] x i32>
```

Initialization using vector splats remains allowed (i.e. when all the
init values are identical):
```mlir
  %c = arith.constant dense<[1, 1]> : vector<[2] x i32>
```

Note: This is a re-upload of #86178
---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp  |  9 +++++++++
 mlir/test/Dialect/Arith/invalid.mlir    | 18 ++++++++++++++++++
 mlir/test/Dialect/Vector/linearize.mlir | 11 -----------
 3 files changed, 27 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 1d68a4f7292b53..6f995b93bc3ecd 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -213,6 +213,15 @@ LogicalResult arith::ConstantOp::verify() {
     return emitOpError(
         "value must be an integer, float, or elements attribute");
   }
+
+  // Note, we could relax this for vectors with 1 scalable dim, e.g.:
+  //  * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
+  // However, this would most likely require updating the lowerings to LLVM.
+  auto vecType = dyn_cast<VectorType>(type);
+  if (vecType && vecType.isScalable() && !isa<SplatElementsAttr>(getValue()))
+    return emitOpError(
+        "intializing scalable vectors with elements attribute is not supported"
+        " unless it's a vector splat");
   return success();
 }
 
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 6d8ac0ada52be3..ada849220bb839 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -64,6 +64,24 @@ func.func @constant_out_of_range() {
 
 // -----
 
+func.func @constant_invalid_scalable_1d_vec_initialization() {
+^bb0:
+  // expected-error at +1 {{'arith.constant' op intializing scalable vectors with elements attribute is not supported unless it's a vector splat}}
+  %c = arith.constant dense<[0, 1]> : vector<[2] x i32>
+  return
+}
+
+// -----
+
+func.func @constant_invalid_scalable_2d_vec_initialization() {
+^bb0:
+  // expected-error at +1 {{'arith.constant' op intializing scalable vectors with elements attribute is not supported unless it's a vector splat}}
+  %c = arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
+  return
+}
+
+// -----
+
 func.func @constant_wrong_type() {
 ^bb:
   %x = "arith.constant"(){value = 10.} : () -> f32 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 212541c79565b6..22be78cd682057 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -153,14 +153,3 @@ func.func @test_0d_vector() -> vector<f32> {
   // ALL: return %[[CST]]
   return %0 : vector<f32>
 }
-
-// -----
-
-func.func @test_scalable_no_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
-  // expected-error at +1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
-  %0 = arith.constant dense<[[1., 1.], [3., 3.]]> : vector<2x[2]xf32>
-  %1 = math.sin %arg0 : vector<2x[2]xf32>
-  %2 = arith.addf %0, %1 : vector<2x[2]xf32>
-
-  return %2 : vector<2x[2]xf32>
-}



More information about the Mlir-commits mailing list