[Mlir-commits] [mlir] 381a65f - [mlir] Add clone method to ShapedType
Jacques Pienaar
llvmlistbot at llvm.org
Mon Feb 15 11:04:31 PST 2021
Author: Jacques Pienaar
Date: 2021-02-15T11:04:16-08:00
New Revision: 381a65fa066171977bc9119432917a1444f99f87
URL: https://github.com/llvm/llvm-project/commit/381a65fa066171977bc9119432917a1444f99f87
DIFF: https://github.com/llvm/llvm-project/commit/381a65fa066171977bc9119432917a1444f99f87.diff
LOG: [mlir] Add clone method to ShapedType
Allow clients to create a new ShapedType of the same "container" type
but with different element or shape. First use case is when refining
shape during shape inference without needing to consider which
ShapedType is being refined.
Differential Revision: https://reviews.llvm.org/D96682
Added:
mlir/unittests/IR/ShapedTypeTest.cpp
Modified:
mlir/include/mlir/IR/BuiltinTypes.h
mlir/lib/IR/BuiltinTypes.cpp
mlir/unittests/IR/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index d483ee4a4f2d..9064e3294e81 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -88,6 +88,11 @@ class ShapedType : public Type {
static constexpr int64_t kDynamicStrideOrOffset =
std::numeric_limits<int64_t>::min();
+ /// Return clone of this type with new shape and element type.
+ ShapedType clone(ArrayRef<int64_t> shape, Type elementType);
+ ShapedType clone(ArrayRef<int64_t> shape);
+ ShapedType clone(Type elementType);
+
/// Return the element type.
Type getElementType() const;
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 1316bfd3bb52..cedc6ad3c2d8 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -197,6 +197,75 @@ LogicalResult OpaqueType::verifyConstructionInvariants(Location loc,
constexpr int64_t ShapedType::kDynamicSize;
constexpr int64_t ShapedType::kDynamicStrideOrOffset;
+ShapedType ShapedType::clone(ArrayRef<int64_t> shape, Type elementType) {
+ if (auto other = dyn_cast<MemRefType>()) {
+ MemRefType::Builder b(other);
+ b.setShape(shape);
+ b.setElementType(elementType);
+ return b;
+ }
+
+ if (auto other = dyn_cast<UnrankedMemRefType>()) {
+ MemRefType::Builder b(shape, elementType);
+ b.setMemorySpace(other.getMemorySpace());
+ return b;
+ }
+
+ if (isa<TensorType>())
+ return RankedTensorType::get(shape, elementType);
+
+ if (isa<VectorType>())
+ return VectorType::get(shape, elementType);
+
+ llvm_unreachable("Unhandled ShapedType clone case");
+}
+
+ShapedType ShapedType::clone(ArrayRef<int64_t> shape) {
+ if (auto other = dyn_cast<MemRefType>()) {
+ MemRefType::Builder b(other);
+ b.setShape(shape);
+ return b;
+ }
+
+ if (auto other = dyn_cast<UnrankedMemRefType>()) {
+ MemRefType::Builder b(shape, other.getElementType());
+ b.setShape(shape);
+ b.setMemorySpace(other.getMemorySpace());
+ return b;
+ }
+
+ if (isa<TensorType>())
+ return RankedTensorType::get(shape, getElementType());
+
+ if (isa<VectorType>())
+ return VectorType::get(shape, getElementType());
+
+ llvm_unreachable("Unhandled ShapedType clone case");
+}
+
+ShapedType ShapedType::clone(Type elementType) {
+ if (auto other = dyn_cast<MemRefType>()) {
+ MemRefType::Builder b(other);
+ b.setElementType(elementType);
+ return b;
+ }
+
+ if (auto other = dyn_cast<UnrankedMemRefType>()) {
+ return UnrankedMemRefType::get(elementType, other.getMemorySpace());
+ }
+
+ if (isa<TensorType>()) {
+ if (hasRank())
+ return RankedTensorType::get(getShape(), elementType);
+ return UnrankedTensorType::get(elementType);
+ }
+
+ if (isa<VectorType>())
+ return VectorType::get(getShape(), elementType);
+
+ llvm_unreachable("Unhandled ShapedType clone hit");
+}
+
Type ShapedType::getElementType() const {
return static_cast<ImplType *>(impl)->elementType;
}
diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index 0b80f11e1955..af21c1b8b43b 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_unittest(MLIRIRTests
AttributeTest.cpp
DialectTest.cpp
OperationSupportTest.cpp
+ ShapedTypeTest.cpp
)
target_link_libraries(MLIRIRTests
PRIVATE
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
new file mode 100644
index 000000000000..e3e5ffe95fe1
--- /dev/null
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -0,0 +1,129 @@
+//===- ShapedTypeTest.cpp - ShapedType unit tests -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectInterface.h"
+#include "llvm/ADT/SmallVector.h"
+#include "gtest/gtest.h"
+#include <cstdint>
+
+using namespace mlir;
+using namespace mlir::detail;
+
+namespace {
+TEST(ShapedTypeTest, CloneMemref) {
+ MLIRContext context;
+
+ Type i32 = IntegerType::get(&context, 32);
+ Type f32 = FloatType::getF32(&context);
+ int memSpace = 7;
+ Type memrefOriginalType = i32;
+ llvm::SmallVector<int64_t> memrefOriginalShape({10, 20});
+ AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context);
+
+ ShapedType memrefType =
+ MemRefType::Builder(memrefOriginalShape, memrefOriginalType)
+ .setMemorySpace(memSpace)
+ .setAffineMaps(map);
+ // Update shape.
+ llvm::SmallVector<int64_t> memrefNewShape({30, 40});
+ ASSERT_NE(memrefOriginalShape, memrefNewShape);
+ ASSERT_EQ(memrefType.clone(memrefNewShape),
+ (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
+ .setMemorySpace(memSpace)
+ .setAffineMaps(map));
+ // Update type.
+ Type memrefNewType = f32;
+ ASSERT_NE(memrefOriginalType, memrefNewType);
+ ASSERT_EQ(memrefType.clone(memrefNewType),
+ (MemRefType)MemRefType::Builder(memrefOriginalShape, memrefNewType)
+ .setMemorySpace(memSpace)
+ .setAffineMaps(map));
+ // Update both.
+ ASSERT_EQ(memrefType.clone(memrefNewShape, memrefNewType),
+ (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
+ .setMemorySpace(memSpace)
+ .setAffineMaps(map));
+
+ // Test unranked memref cloning.
+ ShapedType unrankedTensorType =
+ UnrankedMemRefType::get(memrefOriginalType, memSpace);
+ ASSERT_EQ(unrankedTensorType.clone(memrefNewShape),
+ (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType)
+ .setMemorySpace(memSpace));
+ ASSERT_EQ(unrankedTensorType.clone(memrefNewType),
+ UnrankedMemRefType::get(memrefNewType, memSpace));
+ ASSERT_EQ(unrankedTensorType.clone(memrefNewShape, memrefNewType),
+ (MemRefType)MemRefType::Builder(memrefNewShape, memrefNewType)
+ .setMemorySpace(memSpace));
+}
+
+TEST(ShapedTypeTest, CloneTensor) {
+ MLIRContext context;
+
+ Type i32 = IntegerType::get(&context, 32);
+ Type f32 = FloatType::getF32(&context);
+
+ Type tensorOriginalType = i32;
+ llvm::SmallVector<int64_t> tensorOriginalShape({10, 20});
+
+ // Test ranked tensor cloning.
+ ShapedType tensorType =
+ RankedTensorType::get(tensorOriginalShape, tensorOriginalType);
+ // Update shape.
+ llvm::SmallVector<int64_t> tensorNewShape({30, 40});
+ ASSERT_NE(tensorOriginalShape, tensorNewShape);
+ ASSERT_EQ(tensorType.clone(tensorNewShape),
+ RankedTensorType::get(tensorNewShape, tensorOriginalType));
+ // Update type.
+ Type tensorNewType = f32;
+ ASSERT_NE(tensorOriginalType, tensorNewType);
+ ASSERT_EQ(tensorType.clone(tensorNewType),
+ RankedTensorType::get(tensorOriginalShape, tensorNewType));
+ // Update both.
+ ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType),
+ RankedTensorType::get(tensorNewShape, tensorNewType));
+
+ // Test unranked tensor cloning.
+ ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType);
+ ASSERT_EQ(unrankedTensorType.clone(tensorNewShape),
+ RankedTensorType::get(tensorNewShape, tensorOriginalType));
+ ASSERT_EQ(unrankedTensorType.clone(tensorNewType),
+ UnrankedTensorType::get(tensorNewType));
+ ASSERT_EQ(unrankedTensorType.clone(tensorNewShape),
+ RankedTensorType::get(tensorNewShape, tensorOriginalType));
+}
+
+TEST(ShapedTypeTest, CloneVector) {
+ MLIRContext context;
+
+ Type i32 = IntegerType::get(&context, 32);
+ Type f32 = FloatType::getF32(&context);
+
+ Type vectorOriginalType = i32;
+ llvm::SmallVector<int64_t> vectorOriginalShape({10, 20});
+ ShapedType vectorType =
+ VectorType::get(vectorOriginalShape, vectorOriginalType);
+ // Update shape.
+ llvm::SmallVector<int64_t> vectorNewShape({30, 40});
+ ASSERT_NE(vectorOriginalShape, vectorNewShape);
+ ASSERT_EQ(vectorType.clone(vectorNewShape),
+ VectorType::get(vectorNewShape, vectorOriginalType));
+ // Update type.
+ Type vectorNewType = f32;
+ ASSERT_NE(vectorOriginalType, vectorNewType);
+ ASSERT_EQ(vectorType.clone(vectorNewType),
+ VectorType::get(vectorOriginalShape, vectorNewType));
+ // Update both.
+ ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType),
+ VectorType::get(vectorNewShape, vectorNewType));
+}
+
+} // end namespace
More information about the Mlir-commits
mailing list