[Mlir-commits] [mlir] [mlir][vector] Add convenience types for scalable vectors (PR #87986)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Mon Apr 8 08:11:59 PDT 2024
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/87986
>From 14103cbf7dce52b46b9324b1d896a07a905a99d0 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 21 Mar 2024 19:01:12 +0000
Subject: [PATCH 1/2] [mlir][arith] Refine the verifier for arith.constant
Disallows initialization of scalable vectors with an attribute of
arbitrary values, e.g.:
```mlir
%c = arith.constant dense<[0, 1]> : vector<[2] x i32>
```
Initialization using vector splats remains allowed (i.e. when all the
init values are identical):
```mlir
%c = arith.constant dense<[1, 1]> : vector<[2] x i32>
```
Note: This is a re-upload of #86178
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 9 +++++++++
mlir/test/Dialect/Arith/invalid.mlir | 18 ++++++++++++++++++
mlir/test/Dialect/Vector/linearize.mlir | 11 -----------
3 files changed, 27 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 1d68a4f7292b53..6f995b93bc3ecd 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -213,6 +213,15 @@ LogicalResult arith::ConstantOp::verify() {
return emitOpError(
"value must be an integer, float, or elements attribute");
}
+
+ // Note, we could relax this for vectors with 1 scalable dim, e.g.:
+ // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
+ // However, this would most likely require updating the lowerings to LLVM.
+ auto vecType = dyn_cast<VectorType>(type);
+ if (vecType && vecType.isScalable() && !isa<SplatElementsAttr>(getValue()))
+ return emitOpError(
+ "intializing scalable vectors with elements attribute is not supported"
+ " unless it's a vector splat");
return success();
}
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index 6d8ac0ada52be3..ada849220bb839 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -64,6 +64,24 @@ func.func @constant_out_of_range() {
// -----
+func.func @constant_invalid_scalable_1d_vec_initialization() {
+^bb0:
+ // expected-error at +1 {{'arith.constant' op intializing scalable vectors with elements attribute is not supported unless it's a vector splat}}
+ %c = arith.constant dense<[0, 1]> : vector<[2] x i32>
+ return
+}
+
+// -----
+
+func.func @constant_invalid_scalable_2d_vec_initialization() {
+^bb0:
+ // expected-error at +1 {{'arith.constant' op intializing scalable vectors with elements attribute is not supported unless it's a vector splat}}
+ %c = arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
+ return
+}
+
+// -----
+
func.func @constant_wrong_type() {
^bb:
%x = "arith.constant"(){value = 10.} : () -> f32 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 212541c79565b6..22be78cd682057 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -153,14 +153,3 @@ func.func @test_0d_vector() -> vector<f32> {
// ALL: return %[[CST]]
return %0 : vector<f32>
}
-
-// -----
-
-func.func @test_scalable_no_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
- // expected-error at +1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
- %0 = arith.constant dense<[[1., 1.], [3., 3.]]> : vector<2x[2]xf32>
- %1 = math.sin %arg0 : vector<2x[2]xf32>
- %2 = arith.addf %0, %1 : vector<2x[2]xf32>
-
- return %2 : vector<2x[2]xf32>
-}
>From 796ebb38bb24be42921fcd30b913539b2f102b63 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 8 Apr 2024 11:32:13 +0100
Subject: [PATCH 2/2] [mlir][vector] Add convenience types for scalable vectors
This PR adds two small convenience Vector types:
* `ScalableVectorType` and `FixedWidthVectorType`.
The goal of these new types is two-fold:
* enable idiomatic checks like `isa<ScalableVectorType>(...)`,
* make the split into "Scalable" and "Fixed-wdith" vectors a bit more
explicit and more visible in the code-base.
Depends on #87999
---
mlir/include/mlir/IR/VectorTypes.h | 35 ++++++++++++++++++++++++++
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 4 +--
mlir/lib/IR/CMakeLists.txt | 1 +
mlir/lib/IR/VectorTypes.cpp | 27 ++++++++++++++++++++
4 files changed, 65 insertions(+), 2 deletions(-)
create mode 100644 mlir/include/mlir/IR/VectorTypes.h
create mode 100644 mlir/lib/IR/VectorTypes.cpp
diff --git a/mlir/include/mlir/IR/VectorTypes.h b/mlir/include/mlir/IR/VectorTypes.h
new file mode 100644
index 00000000000000..384969779d8f6d
--- /dev/null
+++ b/mlir/include/mlir/IR/VectorTypes.h
@@ -0,0 +1,35 @@
+//===- VectorTypes.h - MLIR Vector Types ------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_IR_VECTORTYPES_H_
+#define MLIR_DIALECT_VECTOR_IR_VECTORTYPES_H_
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace vector {
+
+class ScalableVectorType : public VectorType {
+public:
+ using VectorType::VectorType;
+
+ static bool classof(Type type);
+};
+
+class FixedWidthVectorType : public VectorType {
+public:
+ using VectorType::VectorType;
+ static bool classof(Type type);
+};
+
+} // namespace vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_VECTOR_IR_VECTORTYPES_H_
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 6f995b93bc3ecd..6f6c972887cdc5 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/VectorTypes.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
@@ -217,8 +218,7 @@ LogicalResult arith::ConstantOp::verify() {
// Note, we could relax this for vectors with 1 scalable dim, e.g.:
// * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
// However, this would most likely require updating the lowerings to LLVM.
- auto vecType = dyn_cast<VectorType>(type);
- if (vecType && vecType.isScalable() && !isa<SplatElementsAttr>(getValue()))
+ if (llvm::isa<vector::ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
return emitOpError(
"intializing scalable vectors with elements attribute is not supported"
" unless it's a vector splat");
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index c38ce6c058a006..03f89db18430a8 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -40,6 +40,7 @@ add_mlir_library(MLIRIR
Unit.cpp
Value.cpp
ValueRange.cpp
+ VectorTypes.cpp
Verifier.cpp
Visitors.cpp
${pdl_src}
diff --git a/mlir/lib/IR/VectorTypes.cpp b/mlir/lib/IR/VectorTypes.cpp
new file mode 100644
index 00000000000000..cd84691fac5ada
--- /dev/null
+++ b/mlir/lib/IR/VectorTypes.cpp
@@ -0,0 +1,27 @@
+//===- VectorTypes.cpp - MLIR Vector Types --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/VectorTypes.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+bool ScalableVectorType::classof(Type type) {
+ auto vecTy = llvm::dyn_cast<VectorType>(type);
+ if (!vecTy)
+ return false;
+ return vecTy.isScalable();
+}
+
+bool FixedWidthVectorType::classof(Type type) {
+ auto vecTy = llvm::dyn_cast<VectorType>(type);
+ if (!vecTy)
+ return false;
+ return !vecTy.isScalable();
+}
More information about the Mlir-commits
mailing list