[Mlir-commits] [mlir] a1aad28 - [mlir][vector] NFC: Improve vector type accessor methods

Lei Zhang llvmlistbot at llvm.org
Wed Feb 15 20:11:53 PST 2023


Author: Lei Zhang
Date: 2023-02-16T04:08:33Z
New Revision: a1aad28d297abb14e812cacf97947a0f857a2f54

URL: https://github.com/llvm/llvm-project/commit/a1aad28d297abb14e812cacf97947a0f857a2f54
DIFF: https://github.com/llvm/llvm-project/commit/a1aad28d297abb14e812cacf97947a0f857a2f54.diff

LOG: [mlir][vector] NFC: Improve vector type accessor methods

Plain `getVectorType()` can be quite confusing and error-prone
given that, well, vector ops always work on vector types, and
it can commonly involve both source and result vectors. So this
commit makes various such accessor methods to be explicit w.r.t.
source or result vectors.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D144159

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
    mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 94d4b64465126..c5ebe9f4bcebc 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -321,7 +321,7 @@ def Vector_ReductionOp :
     ```
   }];
   let extraClassDeclaration = [{
-    VectorType getVectorType() {
+    VectorType getSourceVectorType() {
       return getVector().getType().cast<VectorType>();
     }
   }];
@@ -449,7 +449,7 @@ def Vector_BroadcastOp :
   }];
   let extraClassDeclaration = [{
     Type getSourceType() { return getSource().getType(); }
-    VectorType getVectorType() {
+    VectorType getResultVectorType() {
       return getVector().getType().cast<VectorType>();
     }
 
@@ -466,7 +466,7 @@ def Vector_BroadcastOp :
     /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
     /// the helper will assert. This means:
     ///   1. `dstShape` must not be empty.
-    ///   2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
+    ///   2. `broadcastedDims` must be confined to [0 .. rank(value.getResultVectorType)]
     ///   2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
     //       must match the `value` shape.
     static Value createOrFoldBroadcastOp(
@@ -537,7 +537,7 @@ def Vector_ShuffleOp :
     VectorType getV2VectorType() {
       return getV2().getType().cast<VectorType>();
     }
-    VectorType getVectorType() {
+    VectorType getResultVectorType() {
       return getVector().getType().cast<VectorType>();
     }
   }];
@@ -584,7 +584,7 @@ def Vector_ExtractElementOp :
     OpBuilder<(ins "Value":$source)>,
   ];
   let extraClassDeclaration = [{
-    VectorType getVectorType() {
+    VectorType getSourceVectorType() {
       return getVector().getType().cast<VectorType>();
     }
   }];
@@ -619,7 +619,7 @@ def Vector_ExtractOp :
   ];
   let extraClassDeclaration = [{
     static StringRef getPositionAttrStrName() { return "position"; }
-    VectorType getVectorType() {
+    VectorType getSourceVectorType() {
       return getVector().getType().cast<VectorType>();
     }
     static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
@@ -996,7 +996,7 @@ def Vector_OuterProductOp :
         ? VectorType()
         : (*getAcc().begin()).getType().cast<VectorType>();
     }
-    VectorType getVectorType() {
+    VectorType getResultVectorType() {
       return getResult().getType().cast<VectorType>();
     }
     static constexpr StringRef getKindAttrStrName() {
@@ -1172,7 +1172,9 @@ def Vector_ExtractStridedSliceOp :
     static StringRef getOffsetsAttrStrName() { return "offsets"; }
     static StringRef getSizesAttrStrName() { return "sizes"; }
     static StringRef getStridesAttrStrName() { return "strides"; }
-    VectorType getVectorType(){ return getVector().getType().cast<VectorType>(); }
+    VectorType getSourceVectorType() {
+      return getVector().getType().cast<VectorType>(); 
+    }
     void getOffsets(SmallVectorImpl<int64_t> &results);
     bool hasNonUnitStrides() {
       return llvm::any_of(getStrides(), [](Attribute attr) {
@@ -2424,10 +2426,10 @@ def Vector_TransposeOp :
     OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$transp)>
   ];
   let extraClassDeclaration = [{
-    VectorType getVectorType() {
+    VectorType getSourceVectorType() {
       return getVector().getType().cast<VectorType>();
     }
-    VectorType getResultType() {
+    VectorType getResultVectorType() {
       return getResult().getType().cast<VectorType>();
     }
     void getTransp(SmallVectorImpl<int64_t> &results);

diff  --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index ece17510e1365..d9533b4e16b44 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -203,7 +203,7 @@ static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
 
 /// Return true if this is a broadcast from scalar to a 2D vector.
 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
-  return broadcastOp.getVectorType().getRank() == 2;
+  return broadcastOp.getResultVectorType().getRank() == 2;
 }
 
 /// Return true if this integer extend op can be folded into a contract op.
@@ -949,7 +949,7 @@ convertExtractStridedSlice(RewriterBase &rewriter,
 
   SmallVector<int64_t> sizes;
   populateFromInt64AttrArray(op.getSizes(), sizes);
-  ArrayRef<int64_t> warpVectorShape = op.getVectorType().getShape();
+  ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
 
   // Compute offset in vector registers. Note that the mma.sync vector registers
   // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
@@ -1045,7 +1045,7 @@ convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
   assert(broadcastSupportsMMAMatrixType(op));
 
   const char *fragType = inferFragType(op);
-  auto vecType = op.getVectorType();
+  auto vecType = op.getResultVectorType();
   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
   auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a1c1c3dca8362..159bae829133c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -939,7 +939,7 @@ class VectorShuffleOpConversion
     auto loc = shuffleOp->getLoc();
     auto v1Type = shuffleOp.getV1VectorType();
     auto v2Type = shuffleOp.getV2VectorType();
-    auto vectorType = shuffleOp.getVectorType();
+    auto vectorType = shuffleOp.getResultVectorType();
     Type llvmType = typeConverter->convertType(vectorType);
     auto maskArrayAttr = shuffleOp.getMask();
 
@@ -1002,7 +1002,7 @@ class VectorExtractElementOpConversion
   LogicalResult
   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto vectorType = extractEltOp.getVectorType();
+    auto vectorType = extractEltOp.getSourceVectorType();
     auto llvmType = typeConverter->convertType(vectorType.getElementType());
 
     // Bail if result type cannot be lowered.

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 0dbf67e0b69f8..a3a3a612ed147 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -83,7 +83,8 @@ struct VectorBroadcastConvert final
   LogicalResult
   matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type resultType = getTypeConverter()->convertType(castOp.getVectorType());
+    Type resultType =
+        getTypeConverter()->convertType(castOp.getResultVectorType());
     if (!resultType)
       return failure();
 
@@ -92,10 +93,10 @@ struct VectorBroadcastConvert final
       return success();
     }
 
-    SmallVector<Value, 4> source(castOp.getVectorType().getNumElements(),
+    SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
                                  adaptor.getSource());
     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
-        castOp, castOp.getVectorType(), source);
+        castOp, castOp.getResultVectorType(), source);
     return success();
   }
 };
@@ -405,7 +406,7 @@ struct VectorShuffleOpConvert final
   LogicalResult
   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto oldResultType = shuffleOp.getVectorType();
+    auto oldResultType = shuffleOp.getResultVectorType();
     if (!spirv::CompositeType::isValid(oldResultType))
       return failure();
     Type newResultType = getTypeConverter()->convertType(oldResultType);

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 64125efeb21cb..8c6609d98b439 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -416,7 +416,7 @@ void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
 
 LogicalResult ReductionOp::verify() {
   // Verify for 0-D and 1-D vector.
-  int64_t rank = getVectorType().getRank();
+  int64_t rank = getSourceVectorType().getRank();
   if (rank > 1)
     return emitOpError("unsupported reduction rank: ") << rank;
 
@@ -465,7 +465,7 @@ void ReductionOp::print(OpAsmPrinter &p) {
 
 /// Returns the mask type expected by this operation.
 Type ReductionOp::getExpectedMaskType() {
-  auto vecType = getVectorType();
+  auto vecType = getSourceVectorType();
   return vecType.cloneWith(std::nullopt,
                            IntegerType::get(vecType.getContext(), /*width=*/1));
 }
@@ -515,7 +515,7 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
 }
 
 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
-  return llvm::to_vector<4>(getVectorType().getShape());
+  return llvm::to_vector<4>(getSourceVectorType().getShape());
 }
 
 namespace {
@@ -530,7 +530,7 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
     if (maskableOp.isMasked())
       return failure();
 
-    auto vectorType = reductionOp.getVectorType();
+    auto vectorType = reductionOp.getSourceVectorType();
     if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
       return failure();
 
@@ -1074,7 +1074,7 @@ void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
 }
 
 LogicalResult vector::ExtractElementOp::verify() {
-  VectorType vectorType = getVectorType();
+  VectorType vectorType = getSourceVectorType();
   if (vectorType.getRank() == 0) {
     if (getPosition())
       return emitOpError("expected position to be empty with 0-D vector");
@@ -1167,13 +1167,14 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
 
 LogicalResult vector::ExtractOp::verify() {
   auto positionAttr = getPosition().getValue();
-  if (positionAttr.size() > static_cast<unsigned>(getVectorType().getRank()))
+  if (positionAttr.size() >
+      static_cast<unsigned>(getSourceVectorType().getRank()))
     return emitOpError(
         "expected position attribute of rank smaller than vector rank");
   for (const auto &en : llvm::enumerate(positionAttr)) {
     auto attr = en.value().dyn_cast<IntegerAttr>();
     if (!attr || attr.getInt() < 0 ||
-        attr.getInt() >= getVectorType().getDimSize(en.index()))
+        attr.getInt() >= getSourceVectorType().getDimSize(en.index()))
       return emitOpError("expected position attribute #")
              << (en.index() + 1)
              << " to be a non-negative integer smaller than the corresponding "
@@ -1314,7 +1315,7 @@ class ExtractFromInsertTransposeChainState {
 
 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
     ExtractOp e)
-    : extractOp(e), vectorRank(extractOp.getVectorType().getRank()),
+    : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
       extractedRank(extractOp.getPosition().size()) {
   assert(vectorRank >= extractedRank && "extracted pos overflow");
   sentinels.reserve(vectorRank - extractedRank);
@@ -1510,7 +1511,8 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
   int64_t stride = 1;
   for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
     strides.push_back(stride);
-    stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
+    stride *=
+        getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
   }
 
   int64_t position = linearize(extractedPos, strides);
@@ -1552,7 +1554,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
     size_t lastOffset = sliceOffsets.size() - 1;
     if (sliceOffsets.back() != 0 ||
         extractStridedSliceOp.getType().getDimSize(lastOffset) !=
-            extractStridedSliceOp.getVectorType().getDimSize(lastOffset))
+            extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
       break;
     sliceOffsets.pop_back();
   }
@@ -1561,8 +1563,8 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
     destinationRank = vecType.getRank();
   // The dimensions of the result need to be untouched by the
   // extractStridedSlice op.
-  if (destinationRank >
-      extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size())
+  if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
+                            sliceOffsets.size())
     return Value();
   auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
   assert(extractedPos.size() >= sliceOffsets.size());
@@ -1827,7 +1829,7 @@ llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
   if (!srcVectorType)
     return {};
   return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
-                                      getVectorType().getShape());
+                                      getResultVectorType().getShape());
 }
 
 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
@@ -1973,8 +1975,8 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
 
 LogicalResult BroadcastOp::verify() {
   std::pair<int, int> mismatchingDims;
-  BroadcastableToResult res =
-      isBroadcastableTo(getSourceType(), getVectorType(), &mismatchingDims);
+  BroadcastableToResult res = isBroadcastableTo(
+      getSourceType(), getResultVectorType(), &mismatchingDims);
   if (res == BroadcastableToResult::Success)
     return success();
   if (res == BroadcastableToResult::SourceRankHigher)
@@ -1988,11 +1990,11 @@ LogicalResult BroadcastOp::verify() {
 }
 
 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
-  if (getSourceType() == getVectorType())
+  if (getSourceType() == getResultVectorType())
     return getSource();
   if (!adaptor.getSource())
     return {};
-  auto vectorType = getVectorType();
+  auto vectorType = getResultVectorType();
   if (adaptor.getSource().isa<IntegerAttr, FloatAttr>())
     return DenseElementsAttr::get(vectorType, adaptor.getSource());
   if (auto attr = adaptor.getSource().dyn_cast<SplatElementsAttr>())
@@ -2011,8 +2013,9 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
     if (!srcBroadcast)
       return failure();
-    rewriter.replaceOpWithNewOp<BroadcastOp>(
-        broadcastOp, broadcastOp.getVectorType(), srcBroadcast.getSource());
+    rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
+                                             broadcastOp.getResultVectorType(),
+                                             srcBroadcast.getSource());
     return success();
   }
 };
@@ -2035,7 +2038,7 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
 }
 
 LogicalResult ShuffleOp::verify() {
-  VectorType resultType = getVectorType();
+  VectorType resultType = getResultVectorType();
   VectorType v1Type = getV1VectorType();
   VectorType v2Type = getV2VectorType();
   // Verify ranks.
@@ -2143,7 +2146,7 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
     }
   }
 
-  return DenseElementsAttr::get(getVectorType(), results);
+  return DenseElementsAttr::get(getResultVectorType(), results);
 }
 
 namespace {
@@ -2764,7 +2767,7 @@ LogicalResult OuterProductOp::verify() {
   Type tRHS = getOperandTypeRHS();
   VectorType vLHS = getOperandVectorTypeLHS(),
              vRHS = tRHS.dyn_cast<VectorType>(),
-             vACC = getOperandVectorTypeACC(), vRES = getVectorType();
+             vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
 
   if (vLHS.getRank() != 1)
     return emitOpError("expected 1-d vector for operand #1");
@@ -2805,7 +2808,7 @@ LogicalResult OuterProductOp::verify() {
 /// Returns the mask type expected by this operation. Mostly used for
 /// verification purposes. It requires the operation to be vectorized."
 Type OuterProductOp::getExpectedMaskType() {
-  auto vecType = this->getVectorType();
+  auto vecType = this->getResultVectorType();
   return VectorType::get(vecType.getShape(),
                          IntegerType::get(vecType.getContext(), /*width=*/1));
 }
@@ -2913,7 +2916,7 @@ void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
 }
 
 LogicalResult ExtractStridedSliceOp::verify() {
-  auto type = getVectorType();
+  auto type = getSourceVectorType();
   auto offsets = getOffsetsAttr();
   auto sizes = getSizesAttr();
   auto strides = getStridesAttr();
@@ -2944,8 +2947,8 @@ LogicalResult ExtractStridedSliceOp::verify() {
                                                     /*halfOpen=*/false)))
     return failure();
 
-  auto resultType =
-      inferStridedSliceOpResultType(getVectorType(), offsets, sizes, strides);
+  auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
+                                                  offsets, sizes, strides);
   if (getResult().getType() != resultType)
     return emitOpError("expected result type to be ") << resultType;
 
@@ -2966,7 +2969,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
   ArrayAttr extractSizes = op.getSizes();
   auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
   while (insertOp) {
-    if (op.getVectorType().getRank() !=
+    if (op.getSourceVectorType().getRank() !=
         insertOp.getSourceVectorType().getRank())
       return failure();
     ArrayAttr insertOffsets = insertOp.getOffsets();
@@ -3020,7 +3023,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
 }
 
 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
-  if (getVectorType() == getResult().getType())
+  if (getSourceVectorType() == getResult().getType())
     return getVector();
   if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
     return getResult();
@@ -5113,7 +5116,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   // Eliminate splat constant transpose ops.
   if (auto attr = adaptor.getVector().dyn_cast_or_null<DenseElementsAttr>())
     if (attr.isSplat())
-      return attr.reshape(getResultType());
+      return attr.reshape(getResultVectorType());
 
   // Eliminate identity transpose ops. This happens when the dimensions of the
   // input vector remain in their original order after the transpose operation.
@@ -5131,8 +5134,8 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
 }
 
 LogicalResult vector::TransposeOp::verify() {
-  VectorType vectorType = getVectorType();
-  VectorType resultType = getResultType();
+  VectorType vectorType = getSourceVectorType();
+  VectorType resultType = getResultVectorType();
   int64_t rank = resultType.getRank();
   if (vectorType.getRank() != rank)
     return emitOpError("vector result rank mismatch: ") << rank;
@@ -5156,7 +5159,7 @@ LogicalResult vector::TransposeOp::verify() {
 }
 
 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
-  return llvm::to_vector<4>(getResultType().getShape());
+  return llvm::to_vector<4>(getResultVectorType().getShape());
 }
 
 namespace {
@@ -5215,7 +5218,7 @@ struct FoldTransposedScalarBroadcast final
     auto srcVectorType = bcastOp.getSourceType().dyn_cast<VectorType>();
     if (!srcVectorType || srcVectorType.getNumElements() == 1) {
       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-          transposeOp, transposeOp.getResultType(), bcastOp.getSource());
+          transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
       return success();
     }
 
@@ -5235,7 +5238,7 @@ class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
       return failure();
 
     rewriter.replaceOpWithNewOp<vector::SplatOp>(
-        transposeOp, transposeOp.getResultType(), splatOp.getInput());
+        transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
     return success();
   }
 };

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 6005f377c9adf..a242ccab1e98c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -897,7 +897,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
       return failure();
     unsigned int operandNumber = operand->getOperandNumber();
     auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
-    VectorType extractSrcType = extractOp.getVectorType();
+    VectorType extractSrcType = extractOp.getSourceVectorType();
     Location loc = extractOp.getLoc();
 
     // "vector.extract %v[] : vector<f32>" is an invalid op.
@@ -930,7 +930,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
       SmallVector<size_t> newRetIndices;
       WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
           rewriter, warpOp, {extractOp.getVector()},
-          {extractOp.getVectorType()}, newRetIndices);
+          {extractOp.getSourceVectorType()}, newRetIndices);
       rewriter.setInsertionPointAfter(newWarpOp);
       Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
       // Extract from distributed vector.
@@ -994,7 +994,7 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
       return failure();
     unsigned int operandNumber = operand->getOperandNumber();
     auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
-    VectorType extractSrcType = extractOp.getVectorType();
+    VectorType extractSrcType = extractOp.getSourceVectorType();
     bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
     Type elType = extractSrcType.getElementType();
     VectorType distributedVecType;

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 4918abb338c35..722db6a1790fb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -48,7 +48,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
     // vector.extract_strided_slice requires the input and output vector to have
     // the same rank. Here we drop leading one dimensions from the input vector
     // type to make sure we don't cause mismatch.
-    VectorType oldSrcType = extractOp.getVectorType();
+    VectorType oldSrcType = extractOp.getSourceVectorType();
     VectorType newSrcType = trimLeadingOneDims(oldSrcType);
 
     if (newSrcType.getRank() == oldSrcType.getRank())

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 357582b4052d9..b5102461e4c24 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -266,7 +266,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
   LogicalResult matchAndRewrite(vector::BroadcastOp op,
                                 PatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
-    VectorType dstType = op.getVectorType();
+    VectorType dstType = op.getResultVectorType();
     VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
     Type eltType = dstType.getElementType();
 
@@ -404,8 +404,8 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
     auto loc = op.getLoc();
 
     Value input = op.getVector();
-    VectorType inputType = op.getVectorType();
-    VectorType resType = op.getResultType();
+    VectorType inputType = op.getSourceVectorType();
+    VectorType resType = op.getResultVectorType();
 
     // Set up convenience transposition table.
     SmallVector<int64_t> transp;
@@ -492,7 +492,7 @@ class TransposeOp2DToShuffleLowering
                                 PatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
 
-    VectorType srcType = op.getVectorType();
+    VectorType srcType = op.getSourceVectorType();
     if (srcType.getRank() != 2)
       return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
 
@@ -518,8 +518,8 @@ class TransposeOp2DToShuffleLowering
 
     Value shuffled =
         rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
-                                                     shuffled);
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        op, op.getResultVectorType(), shuffled);
 
     return success();
   }
@@ -552,7 +552,7 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
 
     VectorType lhsType = op.getOperandVectorTypeLHS();
     VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
-    VectorType resType = op.getVectorType();
+    VectorType resType = op.getResultVectorType();
     Type eltType = resType.getElementType();
     bool isInt = eltType.isa<IntegerType, IndexType>();
     Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
@@ -1208,15 +1208,16 @@ struct CombineContractBroadcast
         continue;
       // contractionOp can only take vector as operands.
       auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
-      if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
+      if (!srcType ||
+          srcType.getRank() == broadcast.getResultVectorType().getRank())
         continue;
       int64_t rankDiff =
-          broadcast.getVectorType().getRank() - srcType.getRank();
+          broadcast.getResultVectorType().getRank() - srcType.getRank();
       bool innerDimBroadcast = false;
       SmallVector<AffineExpr> originalDims;
       for (const auto &dim : llvm::enumerate(srcType.getShape())) {
-        if (dim.value() !=
-            broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
+        if (dim.value() != broadcast.getResultVectorType().getDimSize(
+                               rankDiff + dim.index())) {
           innerDimBroadcast = true;
           break;
         }
@@ -1232,7 +1233,7 @@ struct CombineContractBroadcast
       // of non-unit size.
       bool nonUnitDimReductionBroadcast = false;
       for (int64_t i = 0; i < rankDiff; ++i) {
-        if (broadcast.getVectorType().getDimSize(i) != 1 &&
+        if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
             isReductionIterator(contractOp.getIteratorTypes()
                                     .getValue()[map.getDimPosition(i)])) {
           nonUnitDimReductionBroadcast = true;
@@ -1243,8 +1244,8 @@ struct CombineContractBroadcast
         continue;
 
       AffineMap broadcastMap =
-          AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
-                         contractOp.getContext());
+          AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
+                         originalDims, contractOp.getContext());
       map = broadcastMap.compose(map);
       *operand = broadcast.getSource();
       changed = true;
@@ -1363,7 +1364,7 @@ struct ReorderElementwiseOpsOnTranspose final
       auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
       if (transposeOp) {
         transposeMaps.push_back(transposeOp.getTransp());
-        srcType = transposeOp.getVectorType();
+        srcType = transposeOp.getSourceVectorType();
       } else if (!matchPattern(operand, m_Constant())) {
         return failure();
       }
@@ -2376,7 +2377,7 @@ struct BubbleDownVectorBitCastForExtract
   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
                                 PatternRewriter &rewriter) const override {
     // Only support extracting scalars for now.
-    if (extractOp.getVectorType().getRank() != 1)
+    if (extractOp.getSourceVectorType().getRank() != 1)
       return failure();
 
     auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
@@ -2468,7 +2469,7 @@ struct BubbleDownBitCastForStridedSliceExtract
                      [](const APInt &val) { return !val.isOneValue(); }))
       return failure();
 
-    unsigned rank = extractOp.getVectorType().getRank();
+    unsigned rank = extractOp.getSourceVectorType().getRank();
     assert(castDstLastDim % castSrcLastDim == 0);
     int64_t expandRatio = castDstLastDim / castSrcLastDim;
 

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 0872d10ca204e..b2d00255b0d7f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -602,12 +602,12 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
 
   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
                                 PatternRewriter &rewriter) const override {
-    if (transposeOp.getResultType().getRank() == 0)
+    if (transposeOp.getResultVectorType().getRank() == 0)
       return failure();
     auto targetShape = getTargetShape(options, transposeOp);
     if (!targetShape)
       return failure();
-    auto originalVectorType = transposeOp.getResultType();
+    auto originalVectorType = transposeOp.getResultVectorType();
     SmallVector<int64_t> strides(targetShape->size(), 1);
     Location loc = transposeOp.getLoc();
     ArrayRef<int64_t> originalSize = originalVectorType.getShape();

diff  --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
index 870532983168d..1558000139f20 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
@@ -252,7 +252,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
 
     // Check if the source vector type is supported. AVX2 patterns can only be
     // applied to f32 vector types with two dimensions greater than one.
-    VectorType srcType = op.getVectorType();
+    VectorType srcType = op.getSourceVectorType();
     if (!srcType.getElementType().isF32())
       return rewriter.notifyMatchFailure(op, "Unsupported vector element type");
 
@@ -287,7 +287,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       // Reshape the n-D input vector with only two dimensions greater than one
       // to a 2-D vector.
       auto flattenedType =
-          VectorType::get({n * m}, op.getVectorType().getElementType());
+          VectorType::get({n * m}, op.getSourceVectorType().getElementType());
       auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
       auto reshInput =
           ib.create<vector::ShapeCastOp>(flattenedType, op.getVector());
@@ -315,7 +315,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       // We have to transpose their dimensions and retrieve its original rank
       // (e.g., 1x8x1x4x1).
       res = ib.create<vector::ShapeCastOp>(flattenedType, res);
-      res = ib.create<vector::ShapeCastOp>(op.getResultType(), res);
+      res = ib.create<vector::ShapeCastOp>(op.getResultVectorType(), res);
       rewriter.replaceOp(op, res);
       return success();
     };


        


More information about the Mlir-commits mailing list