[Mlir-commits] [mlir] b4bcef0 - [mlir][vector] Fix bug in extractFromBroadcast folding

Thomas Raoux llvmlistbot at llvm.org
Fri Apr 15 12:21:51 PDT 2022


Author: Thomas Raoux
Date: 2022-04-15T19:21:45Z
New Revision: b4bcef05b7eff074e2db89bbeb856e344f29d45d

URL: https://github.com/llvm/llvm-project/commit/b4bcef05b7eff074e2db89bbeb856e344f29d45d
DIFF: https://github.com/llvm/llvm-project/commit/b4bcef05b7eff074e2db89bbeb856e344f29d45d.diff

LOG: [mlir][vector] Fix bug in extractFromBroadcast folding

extract was incorrectly folded when the source was coming from a
broadcast that was both adding new rank and broadcasting the inner
dimension.

Differential Revision: https://reviews.llvm.org/D123867

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index af174601da570..68cce9b89917d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1292,20 +1292,25 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
   };
   unsigned broadcastSrcRank = getRank(source.getType());
   unsigned extractResultRank = getRank(extractOp.getType());
-  if (extractResultRank < broadcastSrcRank) {
-    auto extractPos = extractVector<int64_t>(extractOp.getPosition());
-    unsigned rankDiff = broadcastSrcRank - extractResultRank;
-    extractPos.erase(
-        extractPos.begin(),
-        std::next(extractPos.begin(), extractPos.size() - rankDiff));
-    extractOp.setOperand(source);
-    // OpBuilder is only used as a helper to build an I64ArrayAttr.
-    OpBuilder b(extractOp.getContext());
-    extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
-                       b.getI64ArrayAttr(extractPos));
-    return extractOp.getResult();
-  }
-  return Value();
+  if (extractResultRank >= broadcastSrcRank)
+    return Value();
+  // Check that the dimension of the result haven't been broadcasted.
+  auto extractVecType = extractOp.getType().dyn_cast<VectorType>();
+  auto broadcastVecType = source.getType().dyn_cast<VectorType>();
+  if (extractVecType && broadcastVecType &&
+      extractVecType.getShape() !=
+          broadcastVecType.getShape().take_back(extractResultRank))
+    return Value();
+  auto extractPos = extractVector<int64_t>(extractOp.getPosition());
+  unsigned rankDiff = broadcastSrcRank - extractResultRank;
+  extractPos.erase(extractPos.begin(),
+                   std::next(extractPos.begin(), extractPos.size() - rankDiff));
+  extractOp.setOperand(source);
+  // OpBuilder is only used as a helper to build an I64ArrayAttr.
+  OpBuilder b(extractOp.getContext());
+  extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
+                     b.getI64ArrayAttr(extractPos));
+  return extractOp.getResult();
 }
 
 // Fold extractOp with source coming from ShapeCast op.

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 608cb026c43ea..37c660e6b3224 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -521,6 +521,17 @@ func @fold_extract_broadcast(%a : f32) -> f32 {
 
 // -----
 
+// CHECK-LABEL: fold_extract_broadcast_negative
+//       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
+//       CHECK:   vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>
+func @fold_extract_broadcast_negative(%a : vector<1x1xf32>) -> vector<4xf32> {
+  %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
+  %r = vector.extract %b[0, 0] : vector<1x1x4xf32>
+  return %r : vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_extract_splat
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32


        


More information about the Mlir-commits mailing list