[Mlir-commits] [mlir] 92e83af - [mlir][vector] Fold extractOp coming from broadcastOp

Thomas Raoux llvmlistbot at llvm.org
Tue Oct 6 10:28:10 PDT 2020


Author: Thomas Raoux
Date: 2020-10-06T10:27:39-07:00
New Revision: 92e83afe44fbfd81ffd428bb41b7f760eee712f9

URL: https://github.com/llvm/llvm-project/commit/92e83afe44fbfd81ffd428bb41b7f760eee712f9
DIFF: https://github.com/llvm/llvm-project/commit/92e83afe44fbfd81ffd428bb41b7f760eee712f9.diff

LOG: [mlir][vector] Fold extractOp coming from broadcastOp

Combine ExtractOp with scalar result with BroadcastOp source. This is useful to
be able to incrementally convert degenerated vector of one element into scalar.

Differential Revision: https://reviews.llvm.org/D88751

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 672ad4058309a..b71102cde1cf6 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -812,6 +812,37 @@ static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
   return Value();
 }
 
+/// Fold extractOp with scalar result coming from BroadcastOp.
+static Value foldExtractFromBroadcast(ExtractOp extractOp) {
+  auto broadcastOp = extractOp.vector().getDefiningOp<vector::BroadcastOp>();
+  if (!broadcastOp)
+    return Value();
+  if (extractOp.getType() == broadcastOp.getSourceType())
+    return broadcastOp.source();
+  auto getRank = [](Type type) {
+    return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
+  };
+  unsigned broadcasrSrcRank = getRank(broadcastOp.getSourceType());
+  unsigned extractResultRank = getRank(extractOp.getType());
+  if (extractResultRank < broadcasrSrcRank) {
+    auto extractPos = extractVector<int64_t>(extractOp.position());
+    unsigned rankDiff = broadcasrSrcRank - extractResultRank;
+    extractPos.erase(
+        extractPos.begin(),
+        std::next(extractPos.begin(), extractPos.size() - rankDiff));
+    extractOp.setOperand(broadcastOp.source());
+    // OpBuilder is only used as a helper to build an I64ArrayAttr.
+    OpBuilder b(extractOp.getContext());
+    extractOp.setAttr(ExtractOp::getPositionAttrName(),
+                      b.getI64ArrayAttr(extractPos));
+    return extractOp.getResult();
+  }
+  // TODO: In case the rank of the broadcast source is greater than the rank of
+  // the extract result this can be combined into a new broadcast op. This needs
+  // to be added a canonicalization pattern if needed.
+  return Value();
+}
+
 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
   if (succeeded(foldExtractOpFromExtractChain(*this)))
     return getResult();
@@ -819,6 +850,8 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
     return getResult();
   if (auto val = foldExtractOpFromInsertChainAndTranspose(*this))
     return val;
+  if (auto val = foldExtractFromBroadcast(*this))
+    return val;
   return OpFoldResult();
 }
 

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 9c36f7684baf9..2f927a1bbc810 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -348,6 +348,54 @@ func @fold_extract_transpose(
 
 // -----
 
+// CHECK-LABEL: fold_extract_broadcast
+//  CHECK-SAME:   %[[A:.*]]: f32
+//       CHECK:   return %[[A]] : f32
+func @fold_extract_broadcast(%a : f32) -> f32 {
+  %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
+  %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32>
+  return %r : f32
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_broadcast_vector
+//  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
+//       CHECK:   return %[[A]] : vector<4xf32>
+func @fold_extract_broadcast_vector(%a : vector<4xf32>) -> vector<4xf32> {
+  %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
+  %r = vector.extract %b[0, 1] : vector<1x2x4xf32>
+  return %r : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_broadcast
+//  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
+//       CHECK:   %[[R:.*]] = vector.extract %[[A]][2] : vector<4xf32>
+//       CHECK:   return %[[R]] : f32
+func @fold_extract_broadcast(%a : vector<4xf32>) -> f32 {
+  %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
+  %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32>
+  return %r : f32
+}
+
+// -----
+
+// Negative test for extract_op folding when the type of broadcast source
+// doesn't match the type of vector.extract.
+// CHECK-LABEL: fold_extract_broadcast_negative
+//       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<1x2x4xf32>
+//       CHECK:   %[[R:.*]] = vector.extract %[[B]][0, 1] : vector<1x2x4xf32>
+//       CHECK:   return %[[R]] : vector<4xf32>
+func @fold_extract_broadcast_negative(%a : f32) -> vector<4xf32> {
+  %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
+  %r = vector.extract %b[0, 1] : vector<1x2x4xf32>
+  return %r : vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfers
 func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vector<4x9xf32>) {
   %c0 = constant 0 : index


        


More information about the Mlir-commits mailing list