[Mlir-commits] [mlir] d2f42c7 - [mlir] Add unit test for RankedTensorType wrapper example. (#99789)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 22 13:16:26 PDT 2024
Author: Jacques Pienaar
Date: 2024-07-22T13:16:22-07:00
New Revision: d2f42c737234662d1bdc2f2e393e1b59536f8a93
URL: https://github.com/llvm/llvm-project/commit/d2f42c737234662d1bdc2f2e393e1b59536f8a93
DIFF: https://github.com/llvm/llvm-project/commit/d2f42c737234662d1bdc2f2e393e1b59536f8a93.diff
LOG: [mlir] Add unit test for RankedTensorType wrapper example. (#99789)
Add example as unit test for creating a wrapper type/view for
RankedTensorType with encoding. This view provides a more restricted &
typed API while it allows one to avoid repeated casting queries and
accessing the encoding directly.
For users with more advance encodings, the expectation would be a
separate attribute type, but here just StringAttr is used.
Added:
Modified:
mlir/unittests/IR/ShapedTypeTest.cpp
Removed:
################################################################################
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 61264bc523648..7a5b0722a03ba 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectInterface.h"
+#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
#include "gtest/gtest.h"
#include <cstdint>
@@ -226,4 +227,61 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
}
}
+/// Simple wrapper class to enable "isa querying" and simple accessing of
+/// encoding.
+class TensorWithString : public RankedTensorType {
+public:
+ using RankedTensorType::RankedTensorType;
+
+ static TensorWithString get(ArrayRef<int64_t> shape, Type elementType,
+ StringRef name) {
+ return mlir::cast<TensorWithString>(RankedTensorType::get(
+ shape, elementType, StringAttr::get(elementType.getContext(), name)));
+ }
+
+ StringRef getName() const {
+ if (Attribute enc = getEncoding())
+ return mlir::cast<StringAttr>(enc).getValue();
+ return {};
+ }
+
+ static bool classof(Type type) {
+ if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type))
+ return mlir::isa_and_present<StringAttr>(rt.getEncoding());
+ return false;
+ }
+};
+
+TEST(ShapedTypeTest, RankedTensorTypeView) {
+ MLIRContext context;
+ Type f32 = FloatType::getF32(&context);
+
+ Type noEncodingRankedTensorType = RankedTensorType::get({10, 20}, f32);
+
+ UnitAttr unitAttr = UnitAttr::get(&context);
+ Type unitEncodingRankedTensorType =
+ RankedTensorType::get({10, 20}, f32, unitAttr);
+
+ StringAttr stringAttr = StringAttr::get(&context, "app");
+ Type stringEncodingRankedTensorType =
+ RankedTensorType::get({10, 20}, f32, stringAttr);
+
+ EXPECT_FALSE(mlir::isa<TensorWithString>(noEncodingRankedTensorType));
+ EXPECT_FALSE(mlir::isa<TensorWithString>(unitEncodingRankedTensorType));
+ ASSERT_TRUE(mlir::isa<TensorWithString>(stringEncodingRankedTensorType));
+
+ // Cast to TensorWithString view.
+ auto view = mlir::cast<TensorWithString>(stringEncodingRankedTensorType);
+ ASSERT_TRUE(mlir::isa<TensorWithString>(view));
+ EXPECT_EQ(view.getName(), "app");
+ // Verify one could cast view type back to base type.
+ ASSERT_TRUE(mlir::isa<RankedTensorType>(view));
+
+ Type viewCreated = TensorWithString::get({10, 20}, f32, "bob");
+ ASSERT_TRUE(mlir::isa<TensorWithString>(viewCreated));
+ ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
+ view = mlir::cast<TensorWithString>(viewCreated);
+ EXPECT_EQ(view.getName(), "bob");
+}
+
} // namespace
More information about the Mlir-commits
mailing list