[Mlir-commits] [mlir] [mlir] Add unit test for RankedTensorType wrapper class example. (PR #99789)

Jacques Pienaar llvmlistbot at llvm.org
Mon Jul 22 12:50:01 PDT 2024


https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/99789

>From 1cac243300bebeb2c7c695b218ed57b033b89ce4 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Sat, 20 Jul 2024 22:39:19 +0000
Subject: [PATCH] [mlir] Add unit test for RankedTensorType wrapper class
 example.

Add example as unit test for creating a "RankedTensorType with encoding"
view. This view provides a more typed API to the encoding while it
allows one to avoid repeated dyn_cast queries and accessing the encoding
directly.

For users with more advance encodings, the expectation would be a
separate attribute type, but here just StringAttr is used.
---
 mlir/unittests/IR/ShapedTypeTest.cpp | 58 ++++++++++++++++++++++++++++
 1 file changed, 58 insertions(+)

diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 61264bc523648..7a5b0722a03ba 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -11,6 +11,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectInterface.h"
+#include "mlir/Support/LLVM.h"
 #include "llvm/ADT/SmallVector.h"
 #include "gtest/gtest.h"
 #include <cstdint>
@@ -226,4 +227,61 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
   }
 }
 
+/// Simple wrapper class to enable "isa querying" and simple accessing of
+/// encoding.
+class TensorWithString : public RankedTensorType {
+public:
+  using RankedTensorType::RankedTensorType;
+
+  static TensorWithString get(ArrayRef<int64_t> shape, Type elementType,
+                              StringRef name) {
+    return mlir::cast<TensorWithString>(RankedTensorType::get(
+        shape, elementType, StringAttr::get(elementType.getContext(), name)));
+  }
+
+  StringRef getName() const {
+    if (Attribute enc = getEncoding())
+      return mlir::cast<StringAttr>(enc).getValue();
+    return {};
+  }
+
+  static bool classof(Type type) {
+    if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type))
+      return mlir::isa_and_present<StringAttr>(rt.getEncoding());
+    return false;
+  }
+};
+
+TEST(ShapedTypeTest, RankedTensorTypeView) {
+  MLIRContext context;
+  Type f32 = FloatType::getF32(&context);
+
+  Type noEncodingRankedTensorType = RankedTensorType::get({10, 20}, f32);
+
+  UnitAttr unitAttr = UnitAttr::get(&context);
+  Type unitEncodingRankedTensorType =
+      RankedTensorType::get({10, 20}, f32, unitAttr);
+
+  StringAttr stringAttr = StringAttr::get(&context, "app");
+  Type stringEncodingRankedTensorType =
+      RankedTensorType::get({10, 20}, f32, stringAttr);
+
+  EXPECT_FALSE(mlir::isa<TensorWithString>(noEncodingRankedTensorType));
+  EXPECT_FALSE(mlir::isa<TensorWithString>(unitEncodingRankedTensorType));
+  ASSERT_TRUE(mlir::isa<TensorWithString>(stringEncodingRankedTensorType));
+
+  // Cast to TensorWithString view.
+  auto view = mlir::cast<TensorWithString>(stringEncodingRankedTensorType);
+  ASSERT_TRUE(mlir::isa<TensorWithString>(view));
+  EXPECT_EQ(view.getName(), "app");
+  // Verify one could cast view type back to base type.
+  ASSERT_TRUE(mlir::isa<RankedTensorType>(view));
+
+  Type viewCreated = TensorWithString::get({10, 20}, f32, "bob");
+  ASSERT_TRUE(mlir::isa<TensorWithString>(viewCreated));
+  ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
+  view = mlir::cast<TensorWithString>(viewCreated);
+  EXPECT_EQ(view.getName(), "bob");
+}
+
 } // namespace



More information about the Mlir-commits mailing list