[Mlir-commits] [mlir] c1a2985 - [mlir] NFC - Add VectorType::Builder to more easily build vector types from existing ones
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Nov 15 02:37:00 PST 2021
Author: Nicolas Vasilache
Date: 2021-11-15T10:36:55Z
New Revision: c1a2985d7f4edb0d1ffeda512f84282e60eae677
URL: https://github.com/llvm/llvm-project/commit/c1a2985d7f4edb0d1ffeda512f84282e60eae677
DIFF: https://github.com/llvm/llvm-project/commit/c1a2985d7f4edb0d1ffeda512f84282e60eae677.diff
LOG: [mlir] NFC - Add VectorType::Builder to more easily build vector types from existing ones
Differential Revision: https://reviews.llvm.org/D113875
Added:
Modified:
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/BuiltinTypes.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 0e2541db9951d..82cc5840f867f 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -286,10 +286,11 @@ class RankedTensorType::Builder {
return *this;
}
- /// Create a new RankedTensorType by erasing a dim from shape.
- RankedTensorType dropDim(unsigned dim) {
+ /// Create a new RankedTensor by erasing a dim from shape @pos.
+ RankedTensorType dropDim(unsigned pos) {
+ assert(pos < shape.size() && "overflow");
SmallVector<int64_t, 4> newShape(shape.begin(), shape.end());
- newShape.erase(newShape.begin() + dim);
+ newShape.erase(newShape.begin() + pos);
return setShape(newShape);
}
@@ -303,6 +304,52 @@ class RankedTensorType::Builder {
Attribute encoding;
};
+//===----------------------------------------------------------------------===//
+// VectorType
+//===----------------------------------------------------------------------===//
+
+/// This is a builder type that keeps local references to arguments. Arguments
+/// that are passed into the builder must outlive the builder.
+class VectorType::Builder {
+public:
+ /// Build from another VectorType.
+ explicit Builder(VectorType other)
+ : shape(other.getShape()), elementType(other.getElementType()) {}
+
+ /// Build from scratch.
+ Builder(ArrayRef<int64_t> shape, Type elementType)
+ : shape(shape), elementType(elementType) {}
+
+ Builder &setShape(ArrayRef<int64_t> newShape) {
+ shape = newShape;
+ return *this;
+ }
+
+ Builder &setElementType(Type newElementType) {
+ elementType = newElementType;
+ return *this;
+ }
+
+ /// Create a new VectorType by erasing a dim from shape @pos.
+ /// In the particular case where the vector has a single dimension that we
+ /// drop, return the scalar element type.
+ // TODO: unify once we have a VectorType that supports 0-D.
+ Type dropDim(unsigned pos) {
+ assert(pos < shape.size() && "overflow");
+ if (shape.size() == 1)
+ return elementType;
+ SmallVector<int64_t, 4> newShape(shape.begin(), shape.end());
+ newShape.erase(newShape.begin() + pos);
+ return setShape(newShape);
+ }
+
+ operator VectorType() { return VectorType::get(shape, elementType); }
+
+private:
+ ArrayRef<int64_t> shape;
+ Type elementType;
+};
+
/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
/// `originalShape` with some `1` entries erased, return the set of indices
/// that specifies which of the entries of `originalShape` are dropped to obtain
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index b3b23d5901808..c38a9c4cbe159 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -929,6 +929,10 @@ def Builtin_Vector : Builtin_Type<"Vector", [
}]>
];
let extraClassDeclaration = [{
+ /// This is a builder type that keeps local references to arguments.
+ /// Arguments that are passed into the builder must outlive the builder.
+ class Builder;
+
/// Returns true of the given type can be used as an element of a vector
/// type. In particular, vectors can consist of integer, index, or float
/// primitives.
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index bba4e4f977633..80b4e606c6ff2 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -472,8 +472,8 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
auto rhsType = types[1].cast<VectorType>();
auto maskElementType = parser.getBuilder().getI1Type();
std::array<Type, 2> maskTypes = {
- VectorType::get(lhsType.getShape(), maskElementType),
- VectorType::get(rhsType.getShape(), maskElementType)};
+ VectorType::Builder(lhsType).setElementType(maskElementType),
+ VectorType::Builder(rhsType).setElementType(maskElementType)};
if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
return failure();
return success();
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 6528789810bfa..df32b15a872ad 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -79,25 +79,6 @@ static AffineMap adjustMap(AffineMap map, int64_t index,
return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
}
-// Helper to drop dimension from vector type.
-static Type adjustType(VectorType tp, int64_t index) {
- int64_t rank = tp.getRank();
- Type eltType = tp.getElementType();
- if (rank == 1) {
- assert(index == 0 && "index for scalar result out of bounds");
- return eltType;
- }
- SmallVector<int64_t, 4> adjustedShape;
- for (int64_t i = 0; i < rank; ++i) {
- // Omit dimension at the given index.
- if (i == index)
- continue;
- // Otherwise, add dimension back.
- adjustedShape.push_back(tp.getDimSize(i));
- }
- return VectorType::get(adjustedShape, eltType);
-}
-
// Helper method to possibly drop a dimension in a load.
// TODO
static Value reshapeLoad(Location loc, Value val, VectorType type,
@@ -105,7 +86,7 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
PatternRewriter &rewriter) {
if (index == -1)
return val;
- Type lowType = adjustType(type, 0);
+ Type lowType = VectorType::Builder(type).dropDim(0);
// At extraction dimension?
if (index == 0) {
auto posAttr = rewriter.getI64ArrayAttr(pos);
@@ -113,7 +94,7 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
}
// Unroll leading dimensions.
VectorType vType = lowType.cast<VectorType>();
- VectorType resType = adjustType(type, index).cast<VectorType>();
+ auto resType = VectorType::Builder(type).dropDim(index).cast<VectorType>();
Value result = rewriter.create<arith::ConstantOp>(
loc, resType, rewriter.getZeroAttr(resType));
for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
@@ -140,9 +121,9 @@ static Value reshapeStore(Location loc, Value val, Value result,
return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
}
// Unroll leading dimensions.
- Type lowType = adjustType(type, 0);
+ Type lowType = VectorType::Builder(type).dropDim(0);
VectorType vType = lowType.cast<VectorType>();
- Type insType = adjustType(vType, 0);
+ Type insType = VectorType::Builder(vType).dropDim(0);
for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
auto posAttr = rewriter.getI64ArrayAttr(d);
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);
More information about the Mlir-commits
mailing list