[Mlir-commits] [mlir] [mlir] Use new VectorType wrappers CommonTypeConstraints.td (PR #118645)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Dec 4 08:15:24 PST 2024
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/118645
>From 3eb2e43e8c64c9971a28961d43661ddc514af75c Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 4 Dec 2024 14:15:04 +0000
Subject: [PATCH 1/2] [mlir] Use new VectorType wrappers
CommonTypeConstraints.td
As a follow-on for #87986, moves the VectorType convenience wrappers
(`FixedVectorType` and `ScalableVectorType`) to BuiltinTypes.h. This
allows us to use the new wrappers in "CommonTypeConstraints.td".
---
mlir/include/mlir/IR/BuiltinTypes.h | 31 +++++++++++++++
mlir/include/mlir/IR/CommonTypeConstraints.td | 11 ++----
mlir/include/mlir/IR/VectorTypes.h | 39 -------------------
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 4 +-
4 files changed, 36 insertions(+), 49 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 25535408f4528a..f2bedb512c3dff 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -401,6 +401,37 @@ enum class SliceVerificationResult {
SliceVerificationResult isRankReducedType(ShapedType originalType,
ShapedType candidateReducedType);
+//===----------------------------------------------------------------------===//
+// Convenience wrappers for VectorType
+//
+// These are provided to allow idiomatic code like:
+// * isa<vector::ScalableVectorType>(type)
+//===----------------------------------------------------------------------===//
+/// A vector type containing at least one scalable dimension.
+class ScalableVectorType : public VectorType {
+public:
+ using VectorType::VectorType;
+
+ static bool classof(Type type) {
+ auto vecTy = llvm::dyn_cast<VectorType>(type);
+ if (!vecTy)
+ return false;
+ return vecTy.isScalable();
+ }
+};
+
+/// A vector type with no scalable dimensions.
+class FixedVectorType : public VectorType {
+public:
+ using VectorType::VectorType;
+ static bool classof(Type type) {
+ auto vecTy = llvm::dyn_cast<VectorType>(type);
+ if (!vecTy)
+ return false;
+ return !vecTy.isScalable();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Deferred Method Definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 7db095d0ae5af6..b9f8c1ed19470d 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -24,22 +24,19 @@ include "mlir/IR/DialectBase.td"
// Explicitly disallow 0-D vectors for now until we have good enough coverage.
def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
-def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
- CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
- CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
+def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::FixedVectorType>($_self)">,
+ CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
// TODO: Remove this when all ops support 0-D vectors.
def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
// Whether a type is a fixed-length VectorType.
-def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
- !::llvm::cast<VectorType>($_self).isScalable()}]>;
+def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::FixedVectorType>($_self)}]>;
// Whether a type is a scalable VectorType.
def IsVectorTypeWithAnyDimScalablePred
- : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
- ::llvm::cast<VectorType>($_self).isScalable()}]>;
+ : CPred<[{::llvm::isa<::mlir::ScalableVectorType>($_self)}]>;
// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
// Examples:
diff --git a/mlir/include/mlir/IR/VectorTypes.h b/mlir/include/mlir/IR/VectorTypes.h
index c209f869a579d8..1f1d0f7a306698 100644
--- a/mlir/include/mlir/IR/VectorTypes.h
+++ b/mlir/include/mlir/IR/VectorTypes.h
@@ -10,42 +10,3 @@
// * isa<vector::ScalableVectorType>(type)
//
//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_VECTORTYPES_H
-#define MLIR_IR_VECTORTYPES_H
-
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Types.h"
-
-namespace mlir {
-namespace vector {
-
-/// A vector type containing at least one scalable dimension.
-class ScalableVectorType : public VectorType {
-public:
- using VectorType::VectorType;
-
- static bool classof(Type type) {
- auto vecTy = llvm::dyn_cast<VectorType>(type);
- if (!vecTy)
- return false;
- return vecTy.isScalable();
- }
-};
-
-/// A vector type with no scalable dimensions.
-class FixedVectorType : public VectorType {
-public:
- using VectorType::VectorType;
- static bool classof(Type type) {
- auto vecTy = llvm::dyn_cast<VectorType>(type);
- if (!vecTy)
- return false;
- return !vecTy.isScalable();
- }
-};
-
-} // namespace vector
-} // namespace mlir
-
-#endif // MLIR_IR_VECTORTYPES_H
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index fe7646140db7ea..5f445231b80fdf 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -21,7 +21,6 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/VectorTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APFloat.h"
@@ -226,8 +225,7 @@ LogicalResult arith::ConstantOp::verify() {
// Note, we could relax this for vectors with 1 scalable dim, e.g.:
// * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
// However, this would most likely require updating the lowerings to LLVM.
- if (isa<vector::ScalableVectorType>(type) &&
- !isa<SplatElementsAttr>(getValue()))
+ if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
return emitOpError(
"intializing scalable vectors with elements attribute is not supported"
" unless it's a vector splat");
>From 448f30bd44230344c911d28dcadfb584f86072f8 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 4 Dec 2024 16:15:07 +0000
Subject: [PATCH 2/2] fixup! [mlir] Use new VectorType wrappers
CommonTypeConstraints.td
Add missing empty line
---
mlir/include/mlir/IR/BuiltinTypes.h | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index f2bedb512c3dff..7f9c470ffec304 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -424,6 +424,7 @@ class ScalableVectorType : public VectorType {
class FixedVectorType : public VectorType {
public:
using VectorType::VectorType;
+
static bool classof(Type type) {
auto vecTy = llvm::dyn_cast<VectorType>(type);
if (!vecTy)
More information about the Mlir-commits
mailing list