[Mlir-commits] [mlir] c1a4cd5 - [mlir][linalg] refactor the result handling during vectorization.

Tobias Gysi llvmlistbot at llvm.org
Mon Mar 8 23:27:07 PST 2021


Author: Tobias Gysi
Date: 2021-03-09T07:11:57Z
New Revision: c1a4cd551f1c577008c33d78972929ba6593efcc

URL: https://github.com/llvm/llvm-project/commit/c1a4cd551f1c577008c33d78972929ba6593efcc
DIFF: https://github.com/llvm/llvm-project/commit/c1a4cd551f1c577008c33d78972929ba6593efcc.diff

LOG: [mlir][linalg] refactor the result handling during vectorization.

Return the vectorization results using a vector passed by reference instead of returning them embedded in a structure.

Differential Revision: https://reviews.llvm.org/D98182

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 5c0d1dc3a2fa..8f422d284df6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -263,12 +263,8 @@ Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
                                    OperationFolder *folder = nullptr);
 
 /// Emit a suitable vector form for a Linalg op with fully static shape.
-struct VectorizedLinalgOp {
-  SmallVector<Value> tensorResults;
-  VectorizedLinalgOp &operator=(const VectorizedLinalgOp &) = default;
-};
-Optional<VectorizedLinalgOp> vectorizeLinalgOp(OpBuilder &builder,
-                                               Operation *op);
+LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
+                                SmallVectorImpl<Value> &newResults);
 
 /// Emits a loop nest of `LoopTy` with the proper body for `op`.
 template <typename LoopTy>

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index dd92ccd838cd..7f604807030d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -468,11 +468,11 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
     return failure();
   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
     return failure();
-  Optional<VectorizedLinalgOp> res = vectorizeLinalgOp(rewriter, op);
-  if (!res)
+  SmallVector<Value> newResults;
+  if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
     return failure();
-  if (!res->tensorResults.empty())
-    rewriter.replaceOp(op, res->tensorResults);
+  if (!newResults.empty())
+    rewriter.replaceOp(op, newResults);
   else
     rewriter.eraseOp(op);
   return success();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f471ab0ebd75..48b6165d7b68 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -139,16 +139,16 @@ using CustomVectorizationHook = std::function<VectorizationResult(
     Operation *, const BlockAndValueMapping &)>;
 
 /// Helper function to vectorize the terminator of a `linalgOp`. New result
-/// vector values are appended to `results`.
-/// Return VectorizationStatus::NoReplace to signal the vectorization algorithm
-/// that it should not try to map produced operations: this is the purpose of
-/// the `results` argument to capture such values and make them available for
-/// RAUW to the vectorization algorithm.
-/// This function is meant to be used as a CustomVectorizationHook.
+/// vector values are appended to `newResults`. Return
+/// VectorizationStatus::NoReplace to signal the vectorization algorithm that it
+/// should not try to map produced operations and instead return the results
+/// using the `newResults` vector making them available to the
+/// vectorization algorithm for RAUW. This function is meant to be used as a
+/// CustomVectorizationHook.
 static VectorizationResult
 vectorizeLinalgYield(OpBuilder &builder, Operation *op,
                      const BlockAndValueMapping &bvm, LinalgOp linalgOp,
-                     SmallVectorImpl<Value> &results) {
+                     SmallVectorImpl<Value> &newResults) {
   auto yieldOp = dyn_cast<linalg::YieldOp>(op);
   if (!yieldOp)
     return VectorizationResult{VectorizationStatus::Failure, nullptr};
@@ -156,10 +156,10 @@ vectorizeLinalgYield(OpBuilder &builder, Operation *op,
     // TODO: Scan for an opportunity for reuse.
     // TODO: use a map.
     Value vectorValue = bvm.lookup(outputs.value());
-    Value result = buildVectorWrite(builder, vectorValue,
-                                    linalgOp.getOutput(outputs.index()));
-    if (result)
-      results.push_back(result);
+    Value newResult = buildVectorWrite(builder, vectorValue,
+                                       linalgOp.getOutput(outputs.index()));
+    if (newResult)
+      newResults.push_back(newResult);
   }
   return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
 }
@@ -248,8 +248,8 @@ vectorizeOneOp(OpBuilder &builder, Operation *op,
 ///   TODO: Reuse opportunities for RAR dependencies.
 ///   4. Register CustomVectorizationHook for YieldOp to capture the results.
 ///   5. Iteratively call vectorizeOneOp on the region operations.
-static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
-    OpBuilder &builder, LinalgOp linalgOp,
+LogicalResult vectorizeAsLinalgGeneric(
+    OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults,
     ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
   // 1. Certain Linalg ops do not have a region but only a region builder.
   // If so, build the region so we can vectorize.
@@ -290,11 +290,10 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
   }
 
   // 4. Register CustomVectorizationHook for yieldOp.
-  SmallVector<Value> results;
   CustomVectorizationHook vectorizeYield =
       [&](Operation *op,
           const BlockAndValueMapping &bvm) -> VectorizationResult {
-    return vectorizeLinalgYield(builder, op, bvm, linalgOp, results);
+    return vectorizeLinalgYield(builder, op, bvm, linalgOp, newResults);
   };
   // Append the vectorizeYield hook.
   auto hooks = llvm::to_vector<4>(customVectorizationHooks);
@@ -305,7 +304,7 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
     VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks);
     if (result.status == VectorizationStatus::Failure) {
       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
-      return llvm::None;
+      return failure();
     }
     if (result.status == VectorizationStatus::NewOp) {
       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
@@ -314,7 +313,7 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
     }
   }
 
-  return VectorizedLinalgOp{{results}};
+  return success();
 }
 
 /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
@@ -355,8 +354,8 @@ static bool isElementwise(Operation *op) {
   return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
 }
 
-static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
-                                                         LinalgOp linalgOp) {
+static LogicalResult vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp,
+                                          SmallVectorImpl<Value> &newResults) {
   assert(isaContractionOpInterface(linalgOp) &&
          "expected vectorizeContraction preconditions to be met");
   Location loc = linalgOp.getLoc();
@@ -383,7 +382,8 @@ static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
         linalgOp.indexing_maps(), linalgOp.iterator_types());
     return VectorizationResult{VectorizationStatus::NewOp, contract};
   };
-  return vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
+  return vectorizeAsLinalgGeneric(builder, linalgOp, newResults,
+                                  {vectorizeContraction});
 }
 
 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
@@ -400,19 +400,20 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
   return success(isaContractionOpInterface(linalgOp));
 }
 
-Optional<VectorizedLinalgOp> mlir::linalg::vectorizeLinalgOp(OpBuilder &builder,
-                                                             Operation *op) {
+LogicalResult
+mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op,
+                                SmallVectorImpl<Value> &newResults) {
   if (failed(vectorizeLinalgOpPrecondition(op)))
-    return llvm::None;
+    return failure();
 
   edsc::ScopedContext scope(builder, op->getLoc());
   if (isElementwise(op)) {
     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
                       << "Vectorize linalg op as a generic: " << *op);
-    return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
+    return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op), newResults);
   }
 
-  return vectorizeContraction(builder, cast<LinalgOp>(op));
+  return vectorizeContraction(builder, cast<LinalgOp>(op), newResults);
 }
 
 //----------------------------------------------------------------------------//


        


More information about the Mlir-commits mailing list