[Mlir-commits] [mlir] 28d7671 - [mlir] Add two clone methods about encoding to RankedTensorType. (#127709)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 27 17:59:30 PST 2025
Author: Han-Chung Wang
Date: 2025-02-27T17:59:27-08:00
New Revision: 28d76714714a2cdcbdd62265de15115015eb9469
URL: https://github.com/llvm/llvm-project/commit/28d76714714a2cdcbdd62265de15115015eb9469
DIFF: https://github.com/llvm/llvm-project/commit/28d76714714a2cdcbdd62265de15115015eb9469.diff
LOG: [mlir] Add two clone methods about encoding to RankedTensorType. (#127709)
There are clone methods for shape and element type, but not for
encodings. The revision adds two clone method to RankedTensorType:
- dropEncoding(): Return a clone of this type without the encoding.
- cloneWithEncoding(Attribute encoding): Return a clone of this type
with the given new encoding and the same shape and element type as this
type.
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
Added:
Modified:
mlir/include/mlir/IR/BuiltinTypes.td
mlir/unittests/IR/ShapedTypeTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index e5a2ae81da0c9..af474b3e3ec47 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1035,6 +1035,17 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
RankedTensorType clone(::mlir::Type elementType) {
return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
}
+
+ /// Return a clone of this type without the encoding.
+ RankedTensorType dropEncoding() {
+ return RankedTensorType::get(getShape(), getElementType());
+ }
+
+ /// Return a clone of this type with the given new encoding and the same
+ /// shape and element type as this type.
+ RankedTensorType cloneWithEncoding(::mlir::Attribute encoding) {
+ return RankedTensorType::get(getShape(), getElementType(), encoding);
+ }
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index c2900b5aaeeeb..bc4066ed210e8 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -282,6 +282,20 @@ TEST(ShapedTypeTest, RankedTensorTypeView) {
ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
view = mlir::cast<TensorWithString>(viewCreated);
EXPECT_EQ(view.getName(), "bob");
+
+ // Verify encoding clone methods.
+ EXPECT_EQ(unitEncodingRankedTensorType,
+ cast<RankedTensorType>(noEncodingRankedTensorType)
+ .cloneWithEncoding(unitAttr));
+ EXPECT_EQ(stringEncodingRankedTensorType,
+ cast<RankedTensorType>(noEncodingRankedTensorType)
+ .cloneWithEncoding(stringAttr));
+ EXPECT_EQ(
+ noEncodingRankedTensorType,
+ cast<RankedTensorType>(unitEncodingRankedTensorType).dropEncoding());
+ EXPECT_EQ(
+ noEncodingRankedTensorType,
+ cast<RankedTensorType>(stringEncodingRankedTensorType).dropEncoding());
}
} // namespace
More information about the Mlir-commits
mailing list