[Mlir-commits] [mlir] 0fcbbde - [mlir][Linalg] NFC - Refactor vectorization to be more composable

Nicolas Vasilache llvmlistbot at llvm.org
Fri Feb 5 04:03:57 PST 2021


Author: Nicolas Vasilache
Date: 2021-02-05T12:03:14Z
New Revision: 0fcbbde2c7b02b89503d5d1b631229d64eab7104

URL: https://github.com/llvm/llvm-project/commit/0fcbbde2c7b02b89503d5d1b631229d64eab7104
DIFF: https://github.com/llvm/llvm-project/commit/0fcbbde2c7b02b89503d5d1b631229d64eab7104.diff

LOG: [mlir][Linalg] NFC - Refactor vectorization to be more composable

Differential Revision: https://reviews.llvm.org/D96116

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 16203e5459b9..942581b4bbaf 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -31,13 +31,6 @@ struct LinalgTilingOptions;
 //===----------------------------------------------------------------------===//
 using LinalgLoops = SmallVector<Operation *, 4>;
 
-struct TiledLinalgOp {
-  LinalgOp op;
-  SmallVector<Operation *, 8> loops;
-  SmallVector<Value, 4> tensorResults;
-  TiledLinalgOp &operator=(const TiledLinalgOp &) = default;
-};
-
 /// Populates patterns for vectorization of all ConvN-D ops.
 void populateConvVectorizationPatterns(
     MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
@@ -63,6 +56,12 @@ void populateLinalgBufferizePatterns(MLIRContext *context,
 /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
 /// integers, in the range 0..`tileSizes.size()` without duplications
 /// (i.e. `[1,1,2]` is an invalid permutation).
+struct TiledLinalgOp {
+  LinalgOp op;
+  SmallVector<Operation *, 8> loops;
+  SmallVector<Value, 4> tensorResults;
+  TiledLinalgOp &operator=(const TiledLinalgOp &) = default;
+};
 Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
                                      const LinalgTilingOptions &options);
 
@@ -264,7 +263,12 @@ Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
                                    OperationFolder *folder = nullptr);
 
 /// Emit a suitable vector form for a Linalg op with fully static shape.
-void vectorizeLinalgOp(OpBuilder &builder, Operation *op);
+struct VectorizedLinalgOp {
+  SmallVector<Value> tensorResults;
+  VectorizedLinalgOp &operator=(const VectorizedLinalgOp &) = default;
+};
+Optional<VectorizedLinalgOp> vectorizeLinalgOp(OpBuilder &builder,
+                                               Operation *op);
 
 /// Emits a loop nest of `LoopTy` with the proper body for `op`.
 template <typename LoopTy>

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 8dac82a57de5..b80b6fb090e7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -468,10 +468,13 @@ LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
     return failure();
   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
     return failure();
-  if (failed(vectorizeLinalgOpPrecondition(op)))
+  Optional<VectorizedLinalgOp> res = vectorizeLinalgOp(rewriter, op);
+  if (!res)
     return failure();
-  vectorizeLinalgOp(rewriter, op);
-  rewriter.eraseOp(op);
+  if (!res->tensorResults.empty())
+    rewriter.replaceOp(op, res->tensorResults);
+  else
+    rewriter.eraseOp(op);
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 6e5b49125845..a9a43e194d75 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -248,8 +248,7 @@ vectorizeOneOp(OpBuilder &builder, Operation *op,
 ///   TODO: Reuse opportunities for RAR dependencies.
 ///   4. Register CustomVectorizationHook for YieldOp to capture the results.
 ///   5. Iteratively call vectorizeOneOp on the region operations.
-///   6. RAUW the linalg op by the results captured vectorizing the YieldOp.
-static LogicalResult vectorizeAsLinalgGeneric(
+static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
     OpBuilder &builder, LinalgOp linalgOp,
     ArrayRef<CustomVectorizationHook> customVectorizationHooks = {}) {
   // 1. Certain Linalg ops do not have a region but only a region builder.
@@ -306,7 +305,7 @@ static LogicalResult vectorizeAsLinalgGeneric(
     VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks);
     if (result.status == VectorizationStatus::Failure) {
       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: failed to vectorize: " << op);
-      return failure();
+      return llvm::None;
     }
     if (result.status == VectorizationStatus::NewOp) {
       LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: new vector op: "
@@ -315,10 +314,7 @@ static LogicalResult vectorizeAsLinalgGeneric(
     }
   }
 
-  // 6. RAUW the linalg op by the results captured vectorizing the YieldOp.
-  if (!results.empty())
-    linalgOp->replaceAllUsesWith(results);
-  return success();
+  return VectorizedLinalgOp{{results}};
 }
 
 /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
@@ -357,7 +353,8 @@ static bool isElementwise(Operation *op) {
   return hasOnlyScalarElementwiseOp(genericOp.getRegion());
 }
 
-static void vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp) {
+static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
+                                                         LinalgOp linalgOp) {
   assert(isaContractionOpInterface(linalgOp) &&
          "expected vectorizeContraction preconditions to be met");
   Location loc = linalgOp.getLoc();
@@ -384,11 +381,7 @@ static void vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp) {
         linalgOp.indexing_maps(), linalgOp.iterator_types());
     return VectorizationResult{VectorizationStatus::NewOp, contract};
   };
-  auto status =
-      vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
-  (void)status;
-  assert(succeeded(status) &&
-         "Unexpected vectorization failed despite preconditions");
+  return vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
 }
 
 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
@@ -408,8 +401,10 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
   return success(isaContractionOpInterface(linalgOp));
 }
 
-void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
-  assert(succeeded(vectorizeLinalgOpPrecondition(op)));
+Optional<VectorizedLinalgOp> mlir::linalg::vectorizeLinalgOp(OpBuilder &builder,
+                                                             Operation *op) {
+  if (failed(vectorizeLinalgOpPrecondition(op)))
+    return llvm::None;
 
   edsc::ScopedContext scope(builder, op->getLoc());
   // In the case of 0-D memrefs, return null and special case to scalar load or
@@ -418,8 +413,10 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
     // Vectorize fill as a vector.broadcast.
     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
                       << "Rewrite linalg.fill as vector.broadcast: " << *op);
-    buildVectorWrite(builder, fillOp.value(), fillOp.output());
-    return;
+    VectorizedLinalgOp res;
+    if (Value v = buildVectorWrite(builder, fillOp.value(), fillOp.output()))
+      res.tensorResults.push_back(v);
+    return res;
   }
   if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
     // Vectorize copy as a vector.transfer_read+vector.transfer_write.
@@ -428,21 +425,26 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
                          "vector.transfer_write: "
                       << *op);
     Value vector = buildVectorRead(builder, copyOp.input());
-    buildVectorWrite(builder, vector, copyOp.output());
-    return;
+    VectorizedLinalgOp res;
+    if (Value v = buildVectorWrite(builder, vector, copyOp.output()))
+      res.tensorResults.push_back(v);
+    return res;
   }
-
   if (isElementwise(op)) {
     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
-                      << "Rewrite linalg op as vector.transfer_read + " << *op);
-    auto status = vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
-    (void)status;
-    assert(succeeded(status) &&
-           "Unexpected vectorization failed despite preconditions");
-    return;
+                      << "Vectorize linalg op as a generic: " << *op);
+    return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
   }
 
-  vectorizeContraction(builder, cast<LinalgOp>(op));
+  // TODO: as soon as Copy and FillOp. get a region builder, replace all the
+  // above by:
+  // if (isa<FillOp, CopyOp>(op) || isElementwise(op)) {
+  //   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
+  //                     << "Vectorize linalg op as a generic: " << *op);
+  //   return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
+  // }
+
+  return vectorizeContraction(builder, cast<LinalgOp>(op));
 }
 
 //----------------------------------------------------------------------------//

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 3904353287c5..12841a4b6803 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file -debug-only=linalg-vectorization
+
+//| FileCheck %s
 
 // -----
 


        


More information about the Mlir-commits mailing list