[Mlir-commits] [mlir] [mlir] Use new VectorType wrappers CommonTypeConstraints.td (PR #118645)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Dec 4 08:15:24 PST 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/118645

>From 3eb2e43e8c64c9971a28961d43661ddc514af75c Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 4 Dec 2024 14:15:04 +0000
Subject: [PATCH 1/2] [mlir] Use new VectorType wrappers
 CommonTypeConstraints.td

As a follow-on for #87986, moves the VectorType convenience wrappers
(`FixedVectorType` and `ScalableVectorType`) to BuiltinTypes.h. This
allows us to use the new wrappers in "CommonTypeConstraints.td".
---
 mlir/include/mlir/IR/BuiltinTypes.h           | 31 +++++++++++++++
 mlir/include/mlir/IR/CommonTypeConstraints.td | 11 ++----
 mlir/include/mlir/IR/VectorTypes.h            | 39 -------------------
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        |  4 +-
 4 files changed, 36 insertions(+), 49 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 25535408f4528a..f2bedb512c3dff 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -401,6 +401,37 @@ enum class SliceVerificationResult {
 SliceVerificationResult isRankReducedType(ShapedType originalType,
                                           ShapedType candidateReducedType);
 
+//===----------------------------------------------------------------------===//
+// Convenience wrappers for VectorType
+//
+// These are provided to allow idiomatic code like:
+//  * isa<vector::ScalableVectorType>(type)
+//===----------------------------------------------------------------------===//
+/// A vector type containing at least one scalable dimension.
+class ScalableVectorType : public VectorType {
+public:
+  using VectorType::VectorType;
+
+  static bool classof(Type type) {
+    auto vecTy = llvm::dyn_cast<VectorType>(type);
+    if (!vecTy)
+      return false;
+    return vecTy.isScalable();
+  }
+};
+
+/// A vector type with no scalable dimensions.
+class FixedVectorType : public VectorType {
+public:
+  using VectorType::VectorType;
+  static bool classof(Type type) {
+    auto vecTy = llvm::dyn_cast<VectorType>(type);
+    if (!vecTy)
+      return false;
+    return !vecTy.isScalable();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Deferred Method Definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 7db095d0ae5af6..b9f8c1ed19470d 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -24,22 +24,19 @@ include "mlir/IR/DialectBase.td"
 // Explicitly disallow 0-D vectors for now until we have good enough coverage.
 def IsVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
                                          CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
-def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">,
-                                              CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">,
-                                              CPred<"!::llvm::cast<VectorType>($_self).isScalable()">]>;
+def IsFixedVectorOfNonZeroRankTypePred : And<[CPred<"::llvm::isa<::mlir::FixedVectorType>($_self)">,
+                                              CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>;
 
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 // TODO: Remove this when all ops support 0-D vectors.
 def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
 
 // Whether a type is a fixed-length VectorType.
-def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
-                                  !::llvm::cast<VectorType>($_self).isScalable()}]>;
+def IsFixedVectorOfAnyRankTypePred : CPred<[{::llvm::isa<::mlir::FixedVectorType>($_self)}]>;
 
 // Whether a type is a scalable VectorType.
 def IsVectorTypeWithAnyDimScalablePred
-        : CPred<[{::llvm::isa<::mlir::VectorType>($_self) &&
-                  ::llvm::cast<VectorType>($_self).isScalable()}]>;
+        : CPred<[{::llvm::isa<::mlir::ScalableVectorType>($_self)}]>;
 
 // Whether a type is a scalable VectorType, with a single trailing scalable dimension.
 // Examples:
diff --git a/mlir/include/mlir/IR/VectorTypes.h b/mlir/include/mlir/IR/VectorTypes.h
index c209f869a579d8..1f1d0f7a306698 100644
--- a/mlir/include/mlir/IR/VectorTypes.h
+++ b/mlir/include/mlir/IR/VectorTypes.h
@@ -10,42 +10,3 @@
 //  * isa<vector::ScalableVectorType>(type)
 //
 //===----------------------------------------------------------------------===//
-
-#ifndef MLIR_IR_VECTORTYPES_H
-#define MLIR_IR_VECTORTYPES_H
-
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Types.h"
-
-namespace mlir {
-namespace vector {
-
-/// A vector type containing at least one scalable dimension.
-class ScalableVectorType : public VectorType {
-public:
-  using VectorType::VectorType;
-
-  static bool classof(Type type) {
-    auto vecTy = llvm::dyn_cast<VectorType>(type);
-    if (!vecTy)
-      return false;
-    return vecTy.isScalable();
-  }
-};
-
-/// A vector type with no scalable dimensions.
-class FixedVectorType : public VectorType {
-public:
-  using VectorType::VectorType;
-  static bool classof(Type type) {
-    auto vecTy = llvm::dyn_cast<VectorType>(type);
-    if (!vecTy)
-      return false;
-    return !vecTy.isScalable();
-  }
-};
-
-} // namespace vector
-} // namespace mlir
-
-#endif // MLIR_IR_VECTORTYPES_H
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index fe7646140db7ea..5f445231b80fdf 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -21,7 +21,6 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/IR/VectorTypes.h"
 #include "mlir/Support/LogicalResult.h"
 
 #include "llvm/ADT/APFloat.h"
@@ -226,8 +225,7 @@ LogicalResult arith::ConstantOp::verify() {
   // Note, we could relax this for vectors with 1 scalable dim, e.g.:
   //  * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
   // However, this would most likely require updating the lowerings to LLVM.
-  if (isa<vector::ScalableVectorType>(type) &&
-      !isa<SplatElementsAttr>(getValue()))
+  if (isa<ScalableVectorType>(type) && !isa<SplatElementsAttr>(getValue()))
     return emitOpError(
         "intializing scalable vectors with elements attribute is not supported"
         " unless it's a vector splat");

>From 448f30bd44230344c911d28dcadfb584f86072f8 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 4 Dec 2024 16:15:07 +0000
Subject: [PATCH 2/2] fixup! [mlir] Use new VectorType wrappers
 CommonTypeConstraints.td

Add missing empty line
---
 mlir/include/mlir/IR/BuiltinTypes.h | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index f2bedb512c3dff..7f9c470ffec304 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -424,6 +424,7 @@ class ScalableVectorType : public VectorType {
 class FixedVectorType : public VectorType {
 public:
   using VectorType::VectorType;
+
   static bool classof(Type type) {
     auto vecTy = llvm::dyn_cast<VectorType>(type);
     if (!vecTy)



More information about the Mlir-commits mailing list