[Mlir-commits] [mlir] c1a2985 - [mlir] NFC - Add VectorType::Builder to more easily build vector types from existing ones

Nicolas Vasilache llvmlistbot at llvm.org
Mon Nov 15 02:37:00 PST 2021


Author: Nicolas Vasilache
Date: 2021-11-15T10:36:55Z
New Revision: c1a2985d7f4edb0d1ffeda512f84282e60eae677

URL: https://github.com/llvm/llvm-project/commit/c1a2985d7f4edb0d1ffeda512f84282e60eae677
DIFF: https://github.com/llvm/llvm-project/commit/c1a2985d7f4edb0d1ffeda512f84282e60eae677.diff

LOG: [mlir] NFC - Add VectorType::Builder to more easily build vector types from existing ones

Differential Revision: https://reviews.llvm.org/D113875

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 0e2541db9951d..82cc5840f867f 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -286,10 +286,11 @@ class RankedTensorType::Builder {
     return *this;
   }
 
-  /// Create a new RankedTensorType by erasing a dim from shape.
-  RankedTensorType dropDim(unsigned dim) {
+  /// Create a new RankedTensor by erasing a dim from shape @pos.
+  RankedTensorType dropDim(unsigned pos) {
+    assert(pos < shape.size() && "overflow");
     SmallVector<int64_t, 4> newShape(shape.begin(), shape.end());
-    newShape.erase(newShape.begin() + dim);
+    newShape.erase(newShape.begin() + pos);
     return setShape(newShape);
   }
 
@@ -303,6 +304,52 @@ class RankedTensorType::Builder {
   Attribute encoding;
 };
 
+//===----------------------------------------------------------------------===//
+// VectorType
+//===----------------------------------------------------------------------===//
+
+/// This is a builder type that keeps local references to arguments. Arguments
+/// that are passed into the builder must outlive the builder.
+class VectorType::Builder {
+public:
+  /// Build from another VectorType.
+  explicit Builder(VectorType other)
+      : shape(other.getShape()), elementType(other.getElementType()) {}
+
+  /// Build from scratch.
+  Builder(ArrayRef<int64_t> shape, Type elementType)
+      : shape(shape), elementType(elementType) {}
+
+  Builder &setShape(ArrayRef<int64_t> newShape) {
+    shape = newShape;
+    return *this;
+  }
+
+  Builder &setElementType(Type newElementType) {
+    elementType = newElementType;
+    return *this;
+  }
+
+  /// Create a new VectorType by erasing a dim from shape @pos.
+  /// In the particular case where the vector has a single dimension that we
+  /// drop, return the scalar element type.
+  // TODO: unify once we have a VectorType that supports 0-D.
+  Type dropDim(unsigned pos) {
+    assert(pos < shape.size() && "overflow");
+    if (shape.size() == 1)
+      return elementType;
+    SmallVector<int64_t, 4> newShape(shape.begin(), shape.end());
+    newShape.erase(newShape.begin() + pos);
+    return setShape(newShape);
+  }
+
+  operator VectorType() { return VectorType::get(shape, elementType); }
+
+private:
+  ArrayRef<int64_t> shape;
+  Type elementType;
+};
+
 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
 /// `originalShape` with some `1` entries erased, return the set of indices
 /// that specifies which of the entries of `originalShape` are dropped to obtain

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index b3b23d5901808..c38a9c4cbe159 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -929,6 +929,10 @@ def Builtin_Vector : Builtin_Type<"Vector", [
     }]>
   ];
   let extraClassDeclaration = [{
+    /// This is a builder type that keeps local references to arguments.
+    /// Arguments that are passed into the builder must outlive the builder.
+    class Builder;
+
     /// Returns true of the given type can be used as an element of a vector
     /// type. In particular, vectors can consist of integer, index, or float
     /// primitives.

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index bba4e4f977633..80b4e606c6ff2 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -472,8 +472,8 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
   auto rhsType = types[1].cast<VectorType>();
   auto maskElementType = parser.getBuilder().getI1Type();
   std::array<Type, 2> maskTypes = {
-      VectorType::get(lhsType.getShape(), maskElementType),
-      VectorType::get(rhsType.getShape(), maskElementType)};
+      VectorType::Builder(lhsType).setElementType(maskElementType),
+      VectorType::Builder(rhsType).setElementType(maskElementType)};
   if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
     return failure();
   return success();

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 6528789810bfa..df32b15a872ad 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -79,25 +79,6 @@ static AffineMap adjustMap(AffineMap map, int64_t index,
   return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
 }
 
-// Helper to drop dimension from vector type.
-static Type adjustType(VectorType tp, int64_t index) {
-  int64_t rank = tp.getRank();
-  Type eltType = tp.getElementType();
-  if (rank == 1) {
-    assert(index == 0 && "index for scalar result out of bounds");
-    return eltType;
-  }
-  SmallVector<int64_t, 4> adjustedShape;
-  for (int64_t i = 0; i < rank; ++i) {
-    // Omit dimension at the given index.
-    if (i == index)
-      continue;
-    // Otherwise, add dimension back.
-    adjustedShape.push_back(tp.getDimSize(i));
-  }
-  return VectorType::get(adjustedShape, eltType);
-}
-
 // Helper method to possibly drop a dimension in a load.
 // TODO
 static Value reshapeLoad(Location loc, Value val, VectorType type,
@@ -105,7 +86,7 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
                          PatternRewriter &rewriter) {
   if (index == -1)
     return val;
-  Type lowType = adjustType(type, 0);
+  Type lowType = VectorType::Builder(type).dropDim(0);
   // At extraction dimension?
   if (index == 0) {
     auto posAttr = rewriter.getI64ArrayAttr(pos);
@@ -113,7 +94,7 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
   }
   // Unroll leading dimensions.
   VectorType vType = lowType.cast<VectorType>();
-  VectorType resType = adjustType(type, index).cast<VectorType>();
+  auto resType = VectorType::Builder(type).dropDim(index).cast<VectorType>();
   Value result = rewriter.create<arith::ConstantOp>(
       loc, resType, rewriter.getZeroAttr(resType));
   for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
@@ -140,9 +121,9 @@ static Value reshapeStore(Location loc, Value val, Value result,
     return rewriter.create<vector::InsertOp>(loc, type, val, result, posAttr);
   }
   // Unroll leading dimensions.
-  Type lowType = adjustType(type, 0);
+  Type lowType = VectorType::Builder(type).dropDim(0);
   VectorType vType = lowType.cast<VectorType>();
-  Type insType = adjustType(vType, 0);
+  Type insType = VectorType::Builder(vType).dropDim(0);
   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
     auto posAttr = rewriter.getI64ArrayAttr(d);
     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, posAttr);


        


More information about the Mlir-commits mailing list