[Mlir-commits] [mlir] [mlir][vector] Add convenience types for scalable vectors (PR #87986)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Apr 16 08:52:21 PDT 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/87986

>From ecdb3ecd720d1a8f3e5952d14cd58fe4e7e3959b Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 8 Apr 2024 11:32:13 +0100
Subject: [PATCH 1/2] [mlir][vector] Add convenience types for scalable vectors

This PR adds two small convenience Vector types:

  * `ScalableVectorType` and `FixedWidthVectorType`.

The goal of these new types is two-fold:
  * enable idiomatic checks like `isa<ScalableVectorType>(...)`,
  * make the split into "Scalable" and "Fixed-wdith" vectors a bit more
    explicit and more visible in the code-base.

Depends on #87999
---
 mlir/include/mlir/IR/VectorTypes.h     | 35 ++++++++++++++++++++++++++
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp |  5 ++--
 mlir/lib/IR/CMakeLists.txt             |  1 +
 mlir/lib/IR/VectorTypes.cpp            | 27 ++++++++++++++++++++
 4 files changed, 66 insertions(+), 2 deletions(-)
 create mode 100644 mlir/include/mlir/IR/VectorTypes.h
 create mode 100644 mlir/lib/IR/VectorTypes.cpp

diff --git a/mlir/include/mlir/IR/VectorTypes.h b/mlir/include/mlir/IR/VectorTypes.h
new file mode 100644
index 00000000000000..384969779d8f6d
--- /dev/null
+++ b/mlir/include/mlir/IR/VectorTypes.h
@@ -0,0 +1,35 @@
+//===- VectorTypes.h - MLIR Vector Types ------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_IR_VECTORTYPES_H_
+#define MLIR_DIALECT_VECTOR_IR_VECTORTYPES_H_
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+namespace vector {
+
+class ScalableVectorType : public VectorType {
+public:
+  using VectorType::VectorType;
+
+  static bool classof(Type type);
+};
+
+class FixedWidthVectorType : public VectorType {
+public:
+  using VectorType::VectorType;
+  static bool classof(Type type);
+};
+
+} // namespace vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_VECTOR_IR_VECTORTYPES_H_
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 6f995b93bc3ecd..d017e44d516421 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -21,6 +21,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/VectorTypes.h"
 #include "mlir/Support/LogicalResult.h"
 
 #include "llvm/ADT/APFloat.h"
@@ -217,8 +218,8 @@ LogicalResult arith::ConstantOp::verify() {
   // Note, we could relax this for vectors with 1 scalable dim, e.g.:
   //  * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32>
   // However, this would most likely require updating the lowerings to LLVM.
-  auto vecType = dyn_cast<VectorType>(type);
-  if (vecType && vecType.isScalable() && !isa<SplatElementsAttr>(getValue()))
+  if (isa<vector::ScalableVectorType>(type) &&
+      !isa<SplatElementsAttr>(getValue()))
     return emitOpError(
         "intializing scalable vectors with elements attribute is not supported"
         " unless it's a vector splat");
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index c38ce6c058a006..03f89db18430a8 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -40,6 +40,7 @@ add_mlir_library(MLIRIR
   Unit.cpp
   Value.cpp
   ValueRange.cpp
+  VectorTypes.cpp
   Verifier.cpp
   Visitors.cpp
   ${pdl_src}
diff --git a/mlir/lib/IR/VectorTypes.cpp b/mlir/lib/IR/VectorTypes.cpp
new file mode 100644
index 00000000000000..4ce1d606a3f2f4
--- /dev/null
+++ b/mlir/lib/IR/VectorTypes.cpp
@@ -0,0 +1,27 @@
+//===- VectorTypes.cpp - MLIR Vector Types --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/VectorTypes.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+bool ScalableVectorType::classof(Type type) {
+  auto vecTy = dyn_cast<VectorType>(type);
+  if (!vecTy)
+    return false;
+  return vecTy.isScalable();
+}
+
+bool FixedWidthVectorType::classof(Type type) {
+  auto vecTy = llvm::dyn_cast<VectorType>(type);
+  if (!vecTy)
+    return false;
+  return !vecTy.isScalable();
+}

>From cc17e63de4e7f41319c785bdee401e15d9c35cf8 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 16 Apr 2024 16:32:15 +0100
Subject: [PATCH 2/2] fixup! [mlir][vector] Add convenience types for scalable
 vectors

Inline implementation, add comment.
---
 mlir/include/mlir/IR/VectorTypes.h | 16 ++++++++++++++--
 mlir/lib/IR/CMakeLists.txt         |  1 -
 mlir/lib/IR/VectorTypes.cpp        | 27 ---------------------------
 3 files changed, 14 insertions(+), 30 deletions(-)
 delete mode 100644 mlir/lib/IR/VectorTypes.cpp

diff --git a/mlir/include/mlir/IR/VectorTypes.h b/mlir/include/mlir/IR/VectorTypes.h
index 384969779d8f6d..769daa8f9417d1 100644
--- a/mlir/include/mlir/IR/VectorTypes.h
+++ b/mlir/include/mlir/IR/VectorTypes.h
@@ -16,17 +16,29 @@
 namespace mlir {
 namespace vector {
 
+/// A vector type containing at least one scalable dimension
 class ScalableVectorType : public VectorType {
 public:
   using VectorType::VectorType;
 
-  static bool classof(Type type);
+  static bool classof(Type type) {
+    auto vecTy = llvm::dyn_cast<VectorType>(type);
+    if (!vecTy)
+      return false;
+    return vecTy.isScalable();
+  }
 };
 
+/// A vector type with no scalable dimensions
 class FixedWidthVectorType : public VectorType {
 public:
   using VectorType::VectorType;
-  static bool classof(Type type);
+  static bool classof(Type type) {
+    auto vecTy = llvm::dyn_cast<VectorType>(type);
+    if (!vecTy)
+      return false;
+    return !vecTy.isScalable();
+  }
 };
 
 } // namespace vector
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 03f89db18430a8..c38ce6c058a006 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -40,7 +40,6 @@ add_mlir_library(MLIRIR
   Unit.cpp
   Value.cpp
   ValueRange.cpp
-  VectorTypes.cpp
   Verifier.cpp
   Visitors.cpp
   ${pdl_src}
diff --git a/mlir/lib/IR/VectorTypes.cpp b/mlir/lib/IR/VectorTypes.cpp
deleted file mode 100644
index 4ce1d606a3f2f4..00000000000000
--- a/mlir/lib/IR/VectorTypes.cpp
+++ /dev/null
@@ -1,27 +0,0 @@
-//===- VectorTypes.cpp - MLIR Vector Types --------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/IR/VectorTypes.h"
-#include "mlir/IR/BuiltinTypes.h"
-
-using namespace mlir;
-using namespace mlir::vector;
-
-bool ScalableVectorType::classof(Type type) {
-  auto vecTy = dyn_cast<VectorType>(type);
-  if (!vecTy)
-    return false;
-  return vecTy.isScalable();
-}
-
-bool FixedWidthVectorType::classof(Type type) {
-  auto vecTy = llvm::dyn_cast<VectorType>(type);
-  if (!vecTy)
-    return false;
-  return !vecTy.isScalable();
-}



More information about the Mlir-commits mailing list