[Mlir-commits] [mlir] 45b25d2 - [mlir][Vector] Disable 'vector.extract' folding for unsupported 0-D vectors
Diego Caballero
llvmlistbot at llvm.org
Thu Jun 1 15:23:26 PDT 2023
Author: Diego Caballero
Date: 2023-06-01T22:22:15Z
New Revision: 45b25d24f04dba7e3089453774544459f152ef95
URL: https://github.com/llvm/llvm-project/commit/45b25d24f04dba7e3089453774544459f152ef95
DIFF: https://github.com/llvm/llvm-project/commit/45b25d24f04dba7e3089453774544459f152ef95.diff
LOG: [mlir][Vector] Disable 'vector.extract' folding for unsupported 0-D vectors
The `vector.extract` folding patterns do not support 0-D vectors
(actually, 0-D vector support couldn't even be implemented as a folding
pattern as it would require replacing `vector.extract` with a
`vector.extractelement` op). This patch is bailing out folding when 0-D
vectors are found.
Reviewed By: nicolasvasilache, hanchung
Differential Revision: https://reviews.llvm.org/D151847
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 20c088c2acfe1..acccd66f7c03f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1441,11 +1441,28 @@ Value ExtractFromInsertTransposeChainState::fold() {
return tryToFoldExtractOpInPlace(valueToExtractFrom);
}
+/// Returns true if the operation has a 0-D vector type operand or result.
+static bool hasZeroDimVectors(Operation *op) {
+ auto hasZeroDimVectorType = [](Type type) -> bool {
+ auto vecType = dyn_cast<VectorType>(type);
+ return vecType && vecType.getRank() == 0;
+ };
+
+ return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
+ llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
+}
+
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
Operation *defOp = extractOp.getVector().getDefiningOp();
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
return Value();
+
+ // 0-D vectors not supported.
+ assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
+ if (hasZeroDimVectors(defOp))
+ return Value();
+
Value source = defOp->getOperand(0);
if (extractOp.getType() == source.getType())
return source;
@@ -1497,6 +1514,12 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
if (!shapeCastOp)
return Value();
+
+ // 0-D vectors not supported.
+ assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
+ if (hasZeroDimVectors(shapeCastOp))
+ return Value();
+
// Get the nth dimension size starting from lowest dimension.
auto getDimReverse = [](VectorType type, int64_t n) {
return type.getShape().take_back(n + 1).front();
@@ -1559,6 +1582,12 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
if (!extractStridedSliceOp)
return Value();
+
+ // 0-D vectors not supported.
+ assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
+ if (hasZeroDimVectors(extractStridedSliceOp))
+ return Value();
+
// Return if 'extractStridedSliceOp' has non-unit strides.
if (extractStridedSliceOp.hasNonUnitStrides())
return Value();
@@ -1595,18 +1624,27 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
}
/// Fold extract_op fed from a chain of insertStridedSlice ops.
-static Value foldExtractStridedOpFromInsertChain(ExtractOp op) {
- int64_t destinationRank = llvm::isa<VectorType>(op.getType())
- ? llvm::cast<VectorType>(op.getType()).getRank()
- : 0;
- auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
+static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
+ int64_t destinationRank =
+ llvm::isa<VectorType>(extractOp.getType())
+ ? llvm::cast<VectorType>(extractOp.getType()).getRank()
+ : 0;
+ auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
+ if (!insertOp)
+ return Value();
+
+ // 0-D vectors not supported.
+ assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
+ if (hasZeroDimVectors(insertOp))
+ return Value();
+
while (insertOp) {
int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
insertOp.getSourceVectorType().getRank();
if (destinationRank > insertOp.getSourceVectorType().getRank())
return Value();
auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
- auto extractOffsets = extractVector<int64_t>(op.getPosition());
+ auto extractOffsets = extractVector<int64_t>(extractOp.getPosition());
if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
return llvm::cast<IntegerAttr>(attr).getInt() != 1;
@@ -1643,12 +1681,12 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp op) {
insertRankDiff))
return Value();
}
- op.getVectorMutable().assign(insertOp.getSource());
+ extractOp.getVectorMutable().assign(insertOp.getSource());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
- OpBuilder b(op.getContext());
- op->setAttr(ExtractOp::getPositionAttrStrName(),
- b.getI64ArrayAttr(offsetDiffs));
- return op.getResult();
+ OpBuilder b(extractOp.getContext());
+ extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
+ b.getI64ArrayAttr(offsetDiffs));
+ return extractOp.getResult();
}
// If the chunk extracted is disjoint from the chunk inserted, keep
// looking in the insert chain.
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 739ab00fa43f9..d715f9acbb3c6 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -650,8 +650,7 @@ func.func @fold_extract_shapecast(%arg0 : vector<5x1x3x2xf32>,
// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
// CHECK: %[[R:.*]] = vector.extract %[[V]][1] : vector<2x4x2xf32>
// CHECK: return %[[R]] : vector<4x2xf32>
-func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>,
- %arg1 : vector<8x4x2xf32>) -> vector<4x2xf32> {
+func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
%r = vector.extract %0[1] : vector<2x4x2xf32>
return %r : vector<4x2xf32>
@@ -659,6 +658,18 @@ func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>,
// -----
+// CHECK-LABEL: dont_fold_0d_extract_shapecast
+// CHECK: %[[V:.*]] = vector.shape_cast %{{.*}} : vector<f32> to vector<1xf32>
+// CHECK: %[[R:.*]] = vector.extract %[[V]][0] : vector<1xf32>
+// CHECK: return %[[R]] : f32
+func.func @dont_fold_0d_extract_shapecast(%arg0 : vector<f32>) -> f32 {
+ %0 = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
+ %r = vector.extract %0[0] : vector<1xf32>
+ return %r : f32
+}
+
+// -----
+
// CHECK-LABEL: dont_fold_expand_collapse
// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
@@ -2159,4 +2170,3 @@ func.func @all_true_vector_mask(%a : vector<3x4xf32>) -> vector<3x4xf32> {
%0 = vector.mask %all_true { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
return %0 : vector<3x4xf32>
}
-
More information about the Mlir-commits
mailing list