[Mlir-commits] [mlir] d2f42c7 - [mlir] Add unit test for RankedTensorType wrapper example. (#99789)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 22 13:16:26 PDT 2024


Author: Jacques Pienaar
Date: 2024-07-22T13:16:22-07:00
New Revision: d2f42c737234662d1bdc2f2e393e1b59536f8a93

URL: https://github.com/llvm/llvm-project/commit/d2f42c737234662d1bdc2f2e393e1b59536f8a93
DIFF: https://github.com/llvm/llvm-project/commit/d2f42c737234662d1bdc2f2e393e1b59536f8a93.diff

LOG: [mlir] Add unit test for RankedTensorType wrapper example. (#99789)

Add example as unit test for creating a wrapper type/view for
RankedTensorType with encoding. This view provides a more restricted &
typed API while it allows one to avoid repeated casting queries and
accessing the encoding directly.

For users with more advance encodings, the expectation would be a
separate attribute type, but here just StringAttr is used.

Added: 
    

Modified: 
    mlir/unittests/IR/ShapedTypeTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 61264bc523648..7a5b0722a03ba 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -11,6 +11,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectInterface.h"
+#include "mlir/Support/LLVM.h"
 #include "llvm/ADT/SmallVector.h"
 #include "gtest/gtest.h"
 #include <cstdint>
@@ -226,4 +227,61 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
   }
 }
 
+/// Simple wrapper class to enable "isa querying" and simple accessing of
+/// encoding.
+class TensorWithString : public RankedTensorType {
+public:
+  using RankedTensorType::RankedTensorType;
+
+  static TensorWithString get(ArrayRef<int64_t> shape, Type elementType,
+                              StringRef name) {
+    return mlir::cast<TensorWithString>(RankedTensorType::get(
+        shape, elementType, StringAttr::get(elementType.getContext(), name)));
+  }
+
+  StringRef getName() const {
+    if (Attribute enc = getEncoding())
+      return mlir::cast<StringAttr>(enc).getValue();
+    return {};
+  }
+
+  static bool classof(Type type) {
+    if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type))
+      return mlir::isa_and_present<StringAttr>(rt.getEncoding());
+    return false;
+  }
+};
+
+TEST(ShapedTypeTest, RankedTensorTypeView) {
+  MLIRContext context;
+  Type f32 = FloatType::getF32(&context);
+
+  Type noEncodingRankedTensorType = RankedTensorType::get({10, 20}, f32);
+
+  UnitAttr unitAttr = UnitAttr::get(&context);
+  Type unitEncodingRankedTensorType =
+      RankedTensorType::get({10, 20}, f32, unitAttr);
+
+  StringAttr stringAttr = StringAttr::get(&context, "app");
+  Type stringEncodingRankedTensorType =
+      RankedTensorType::get({10, 20}, f32, stringAttr);
+
+  EXPECT_FALSE(mlir::isa<TensorWithString>(noEncodingRankedTensorType));
+  EXPECT_FALSE(mlir::isa<TensorWithString>(unitEncodingRankedTensorType));
+  ASSERT_TRUE(mlir::isa<TensorWithString>(stringEncodingRankedTensorType));
+
+  // Cast to TensorWithString view.
+  auto view = mlir::cast<TensorWithString>(stringEncodingRankedTensorType);
+  ASSERT_TRUE(mlir::isa<TensorWithString>(view));
+  EXPECT_EQ(view.getName(), "app");
+  // Verify one could cast view type back to base type.
+  ASSERT_TRUE(mlir::isa<RankedTensorType>(view));
+
+  Type viewCreated = TensorWithString::get({10, 20}, f32, "bob");
+  ASSERT_TRUE(mlir::isa<TensorWithString>(viewCreated));
+  ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
+  view = mlir::cast<TensorWithString>(viewCreated);
+  EXPECT_EQ(view.getName(), "bob");
+}
+
 } // namespace


        


More information about the Mlir-commits mailing list