[Mlir-commits] [mlir] 0bd83f9 - Revert "Don't attempt to create vectors with complex element types."

Johannes Reifferscheid llvmlistbot at llvm.org
Thu Jan 12 02:35:17 PST 2023


Author: Johannes Reifferscheid
Date: 2023-01-12T11:35:02+01:00
New Revision: 0bd83f949bb871a40eddc5eae12e561c33762fca

URL: https://github.com/llvm/llvm-project/commit/0bd83f949bb871a40eddc5eae12e561c33762fca
DIFF: https://github.com/llvm/llvm-project/commit/0bd83f949bb871a40eddc5eae12e561c33762fca.diff

LOG: Revert "Don't attempt to create vectors with complex element types."

This reverts commit 91181db6d6fd896f01e1e89786d6d7d3d09a911e.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorize-convolution.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 90ed3ae229c1d..a17663b101912 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -44,20 +44,6 @@ using namespace mlir::linalg;
 static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
                                                    LinalgOp convOp);
 
-/// Return the given vector type if `elementType` is valid.
-static FailureOr<VectorType> getVectorType(ArrayRef<int64_t> shape,
-                                           Type elementType) {
-  if (!VectorType::isValidElementType(elementType)) {
-    return failure();
-  }
-  return VectorType::get(shape, elementType);
-}
-
-/// Cast the given type to a vector type if its element type is valid.
-static FailureOr<VectorType> getVectorType(ShapedType type) {
-  return getVectorType(type.getShape(), type.getElementType());
-}
-
 /// Return the unique instance of OpType in `block` if it is indeed unique.
 /// Return null if none or more than 1 instances exist.
 template <typename OpType>
@@ -445,7 +431,8 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
       vector::BroadcastableToResult::Success)
     return value;
   Location loc = b.getInsertionPoint()->getLoc();
-  return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
+  return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType,
+                                                    value);
 }
 
 /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
@@ -898,20 +885,18 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
     }
 
     auto readType =
-        getVectorType(readVecShape, getElementTypeOrSelf(opOperand->get()));
-    if (!succeeded(readType))
-      return failure();
+        VectorType::get(readVecShape, getElementTypeOrSelf(opOperand->get()));
     SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
 
     Operation *read = rewriter.create<vector::TransferReadOp>(
-        loc, *readType, opOperand->get(), indices, readMap);
+        loc, readType, opOperand->get(), indices, readMap);
     read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
     Value readValue = read->getResult(0);
 
     // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
     // will be in-bounds.
     if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
-      SmallVector<bool> inBounds(readType->getRank(), true);
+      SmallVector<bool> inBounds(readType.getRank(), true);
       cast<vector::TransferReadOp>(maskOp.getMaskableOp())
           .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
     }
@@ -1150,22 +1135,21 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
   if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
     return failure();
 
-  auto readType = getVectorType(srcType);
-  auto writeType = getVectorType(dstType);
-  if (!(succeeded(readType) && succeeded(writeType)))
-    return failure();
+  auto readType =
+      VectorType::get(srcType.getShape(), getElementTypeOrSelf(srcType));
+  auto writeType =
+      VectorType::get(dstType.getShape(), getElementTypeOrSelf(dstType));
 
   Location loc = copyOp->getLoc();
   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   SmallVector<Value> indices(srcType.getRank(), zero);
 
   Value readValue = rewriter.create<vector::TransferReadOp>(
-      loc, *readType, copyOp.getSource(), indices,
+      loc, readType, copyOp.getSource(), indices,
       rewriter.getMultiDimIdentityMap(srcType.getRank()));
   if (readValue.getType().cast<VectorType>().getRank() == 0) {
     readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
-    readValue =
-        rewriter.create<vector::BroadcastOp>(loc, *writeType, readValue);
+    readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
   }
   Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
       loc, readValue, copyOp.getTarget(), indices,
@@ -1216,10 +1200,6 @@ struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
     auto sourceType = padOp.getSourceType();
     auto resultType = padOp.getResultType();
 
-    // Complex is not a valid vector element type.
-    if (!VectorType::isValidElementType(sourceType.getElementType()))
-      return failure();
-
     // Copy cannot be vectorized if pad value is non-constant and source shape
     // is dynamic. In case of a dynamic source shape, padding must be appended
     // by TransferReadOp, but TransferReadOp supports only constant padding.
@@ -1568,17 +1548,15 @@ struct PadOpVectorizationWithInsertSlicePattern
     if (insertOp.getDest() == padOp.getResult())
       return failure();
 
-    auto vecType = getVectorType(padOp.getType());
-    if (!succeeded(vecType))
-      return failure();
-    unsigned vecRank = vecType->getRank();
+    auto vecType = VectorType::get(padOp.getType().getShape(),
+                                   padOp.getType().getElementType());
+    unsigned vecRank = vecType.getRank();
     unsigned tensorRank = insertOp.getType().getRank();
 
     // Check if sizes match: Insert the entire tensor into most minor dims.
     // (No permutations allowed.)
     SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
-    expectedSizes.append(vecType->getShape().begin(),
-                         vecType->getShape().end());
+    expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
     if (!llvm::all_of(
             llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
               return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
@@ -1594,7 +1572,7 @@ struct PadOpVectorizationWithInsertSlicePattern
     SmallVector<Value> readIndices(
         vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
     auto read = rewriter.create<vector::TransferReadOp>(
-        padOp.getLoc(), *vecType, padOp.getSource(), readIndices, padValue);
+        padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
 
     // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
     // specified offsets. Write is fully in-bounds because a InsertSliceOp's
@@ -1902,7 +1880,7 @@ struct Conv1DGenerator
     auto rhsRank = rhsShapedType.getRank();
     switch (oper) {
     case Conv:
-      if (rhsRank != 2 && rhsRank != 3)
+      if (rhsRank != 2 && rhsRank!= 3)
         return;
       break;
     case Pool:
@@ -2004,24 +1982,22 @@ struct Conv1DGenerator
     Type lhsEltType = lhsShapedType.getElementType();
     Type rhsEltType = rhsShapedType.getElementType();
     Type resEltType = resShapedType.getElementType();
-    auto lhsType = getVectorType(lhsShape, lhsEltType);
-    auto rhsType = getVectorType(rhsShape, rhsEltType);
-    auto resType = getVectorType(resShape, resEltType);
-    if (!(succeeded(lhsType) && succeeded(rhsType) && succeeded(resType)))
-      return failure();
+    auto lhsType = VectorType::get(lhsShape, lhsEltType);
+    auto rhsType = VectorType::get(rhsShape, rhsEltType);
+    auto resType = VectorType::get(resShape, resEltType);
     // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0,
     // 0].
     Value lhs = rewriter.create<vector::TransferReadOp>(
-        loc, *lhsType, lhsShaped, ValueRange{zero, zero, zero});
+        loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
     // Read rhs slice of size {kw, c, f} @ [0, 0, 0].
     // This is needed only for Conv.
     Value rhs = nullptr;
     if (oper == Conv)
       rhs = rewriter.create<vector::TransferReadOp>(
-          loc, *rhsType, rhsShaped, ValueRange{zero, zero, zero});
+          loc, rhsType, rhsShaped, ValueRange{zero, zero, zero});
     // Read res slice of size {n, w, f} @ [0, 0, 0].
     Value res = rewriter.create<vector::TransferReadOp>(
-        loc, *resType, resShaped, ValueRange{zero, zero, zero});
+        loc, resType, resShaped, ValueRange{zero, zero, zero});
 
     // The base vectorization case is input: {n,w,c}, weight: {kw,c,f}, output:
     // {n,w,f}. To reuse the base pattern vectorization case, we do pre
@@ -2188,28 +2164,26 @@ struct Conv1DGenerator
     Type lhsEltType = lhsShapedType.getElementType();
     Type rhsEltType = rhsShapedType.getElementType();
     Type resEltType = resShapedType.getElementType();
-    auto lhsType = getVectorType(
+    VectorType lhsType = VectorType::get(
         {nSize,
          // iw = ow * sw + kw *  dw - 1
          //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
          ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
          cSize},
         lhsEltType);
-    auto rhsType = getVectorType({kwSize, cSize}, rhsEltType);
-    auto resType = getVectorType({nSize, wSize, cSize}, resEltType);
-    if (!(succeeded(lhsType) && succeeded(rhsType) && succeeded(resType)))
-      return failure();
+    VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
+    VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
 
     // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
     // 0].
     Value lhs = rewriter.create<vector::TransferReadOp>(
-        loc, *lhsType, lhsShaped, ValueRange{zero, zero, zero});
+        loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
     // Read rhs slice of size {kw, c} @ [0, 0].
-    Value rhs = rewriter.create<vector::TransferReadOp>(
-        loc, *rhsType, rhsShaped, ValueRange{zero, zero});
+    Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
+                                                        ValueRange{zero, zero});
     // Read res slice of size {n, w, c} @ [0, 0, 0].
     Value res = rewriter.create<vector::TransferReadOp>(
-        loc, *resType, resShaped, ValueRange{zero, zero, zero});
+        loc, resType, resShaped, ValueRange{zero, zero, zero});
 
     //===------------------------------------------------------------------===//
     // Begin vector-only rewrite part

diff  --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index c572b02be4ee6..91d822d804c05 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -850,18 +850,3 @@ func.func @pooling_ncw_sum_memref_2_3_2_1(%input: memref<4x2x5xf32>, %filter: me
 // CHECK: %[[V7:.+]] = arith.addf %[[V5]], %[[V6]] : vector<4x3x2xf32>
 // CHECK: %[[V8:.+]] = vector.transpose %[[V7]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32>
 // CHECK: vector.transfer_write %[[V8:.+]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32>
-
-
-// -----
-
-func.func @pooling_ncw_sum_memref_complex(%input: memref<4x2x5xcomplex<f32>>,
-     %filter: memref<2xcomplex<f32>>, %output: memref<4x2x3xcomplex<f32>>) {
-  linalg.pooling_ncw_sum
-    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
-    ins(%input, %filter : memref<4x2x5xcomplex<f32>>, memref<2xcomplex<f32>>)
-    outs(%output : memref<4x2x3xcomplex<f32>>)
-  return
-}
-
-// Regression test: just check that this lowers successfully
-// CHECK-LABEL: @pooling_ncw_sum_memref_complex


        


More information about the Mlir-commits mailing list