[Mlir-commits] [mlir] 296d5cb - [mlir][BuiltinTypes] Return VectorType from VectorType::Builder conversion operator

Benjamin Maxwell llvmlistbot at llvm.org
Wed Aug 30 06:47:40 PDT 2023


Author: Benjamin Maxwell
Date: 2023-08-30T13:47:06Z
New Revision: 296d5cb60c20fe314babcd93fb5df5ecc24ae987

URL: https://github.com/llvm/llvm-project/commit/296d5cb60c20fe314babcd93fb5df5ecc24ae987
DIFF: https://github.com/llvm/llvm-project/commit/296d5cb60c20fe314babcd93fb5df5ecc24ae987.diff

LOG: [mlir][BuiltinTypes] Return VectorType from VectorType::Builder conversion operator

0-D vectors are now supported, so the special case of returning the just
the element type can now be removed.

A few callers that relied on the old behaviour have been updated.

Reviewed By: awarzynski, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D159122

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index ce68fc2673dcaf..f0b19fe543a5bf 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -357,12 +357,7 @@ class VectorType::Builder {
     return *this;
   }
 
-  /// In the particular case where the vector has a single dimension that we
-  /// drop, return the scalar element type.
-  // TODO: unify once we have a VectorType that supports 0-D.
-  operator Type() {
-    if (shape.empty())
-      return elementType;
+  operator VectorType() {
     return VectorType::get(shape, elementType, scalableDims);
   }
 

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 88bda3931a5a11..af539d2c3795a0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2216,7 +2216,7 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
       return failure();
     if (mask.size() != 1)
       return failure();
-    Type resType = VectorType::Builder(v1VectorType).setShape({1});
+    VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
     if (llvm::cast<IntegerAttr>(mask[0]).getInt() == 0)
       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
                                                        shuffleOp.getV1());

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 66ac5ffef3e3ed..1b3d617a79edb7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -89,21 +89,20 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
                          PatternRewriter &rewriter) {
   if (index == -1)
     return val;
-  Type lowType = VectorType::Builder(type).dropDim(0);
+  Type lowType = type.getRank() > 1 ? VectorType::Builder(type).dropDim(0)
+                                    : type.getElementType();
   // At extraction dimension?
   if (index == 0)
     return rewriter.create<vector::ExtractOp>(loc, lowType, val, pos);
   // Unroll leading dimensions.
   VectorType vType = cast<VectorType>(lowType);
-  Type resType = VectorType::Builder(type).dropDim(index);
-  auto resVectorType = cast<VectorType>(resType);
+  VectorType resType = VectorType::Builder(type).dropDim(index);
   Value result = rewriter.create<arith::ConstantOp>(
-      loc, resVectorType, rewriter.getZeroAttr(resVectorType));
-  for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
+      loc, resType, rewriter.getZeroAttr(resType));
+  for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
     Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, d);
     Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
-    result =
-        rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, d);
+    result = rewriter.create<vector::InsertOp>(loc, resType, load, result, d);
   }
   return result;
 }
@@ -120,13 +119,13 @@ static Value reshapeStore(Location loc, Value val, Value result,
   if (index == 0)
     return rewriter.create<vector::InsertOp>(loc, type, val, result, pos);
   // Unroll leading dimensions.
-  Type lowType = VectorType::Builder(type).dropDim(0);
-  VectorType vType = cast<VectorType>(lowType);
-  Type insType = VectorType::Builder(vType).dropDim(0);
+  VectorType lowType = VectorType::Builder(type).dropDim(0);
+  Type insType = lowType.getRank() > 1 ? VectorType::Builder(lowType).dropDim(0)
+                                       : lowType.getElementType();
   for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
-    Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, d);
+    Value ext = rewriter.create<vector::ExtractOp>(loc, lowType, result, d);
     Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, d);
-    Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
+    Value sto = reshapeStore(loc, ins, ext, lowType, index - 1, pos, rewriter);
     result = rewriter.create<vector::InsertOp>(loc, type, sto, result, d);
   }
   return result;


        


More information about the Mlir-commits mailing list