[Mlir-commits] [mlir] [mlir] Use new VectorType wrappers CommonTypeConstraints.td (PR #118645)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 4 06:32:40 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-arith
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
As a follow-on for #<!-- -->87986, moves the VectorType convenience wrappers
(`FixedVectorType` and `ScalableVectorType`) to BuiltinTypes.h. This
allows us to use the new wrappers in "CommonTypeConstraints.td".
---
Full diff: https://github.com/llvm/llvm-project/pull/118645.diff
4 Files Affected:
- (modified) mlir/include/mlir/IR/BuiltinTypes.h (+31)
- (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+4-7)
- (modified) mlir/include/mlir/IR/VectorTypes.h (-39)
- (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+1-3)
``````````diff
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 25535408f4528a..f2bedb512c3dff 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -401,6 +401,37 @@ enum class SliceVerificationResult {
SliceVerificationResult isRankReducedType(ShapedType originalType,
ShapedType candidateReducedType);
+//===----------------------------------------------------------------------===//
+// Convenience wrappers for VectorType
+//
+// These are provided to allow idiomatic code like:
+// * isa<vector::ScalableVectorType>(type)
+//===----------------------------------------------------------------------===//
+/// A vector type containing at least one scalable dimension.
+class ScalableVectorType : public VectorType {
+public:
+ using VectorType::VectorType;
+
+ static bool classof(Type type) {
+ auto vecTy = llvm::dyn_cast<VectorType>(type);
+ if (!vecTy)
+ return false;
+ return vecTy.isScalable();
+ }
+};
+
+/// A vector type with no scalable dimensions.
+class FixedVectorType : public VectorType {
+public:
+ using VectorType::VectorType;
+ static bool classof(Type type) {
+ auto vecTy = llvm::dyn_cast<VectorType>(type);
+ if (!vecTy)
+ return false;
+ return !vecTy.isScalable();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Deferred Method Definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 7db095d0ae5af6..b9f8c1ed19470d 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -24,22 +24,19 @@ include "mlir/IR/DialectBase.td"
// Explicitly disallow 0-D vectors for now until we have good enough coverage.
def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
-def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
- CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
- CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
+def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::FixedVectorType>($_self)">,
+ CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
// TODO: Remove this when all ops support 0-D vectors.
def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
// Whether a type is a fixed-length VectorType.
-def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
- !::llvm::cast<VectorType>($_self).isScalable()}]>;
+def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::FixedVectorType>($_self)}]>;
// Whether a type is a scalable VectorType.
def IsVectorTypeWithAnyDimScalablePred
- : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
- ::llvm::cast<VectorType>($_self).isScalable()}]>;
+ : CPred<[{::llvm::isa<::mlir::ScalableVectorType>($_self)}]>;
// Whether a type is a scalable VectorType, with a single trailing scalable dimension.
// Examples:
diff --git a/mlir/include/mlir/IR/VectorTypes.h b/mlir/include/mlir/IR/VectorTypes.h
index c209f869a579d8..1f1d0f7a306698 100644
--- a/mlir/include/mlir/IR/VectorTypes.h
+++ b/mlir/include/mlir/IR/VectorTypes.h
@@ -10,42 +10,3 @@
// * isa<vector::ScalableVectorType>(type)
//
//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_VECTORTYPES_H
-#define MLIR_IR_VECTORTYPES_H
-
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Types.h"
-
-namespace mlir {
-namespace vector {
-
-/// A vector type containing at least one scalable dimension.
-class ScalableVectorType : public VectorType {
-public:
- using VectorType::VectorType;
-
- static bool classof(Type type) {
- auto vecTy = llvm::dyn_cast<VectorType>(type);
- if (!vecTy)
- return false;
- return vecTy.isScalable();
- }
-};
-
-/// A vector type with no scalable dimensions.
-class FixedVectorType : public VectorType {
-public:
- using VectorType::VectorType;
- static bool classof(Type type) {
- auto vecTy = llvm::dyn_cast<VectorType>(type);
- if (!vecTy)
- return false;
- return !vecTy.isScalable();
- }
-};
-
-} // namespace vector
-} // namespace mlir
-
-#endif // MLIR_IR_VECTORTYPES_H
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index fe7646140db7ea..5f445231b80fdf 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -21,7 +21,6 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/VectorTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APFloat.h"
@@ -226,8 +225,7 @@ LogicalResult arith::ConstantOp::verify() {
// Note, we could relax this for vectors with 1 scalable dim, e.g.:
// * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
// However, this would most likely require updating the lowerings to LLVM.
- if (isa<vector::ScalableVectorType>(type) &&
- !isa<SplatElementsAttr>(getValue()))
+ if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
return emitOpError(
"intializing scalable vectors with elements attribute is not supported"
" unless it's a vector splat");
``````````
</details>
https://github.com/llvm/llvm-project/pull/118645
More information about the Mlir-commits
mailing list