[Mlir-commits] [mlir] [mlir][vector] Add convenience types for scalable vectors (PR #87986)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 8 06:01:46 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-arith

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

This PR adds two small convenience Vector types:

  * `ScalableVectorType` and `FixedWidthVectorType`.

The goal of these new types is two-fold:
  * enable idiomatic checks like `isa<ScalableVectorType>(...)`,
  * make the split into "Scalable" and "Fixed-wdith" vectors a bit more
    explicit and more visible in the code-base.


---
Full diff: https://github.com/llvm/llvm-project/pull/87986.diff


4 Files Affected:

- (added) mlir/include/mlir/Dialect/Vector/IR/VectorTypes.h (+35) 
- (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+3-2) 
- (modified) mlir/lib/Dialect/Vector/IR/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/Vector/IR/VectorTypes.cpp (+27) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorTypes.h b/mlir/include/mlir/Dialect/Vector/IR/VectorTypes.h
new file mode 100644
index 00000000000000..384969779d8f6d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorTypes.h
@@ -0,0 +1,35 @@
+//===- VectorTypes.h - MLIR Vector Types ------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_IR_VECTORTYPES_H_
+#define MLIR_DIALECT_VECTOR_IR_VECTORTYPES_H_
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace vector {
+
+class ScalableVectorType : public VectorType {
+public:
+  using VectorType::VectorType;
+
+  static bool classof(Type type);
+};
+
+class FixedWidthVectorType : public VectorType {
+public:
+  using VectorType::VectorType;
+  static bool classof(Type type);
+};
+
+} // namespace vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_VECTOR_IR_VECTORTYPES_H_
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index efc4bfe622d53a..b66337cb07bacf 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/CommonFolders.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/VectorTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -214,8 +215,8 @@ LogicalResult arith::ConstantOp::verify() {
         "value must be an integer, float, or elements attribute");
   }
 
-  auto vecType = dyn_cast<VectorType>(type);
-  if (vecType && vecType.isScalable() && !isa<SplatElementsAttr>(getValue()))
+  if (isa<vector::ScalableVectorType>(type) &&
+      !isa<SplatElementsAttr>(getValue()))
     return emitOpError(
         "intializing scalable vectors with elements attribute is not supported"
         " unless it's a vector splat");
diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
index 204462ffd047c6..6638feae1e140f 100644
--- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRVectorDialect
   VectorOps.cpp
+  VectorTypes.cpp
   ValueBoundsOpInterfaceImpl.cpp
   ScalableValueBoundsConstraintSet.cpp
 
diff --git a/mlir/lib/Dialect/Vector/IR/VectorTypes.cpp b/mlir/lib/Dialect/Vector/IR/VectorTypes.cpp
new file mode 100644
index 00000000000000..439040e73938d8
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/IR/VectorTypes.cpp
@@ -0,0 +1,27 @@
+//===- VectorTypes.cpp - MLIR Vector Types --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorTypes.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+bool ScalableVectorType::classof(Type type) {
+  auto vecTy = llvm::dyn_cast<VectorType>(type);
+  if (!vecTy)
+    return false;
+  return vecTy.isScalable();
+}
+
+bool FixedWidthVectorType::classof(Type type) {
+  auto vecTy = llvm::dyn_cast<VectorType>(type);
+  if (!vecTy)
+    return false;
+  return !vecTy.isScalable();
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/87986


More information about the Mlir-commits mailing list