[Mlir-commits] [mlir] 789c88e - [mlir] Fix unintentional mutation by VectorType/RankedTensorType::Builder dropDim

Nicolas Vasilache llvmlistbot at llvm.org
Mon Nov 22 02:55:29 PST 2021


Author: Nicolas Vasilache
Date: 2021-11-22T10:51:50Z
New Revision: 789c88e80e878ed866a2d8cfe29c7fd36082274c

URL: https://github.com/llvm/llvm-project/commit/789c88e80e878ed866a2d8cfe29c7fd36082274c
DIFF: https://github.com/llvm/llvm-project/commit/789c88e80e878ed866a2d8cfe29c7fd36082274c.diff

LOG: [mlir] Fix unintentional mutation by VectorType/RankedTensorType::Builder dropDim

Differential Revision: https://reviews.llvm.org/D113933

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Vector/VectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index b4ce23a72915c..f3d2c24073dc6 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -283,12 +283,14 @@ class RankedTensorType::Builder {
     return *this;
   }
 
-  /// Create a new RankedTensor by erasing a dim from shape @pos.
-  RankedTensorType dropDim(unsigned pos) {
+  /// Erase a dim from shape @pos.
+  Builder &dropDim(unsigned pos) {
     assert(pos < shape.size() && "overflow");
-    SmallVector<int64_t, 4> newShape(shape.begin(), shape.end());
-    newShape.erase(newShape.begin() + pos);
-    return setShape(newShape);
+    if (storage.empty())
+      storage.append(shape.begin(), shape.end());
+    storage.erase(storage.begin() + pos);
+    shape = {storage.data(), storage.size()};
+    return *this;
   }
 
   operator RankedTensorType() {
@@ -297,6 +299,8 @@ class RankedTensorType::Builder {
 
 private:
   ArrayRef<int64_t> shape;
+  // Owning shape data for copy-on-write operations.
+  SmallVector<int64_t> storage;
   Type elementType;
   Attribute encoding;
 };
@@ -327,23 +331,29 @@ class VectorType::Builder {
     return *this;
   }
 
-  /// Create a new VectorType by erasing a dim from shape @pos.
+  /// Erase a dim from shape @pos.
+  Builder &dropDim(unsigned pos) {
+    assert(pos < shape.size() && "overflow");
+    if (storage.empty())
+      storage.append(shape.begin(), shape.end());
+    storage.erase(storage.begin() + pos);
+    shape = {storage.data(), storage.size()};
+    return *this;
+  }
+
   /// In the particular case where the vector has a single dimension that we
   /// drop, return the scalar element type.
   // TODO: unify once we have a VectorType that supports 0-D.
-  Type dropDim(unsigned pos) {
-    assert(pos < shape.size() && "overflow");
-    if (shape.size() == 1)
+  operator Type() {
+    if (shape.empty())
       return elementType;
-    SmallVector<int64_t, 4> newShape(shape.begin(), shape.end());
-    newShape.erase(newShape.begin() + pos);
-    return setShape(newShape);
+    return VectorType::get(shape, elementType);
   }
 
-  operator VectorType() { return VectorType::get(shape, elementType); }
-
 private:
   ArrayRef<int64_t> shape;
+  // Owning shape data for copy-on-write operations.
+  SmallVector<int64_t> storage;
   Type elementType;
 };
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 657f2b7605589..36bb0171823f7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -876,9 +876,12 @@ struct DownscaleSizeOneWindowed2DConvolution final
     // Get new shapes and types for all operands by removing the size-1
     // dimension.
     using RTTBuilder = RankedTensorType::Builder;
-    auto newInputType = RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
-    auto newFilterType = RTTBuilder(filterType).dropDim((removeH ? 0 : 1));
-    auto newOutputType = RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
+    RankedTensorType newInputType =
+        RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
+    RankedTensorType newFilterType =
+        RTTBuilder(filterType).dropDim((removeH ? 0 : 1));
+    RankedTensorType newOutputType =
+        RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
 
     // Rank-reduce operands.
     Location loc = convOp.getLoc();
@@ -948,9 +951,12 @@ struct DownscaleDepthwiseConv2DNhwcHwcOp final
     // Get new shapes and types for all operands by removing the size-1
     // dimension.
     using RTTBuilder = RankedTensorType::Builder;
-    auto newInputType = RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
-    auto newKernelType = RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
-    auto newOutputType = RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
+    RankedTensorType newInputType =
+        RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
+    RankedTensorType newKernelType =
+        RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
+    RankedTensorType newOutputType =
+        RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
 
     // Rank-reduce operands.
     Location loc = convOp.getLoc();

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 37f3c31e6a48d..5760e80bfcaff 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -94,15 +94,16 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
   }
   // Unroll leading dimensions.
   VectorType vType = lowType.cast<VectorType>();
-  auto resType = VectorType::Builder(type).dropDim(index).cast<VectorType>();
+  Type resType = VectorType::Builder(type).dropDim(index);
+  auto resVectorType = resType.cast<VectorType>();
   Value result = rewriter.create<arith::ConstantOp>(
-      loc, resType, rewriter.getZeroAttr(resType));
-  for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
+      loc, resVectorType, rewriter.getZeroAttr(resVectorType));
+  for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
     auto posAttr = rewriter.getI64ArrayAttr(d);
     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, posAttr);
     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
-    result =
-        rewriter.create<vector::InsertOp>(loc, resType, load, result, posAttr);
+    result = rewriter.create<vector::InsertOp>(loc, resVectorType, load, result,
+                                               posAttr);
   }
   return result;
 }


        


More information about the Mlir-commits mailing list