[Mlir-commits] [mlir] [mlir] Add unit test for RankedTensorType wrapper class example. (PR #99789)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jul 20 16:19:28 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jacques Pienaar (jpienaar)

<details>
<summary>Changes</summary>

Add example as unit test for creating a wrapper type/view for RankedTensorType with encoding. This view provides a more restricted & typed API while it allows one to avoid repeated casting queries and accessing the encoding directly.

For users with more advance encodings, the expectation would be a separate attribute type, but here just StringAttr is used.

---
Full diff: https://github.com/llvm/llvm-project/pull/99789.diff


1 Files Affected:

- (modified) mlir/unittests/IR/ShapedTypeTest.cpp (+60) 


``````````diff
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 61264bc523648..66ee416d7636e 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -11,6 +11,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectInterface.h"
+#include "mlir/Support/LLVM.h"
 #include "llvm/ADT/SmallVector.h"
 #include "gtest/gtest.h"
 #include <cstdint>
@@ -226,4 +227,63 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
   }
 }
 
+/// Simple wrapper class to enable "isa querying" and simple accessing of
+/// encoding.
+class TensorWithString : public RankedTensorType {
+public:
+  using RankedTensorType::RankedTensorType;
+
+  static TensorWithString get(ArrayRef<int64_t> shape, Type elementType,
+                              StringRef name) {
+    return mlir::cast<TensorWithString>(RankedTensorType::get(
+        shape, elementType, StringAttr::get(elementType.getContext(), name)));
+  }
+
+  StringRef getName() const {
+    if (Attribute enc = getEncoding())
+      return mlir::cast<StringAttr>(enc).getValue();
+    return {};
+  }
+
+  static bool classof(Type type);
+};
+
+bool TensorWithString::classof(Type type) {
+  if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type))
+    return mlir::isa_and_present<StringAttr>(rt.getEncoding());
+  return false;
+}
+
+TEST(ShapedTypeTest, RankedTensorTypeView) {
+  MLIRContext context;
+  Type f32 = FloatType::getF32(&context);
+
+  Type noEncodingRankedTensorType = RankedTensorType::get({10, 20}, f32);
+
+  UnitAttr unitAttr = UnitAttr::get(&context);
+  Type unitEncodingRankedTensorType =
+      RankedTensorType::get({10, 20}, f32, unitAttr);
+
+  StringAttr stringAttr = StringAttr::get(&context, "app");
+  Type stringEncodingRankedTensorType =
+      RankedTensorType::get({10, 20}, f32, stringAttr);
+
+  EXPECT_FALSE(mlir::isa<TensorWithString>(noEncodingRankedTensorType));
+  EXPECT_FALSE(mlir::isa<TensorWithString>(unitEncodingRankedTensorType));
+  ASSERT_TRUE(mlir::isa<TensorWithString>(stringEncodingRankedTensorType));
+
+  // Cast to TensorWithString view.
+  auto view = mlir::cast<TensorWithString>(stringEncodingRankedTensorType);
+  ASSERT_TRUE(mlir::isa<TensorWithString>(view));
+  EXPECT_EQ(view.getName(), "app");
+  // Verify one could cast view type back to base type.
+  ASSERT_TRUE(mlir::isa<RankedTensorType>(view));
+
+  Type viewCreated = TensorWithString::get({10, 20}, f32, "bob");
+  ASSERT_TRUE(mlir::isa<TensorWithString>(viewCreated));
+  ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
+  view = mlir::cast<TensorWithString>(viewCreated);
+  EXPECT_EQ(view.getName(), "bob");
+}
+
 } // namespace

``````````

</details>


https://github.com/llvm/llvm-project/pull/99789


More information about the Mlir-commits mailing list