[Mlir-commits] [mlir] 45b25d2 - [mlir][Vector] Disable 'vector.extract' folding for unsupported 0-D vectors

Diego Caballero llvmlistbot at llvm.org
Thu Jun 1 15:23:26 PDT 2023


Author: Diego Caballero
Date: 2023-06-01T22:22:15Z
New Revision: 45b25d24f04dba7e3089453774544459f152ef95

URL: https://github.com/llvm/llvm-project/commit/45b25d24f04dba7e3089453774544459f152ef95
DIFF: https://github.com/llvm/llvm-project/commit/45b25d24f04dba7e3089453774544459f152ef95.diff

LOG: [mlir][Vector] Disable 'vector.extract' folding for unsupported 0-D vectors

The `vector.extract` folding patterns do not support 0-D vectors
(actually, 0-D vector support couldn't even be implemented as a folding
pattern as it would require replacing `vector.extract` with a
`vector.extractelement` op). This patch is bailing out folding when 0-D
vectors are found.

Reviewed By: nicolasvasilache, hanchung

Differential Revision: https://reviews.llvm.org/D151847

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 20c088c2acfe1..acccd66f7c03f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1441,11 +1441,28 @@ Value ExtractFromInsertTransposeChainState::fold() {
   return tryToFoldExtractOpInPlace(valueToExtractFrom);
 }
 
+/// Returns true if the operation has a 0-D vector type operand or result.
+static bool hasZeroDimVectors(Operation *op) {
+  auto hasZeroDimVectorType = [](Type type) -> bool {
+    auto vecType = dyn_cast<VectorType>(type);
+    return vecType && vecType.getRank() == 0;
+  };
+
+  return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
+         llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
+}
+
 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
   Operation *defOp = extractOp.getVector().getDefiningOp();
   if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
     return Value();
+
+  // 0-D vectors not supported.
+  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
+  if (hasZeroDimVectors(defOp))
+    return Value();
+
   Value source = defOp->getOperand(0);
   if (extractOp.getType() == source.getType())
     return source;
@@ -1497,6 +1514,12 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
   auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
   if (!shapeCastOp)
     return Value();
+
+  // 0-D vectors not supported.
+  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
+  if (hasZeroDimVectors(shapeCastOp))
+    return Value();
+
   // Get the nth dimension size starting from lowest dimension.
   auto getDimReverse = [](VectorType type, int64_t n) {
     return type.getShape().take_back(n + 1).front();
@@ -1559,6 +1582,12 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
       extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
   if (!extractStridedSliceOp)
     return Value();
+
+  // 0-D vectors not supported.
+  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
+  if (hasZeroDimVectors(extractStridedSliceOp))
+    return Value();
+
   // Return if 'extractStridedSliceOp' has non-unit strides.
   if (extractStridedSliceOp.hasNonUnitStrides())
     return Value();
@@ -1595,18 +1624,27 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
 }
 
 /// Fold extract_op fed from a chain of insertStridedSlice ops.
-static Value foldExtractStridedOpFromInsertChain(ExtractOp op) {
-  int64_t destinationRank = llvm::isa<VectorType>(op.getType())
-                                ? llvm::cast<VectorType>(op.getType()).getRank()
-                                : 0;
-  auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
+static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
+  int64_t destinationRank =
+      llvm::isa<VectorType>(extractOp.getType())
+          ? llvm::cast<VectorType>(extractOp.getType()).getRank()
+          : 0;
+  auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
+  if (!insertOp)
+    return Value();
+
+  // 0-D vectors not supported.
+  assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
+  if (hasZeroDimVectors(insertOp))
+    return Value();
+
   while (insertOp) {
     int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
                              insertOp.getSourceVectorType().getRank();
     if (destinationRank > insertOp.getSourceVectorType().getRank())
       return Value();
     auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
-    auto extractOffsets = extractVector<int64_t>(op.getPosition());
+    auto extractOffsets = extractVector<int64_t>(extractOp.getPosition());
 
     if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
           return llvm::cast<IntegerAttr>(attr).getInt() != 1;
@@ -1643,12 +1681,12 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp op) {
                                                     insertRankDiff))
           return Value();
       }
-      op.getVectorMutable().assign(insertOp.getSource());
+      extractOp.getVectorMutable().assign(insertOp.getSource());
       // OpBuilder is only used as a helper to build an I64ArrayAttr.
-      OpBuilder b(op.getContext());
-      op->setAttr(ExtractOp::getPositionAttrStrName(),
-                  b.getI64ArrayAttr(offsetDiffs));
-      return op.getResult();
+      OpBuilder b(extractOp.getContext());
+      extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
+                         b.getI64ArrayAttr(offsetDiffs));
+      return extractOp.getResult();
     }
     // If the chunk extracted is disjoint from the chunk inserted, keep
     // looking in the insert chain.

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 739ab00fa43f9..d715f9acbb3c6 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -650,8 +650,7 @@ func.func @fold_extract_shapecast(%arg0 : vector<5x1x3x2xf32>,
 //       CHECK:   %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
 //       CHECK:   %[[R:.*]] = vector.extract %[[V]][1] : vector<2x4x2xf32>
 //       CHECK:   return %[[R]] : vector<4x2xf32>
-func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>,
-                             %arg1 : vector<8x4x2xf32>) -> vector<4x2xf32> {
+func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>) -> vector<4x2xf32> {
   %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
   %r = vector.extract %0[1] : vector<2x4x2xf32>
   return %r : vector<4x2xf32>
@@ -659,6 +658,18 @@ func.func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>,
 
 // -----
 
+// CHECK-LABEL: dont_fold_0d_extract_shapecast
+//       CHECK:   %[[V:.*]] = vector.shape_cast %{{.*}} : vector<f32> to vector<1xf32>
+//       CHECK:   %[[R:.*]] = vector.extract %[[V]][0] : vector<1xf32>
+//       CHECK:   return %[[R]] : f32
+func.func @dont_fold_0d_extract_shapecast(%arg0 : vector<f32>) -> f32 {
+  %0 = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
+  %r = vector.extract %0[0] : vector<1xf32>
+  return %r : f32
+}
+
+// -----
+
 // CHECK-LABEL: dont_fold_expand_collapse
 //       CHECK:   %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
 //       CHECK:   %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
@@ -2159,4 +2170,3 @@ func.func @all_true_vector_mask(%a : vector<3x4xf32>) -> vector<3x4xf32> {
   %0 = vector.mask %all_true { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
   return %0 : vector<3x4xf32>
 }
-


        


More information about the Mlir-commits mailing list