[Mlir-commits] [mlir] 0bd83f9 - Revert "Don't attempt to create vectors with complex element types."
Johannes Reifferscheid
llvmlistbot at llvm.org
Thu Jan 12 02:35:17 PST 2023
Author: Johannes Reifferscheid
Date: 2023-01-12T11:35:02+01:00
New Revision: 0bd83f949bb871a40eddc5eae12e561c33762fca
URL: https://github.com/llvm/llvm-project/commit/0bd83f949bb871a40eddc5eae12e561c33762fca
DIFF: https://github.com/llvm/llvm-project/commit/0bd83f949bb871a40eddc5eae12e561c33762fca.diff
LOG: Revert "Don't attempt to create vectors with complex element types."
This reverts commit 91181db6d6fd896f01e1e89786d6d7d3d09a911e.
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorize-convolution.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 90ed3ae229c1d..a17663b101912 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -44,20 +44,6 @@ using namespace mlir::linalg;
static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
LinalgOp convOp);
-/// Return the given vector type if `elementType` is valid.
-static FailureOr<VectorType> getVectorType(ArrayRef<int64_t> shape,
- Type elementType) {
- if (!VectorType::isValidElementType(elementType)) {
- return failure();
- }
- return VectorType::get(shape, elementType);
-}
-
-/// Cast the given type to a vector type if its element type is valid.
-static FailureOr<VectorType> getVectorType(ShapedType type) {
- return getVectorType(type.getShape(), type.getElementType());
-}
-
/// Return the unique instance of OpType in `block` if it is indeed unique.
/// Return null if none or more than 1 instances exist.
template <typename OpType>
@@ -445,7 +431,8 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
vector::BroadcastableToResult::Success)
return value;
Location loc = b.getInsertionPoint()->getLoc();
- return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
+ return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType,
+ value);
}
/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
@@ -898,20 +885,18 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
}
auto readType =
- getVectorType(readVecShape, getElementTypeOrSelf(opOperand->get()));
- if (!succeeded(readType))
- return failure();
+ VectorType::get(readVecShape, getElementTypeOrSelf(opOperand->get()));
SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
Operation *read = rewriter.create<vector::TransferReadOp>(
- loc, *readType, opOperand->get(), indices, readMap);
+ loc, readType, opOperand->get(), indices, readMap);
read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
Value readValue = read->getResult(0);
// 3.b. If masked, set in-bounds to true. Masking guarantees that the access
// will be in-bounds.
if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
- SmallVector<bool> inBounds(readType->getRank(), true);
+ SmallVector<bool> inBounds(readType.getRank(), true);
cast<vector::TransferReadOp>(maskOp.getMaskableOp())
.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
}
@@ -1150,22 +1135,21 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
return failure();
- auto readType = getVectorType(srcType);
- auto writeType = getVectorType(dstType);
- if (!(succeeded(readType) && succeeded(writeType)))
- return failure();
+ auto readType =
+ VectorType::get(srcType.getShape(), getElementTypeOrSelf(srcType));
+ auto writeType =
+ VectorType::get(dstType.getShape(), getElementTypeOrSelf(dstType));
Location loc = copyOp->getLoc();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> indices(srcType.getRank(), zero);
Value readValue = rewriter.create<vector::TransferReadOp>(
- loc, *readType, copyOp.getSource(), indices,
+ loc, readType, copyOp.getSource(), indices,
rewriter.getMultiDimIdentityMap(srcType.getRank()));
if (readValue.getType().cast<VectorType>().getRank() == 0) {
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
- readValue =
- rewriter.create<vector::BroadcastOp>(loc, *writeType, readValue);
+ readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
}
Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
loc, readValue, copyOp.getTarget(), indices,
@@ -1216,10 +1200,6 @@ struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
auto sourceType = padOp.getSourceType();
auto resultType = padOp.getResultType();
- // Complex is not a valid vector element type.
- if (!VectorType::isValidElementType(sourceType.getElementType()))
- return failure();
-
// Copy cannot be vectorized if pad value is non-constant and source shape
// is dynamic. In case of a dynamic source shape, padding must be appended
// by TransferReadOp, but TransferReadOp supports only constant padding.
@@ -1568,17 +1548,15 @@ struct PadOpVectorizationWithInsertSlicePattern
if (insertOp.getDest() == padOp.getResult())
return failure();
- auto vecType = getVectorType(padOp.getType());
- if (!succeeded(vecType))
- return failure();
- unsigned vecRank = vecType->getRank();
+ auto vecType = VectorType::get(padOp.getType().getShape(),
+ padOp.getType().getElementType());
+ unsigned vecRank = vecType.getRank();
unsigned tensorRank = insertOp.getType().getRank();
// Check if sizes match: Insert the entire tensor into most minor dims.
// (No permutations allowed.)
SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
- expectedSizes.append(vecType->getShape().begin(),
- vecType->getShape().end());
+ expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
if (!llvm::all_of(
llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
@@ -1594,7 +1572,7 @@ struct PadOpVectorizationWithInsertSlicePattern
SmallVector<Value> readIndices(
vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
auto read = rewriter.create<vector::TransferReadOp>(
- padOp.getLoc(), *vecType, padOp.getSource(), readIndices, padValue);
+ padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
// Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
// specified offsets. Write is fully in-bounds because a InsertSliceOp's
@@ -1902,7 +1880,7 @@ struct Conv1DGenerator
auto rhsRank = rhsShapedType.getRank();
switch (oper) {
case Conv:
- if (rhsRank != 2 && rhsRank != 3)
+ if (rhsRank != 2 && rhsRank!= 3)
return;
break;
case Pool:
@@ -2004,24 +1982,22 @@ struct Conv1DGenerator
Type lhsEltType = lhsShapedType.getElementType();
Type rhsEltType = rhsShapedType.getElementType();
Type resEltType = resShapedType.getElementType();
- auto lhsType = getVectorType(lhsShape, lhsEltType);
- auto rhsType = getVectorType(rhsShape, rhsEltType);
- auto resType = getVectorType(resShape, resEltType);
- if (!(succeeded(lhsType) && succeeded(rhsType) && succeeded(resType)))
- return failure();
+ auto lhsType = VectorType::get(lhsShape, lhsEltType);
+ auto rhsType = VectorType::get(rhsShape, rhsEltType);
+ auto resType = VectorType::get(resShape, resEltType);
// Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0,
// 0].
Value lhs = rewriter.create<vector::TransferReadOp>(
- loc, *lhsType, lhsShaped, ValueRange{zero, zero, zero});
+ loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
// Read rhs slice of size {kw, c, f} @ [0, 0, 0].
// This is needed only for Conv.
Value rhs = nullptr;
if (oper == Conv)
rhs = rewriter.create<vector::TransferReadOp>(
- loc, *rhsType, rhsShaped, ValueRange{zero, zero, zero});
+ loc, rhsType, rhsShaped, ValueRange{zero, zero, zero});
// Read res slice of size {n, w, f} @ [0, 0, 0].
Value res = rewriter.create<vector::TransferReadOp>(
- loc, *resType, resShaped, ValueRange{zero, zero, zero});
+ loc, resType, resShaped, ValueRange{zero, zero, zero});
// The base vectorization case is input: {n,w,c}, weight: {kw,c,f}, output:
// {n,w,f}. To reuse the base pattern vectorization case, we do pre
@@ -2188,28 +2164,26 @@ struct Conv1DGenerator
Type lhsEltType = lhsShapedType.getElementType();
Type rhsEltType = rhsShapedType.getElementType();
Type resEltType = resShapedType.getElementType();
- auto lhsType = getVectorType(
+ VectorType lhsType = VectorType::get(
{nSize,
// iw = ow * sw + kw * dw - 1
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
cSize},
lhsEltType);
- auto rhsType = getVectorType({kwSize, cSize}, rhsEltType);
- auto resType = getVectorType({nSize, wSize, cSize}, resEltType);
- if (!(succeeded(lhsType) && succeeded(rhsType) && succeeded(resType)))
- return failure();
+ VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
+ VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
// 0].
Value lhs = rewriter.create<vector::TransferReadOp>(
- loc, *lhsType, lhsShaped, ValueRange{zero, zero, zero});
+ loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
// Read rhs slice of size {kw, c} @ [0, 0].
- Value rhs = rewriter.create<vector::TransferReadOp>(
- loc, *rhsType, rhsShaped, ValueRange{zero, zero});
+ Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
+ ValueRange{zero, zero});
// Read res slice of size {n, w, c} @ [0, 0, 0].
Value res = rewriter.create<vector::TransferReadOp>(
- loc, *resType, resShaped, ValueRange{zero, zero, zero});
+ loc, resType, resShaped, ValueRange{zero, zero, zero});
//===------------------------------------------------------------------===//
// Begin vector-only rewrite part
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
index c572b02be4ee6..91d822d804c05 100644
--- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir
@@ -850,18 +850,3 @@ func.func @pooling_ncw_sum_memref_2_3_2_1(%input: memref<4x2x5xf32>, %filter: me
// CHECK: %[[V7:.+]] = arith.addf %[[V5]], %[[V6]] : vector<4x3x2xf32>
// CHECK: %[[V8:.+]] = vector.transpose %[[V7]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32>
// CHECK: vector.transfer_write %[[V8:.+]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32>
-
-
-// -----
-
-func.func @pooling_ncw_sum_memref_complex(%input: memref<4x2x5xcomplex<f32>>,
- %filter: memref<2xcomplex<f32>>, %output: memref<4x2x3xcomplex<f32>>) {
- linalg.pooling_ncw_sum
- {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
- ins(%input, %filter : memref<4x2x5xcomplex<f32>>, memref<2xcomplex<f32>>)
- outs(%output : memref<4x2x3xcomplex<f32>>)
- return
-}
-
-// Regression test: just check that this lowers successfully
-// CHECK-LABEL: @pooling_ncw_sum_memref_complex
More information about the Mlir-commits
mailing list