[Mlir-commits] [mlir] [mlir] Use new VectorType wrappers CommonTypeConstraints.td (PR #118645)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 4 06:32:40 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

As a follow-on for #<!-- -->87986, moves the VectorType convenience wrappers
(`FixedVectorType` and `ScalableVectorType`) to BuiltinTypes.h. This
allows us to use the new wrappers in "CommonTypeConstraints.td".


---
Full diff: https://github.com/llvm/llvm-project/pull/118645.diff


4 Files Affected:

- (modified) mlir/include/mlir/IR/BuiltinTypes.h (+31) 
- (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+4-7) 
- (modified) mlir/include/mlir/IR/VectorTypes.h (-39) 
- (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+1-3) 


``````````diff
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 25535408f4528a..f2bedb512c3dff 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -401,6 +401,37 @@ enum class SliceVerificationResult {
 SliceVerificationResult isRankReducedType(ShapedType originalType,
                                           ShapedType candidateReducedType);
 
+//===----------------------------------------------------------------------===//
+// Convenience wrappers for VectorType
+//
+// These are provided to allow idiomatic code like:
+//  * isa<vector::ScalableVectorType>(type)
+//===----------------------------------------------------------------------===//
+/// A vector type containing at least one scalable dimension.
+class ScalableVectorType : public VectorType {
+public:
+  using VectorType::VectorType;
+
+  static bool classof(Type type) {
+    auto vecTy = llvm::dyn_cast<VectorType>(type);
+    if (!vecTy)
+      return false;
+    return vecTy.isScalable();
+  }
+};
+
+/// A vector type with no scalable dimensions.
+class FixedVectorType : public VectorType {
+public:
+  using VectorType::VectorType;
+  static bool classof(Type type) {
+    auto vecTy = llvm::dyn_cast<VectorType>(type);
+    if (!vecTy)
+      return false;
+    return !vecTy.isScalable();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Deferred Method Definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 7db095d0ae5af6..b9f8c1ed19470d 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -24,22 +24,19 @@ include "mlir/IR/DialectBase.td"
 // Explicitly disallow 0-D vectors for now until we have good enough coverage.
 def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
                                          CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
-def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
-                                              CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
-                                              CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
+def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::FixedVectorType>($_self)">,
+                                              CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
 
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 // TODO: Remove this when all ops support 0-D vectors.
 def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
 
 // Whether a type is a fixed-length VectorType.
-def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
-                                  !::llvm::cast<VectorType>($_self).isScalable()}]>;
+def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::FixedVectorType>($_self)}]>;
 
 // Whether a type is a scalable VectorType.
 def IsVectorTypeWithAnyDimScalablePred
-        : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
-                  ::llvm::cast<VectorType>($_self).isScalable()}]>;
+        : CPred<[{::llvm::isa<::mlir::ScalableVectorType>($_self)}]>;
 
 // Whether a type is a scalable VectorType, with a single trailing scalable dimension.
 // Examples:
diff --git a/mlir/include/mlir/IR/VectorTypes.h b/mlir/include/mlir/IR/VectorTypes.h
index c209f869a579d8..1f1d0f7a306698 100644
--- a/mlir/include/mlir/IR/VectorTypes.h
+++ b/mlir/include/mlir/IR/VectorTypes.h
@@ -10,42 +10,3 @@
 //  * isa<vector::ScalableVectorType>(type)
 //
 //===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_VECTORTYPES_H
-#define MLIR_IR_VECTORTYPES_H
-
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Types.h"
-
-namespace mlir {
-namespace vector {
-
-/// A vector type containing at least one scalable dimension.
-class ScalableVectorType : public VectorType {
-public:
-  using VectorType::VectorType;
-
-  static bool classof(Type type) {
-    auto vecTy = llvm::dyn_cast<VectorType>(type);
-    if (!vecTy)
-      return false;
-    return vecTy.isScalable();
-  }
-};
-
-/// A vector type with no scalable dimensions.
-class FixedVectorType : public VectorType {
-public:
-  using VectorType::VectorType;
-  static bool classof(Type type) {
-    auto vecTy = llvm::dyn_cast<VectorType>(type);
-    if (!vecTy)
-      return false;
-    return !vecTy.isScalable();
-  }
-};
-
-} // namespace vector
-} // namespace mlir
-
-#endif // MLIR_IR_VECTORTYPES_H
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index fe7646140db7ea..5f445231b80fdf 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -21,7 +21,6 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/VectorTypes.h"
 #include "mlir/Support/LogicalResult.h"
 
 #include "llvm/ADT/APFloat.h"
@@ -226,8 +225,7 @@ LogicalResult arith::ConstantOp::verify() {
   // Note, we could relax this for vectors with 1 scalable dim, e.g.:
   //  * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
   // However, this would most likely require updating the lowerings to LLVM.
-  if (isa<vector::ScalableVectorType>(type) &&
-      !isa<SplatElementsAttr>(getValue()))
+  if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
     return emitOpError(
         "intializing scalable vectors with elements attribute is not supported"
         " unless it's a vector splat");

``````````

</details>


https://github.com/llvm/llvm-project/pull/118645


More information about the Mlir-commits mailing list