[Mlir-commits] [mlir] [mlir] Add unit test for RankedTensorType wrapper class example. (PR #99789)
Jacques Pienaar
llvmlistbot at llvm.org
Sat Jul 20 16:19:00 PDT 2024
https://github.com/jpienaar created https://github.com/llvm/llvm-project/pull/99789
Add example as unit test for creating a wrapper type/view for RankedTensorType with encoding. This view provides a more restricted & typed API while it allows one to avoid repeated casting queries and accessing the encoding directly.
For users with more advance encodings, the expectation would be a separate attribute type, but here just StringAttr is used.
>From a485a7d05da7a82360b6a14abf3444a60c46f18f Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Sat, 20 Jul 2024 22:39:19 +0000
Subject: [PATCH] [mlir] Add unit test for RankedTensorType wrapper class
example.
Add example as unit test for creating a "RankedTensorType with encoding"
view. This view provides a more typed API to the encoding while it
allows one to avoid repeated dyn_cast queries and accessing the encoding
directly.
For users with more advance encodings, the expectation would be a
separate attribute type, but here just StringAttr is used.
---
mlir/unittests/IR/ShapedTypeTest.cpp | 60 ++++++++++++++++++++++++++++
1 file changed, 60 insertions(+)
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 61264bc523648..66ee416d7636e 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectInterface.h"
+#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
#include "gtest/gtest.h"
#include <cstdint>
@@ -226,4 +227,63 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
}
}
+/// Simple wrapper class to enable "isa querying" and simple accessing of
+/// encoding.
+class TensorWithString : public RankedTensorType {
+public:
+ using RankedTensorType::RankedTensorType;
+
+ static TensorWithString get(ArrayRef<int64_t> shape, Type elementType,
+ StringRef name) {
+ return mlir::cast<TensorWithString>(RankedTensorType::get(
+ shape, elementType, StringAttr::get(elementType.getContext(), name)));
+ }
+
+ StringRef getName() const {
+ if (Attribute enc = getEncoding())
+ return mlir::cast<StringAttr>(enc).getValue();
+ return {};
+ }
+
+ static bool classof(Type type);
+};
+
+bool TensorWithString::classof(Type type) {
+ if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type))
+ return mlir::isa_and_present<StringAttr>(rt.getEncoding());
+ return false;
+}
+
+TEST(ShapedTypeTest, RankedTensorTypeView) {
+ MLIRContext context;
+ Type f32 = FloatType::getF32(&context);
+
+ Type noEncodingRankedTensorType = RankedTensorType::get({10, 20}, f32);
+
+ UnitAttr unitAttr = UnitAttr::get(&context);
+ Type unitEncodingRankedTensorType =
+ RankedTensorType::get({10, 20}, f32, unitAttr);
+
+ StringAttr stringAttr = StringAttr::get(&context, "app");
+ Type stringEncodingRankedTensorType =
+ RankedTensorType::get({10, 20}, f32, stringAttr);
+
+ EXPECT_FALSE(mlir::isa<TensorWithString>(noEncodingRankedTensorType));
+ EXPECT_FALSE(mlir::isa<TensorWithString>(unitEncodingRankedTensorType));
+ ASSERT_TRUE(mlir::isa<TensorWithString>(stringEncodingRankedTensorType));
+
+ // Cast to TensorWithString view.
+ auto view = mlir::cast<TensorWithString>(stringEncodingRankedTensorType);
+ ASSERT_TRUE(mlir::isa<TensorWithString>(view));
+ EXPECT_EQ(view.getName(), "app");
+ // Verify one could cast view type back to base type.
+ ASSERT_TRUE(mlir::isa<RankedTensorType>(view));
+
+ Type viewCreated = TensorWithString::get({10, 20}, f32, "bob");
+ ASSERT_TRUE(mlir::isa<TensorWithString>(viewCreated));
+ ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
+ view = mlir::cast<TensorWithString>(viewCreated);
+ EXPECT_EQ(view.getName(), "bob");
+}
+
} // namespace
More information about the Mlir-commits
mailing list