[Mlir-commits] [mlir] 0e2f9b6 - Fix tile-and-pad when padding doesn't span all dimension

Ahmed Taei llvmlistbot at llvm.org
Thu Apr 15 20:18:00 PDT 2021


Author: Ahmed Taei
Date: 2021-04-15T20:17:40-07:00
New Revision: 0e2f9b61fd9a30b152e9c80178b3bcc4b171b416

URL: https://github.com/llvm/llvm-project/commit/0e2f9b61fd9a30b152e9c80178b3bcc4b171b416
DIFF: https://github.com/llvm/llvm-project/commit/0e2f9b61fd9a30b152e9c80178b3bcc4b171b416.diff

LOG: Fix tile-and-pad when padding doesn't span all dimension

Without this tile-and-pad will never terminate if pad-fails.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D97720

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 10e96d2b04076..c51c92930ab41 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -257,11 +257,8 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
     return failure();
 
   // Setup RAII guard to return properly.
-  bool succeeded = true;
   LinalgOp tiledOp = res->op;
   auto guard = llvm::make_scope_exit([&]() {
-    if (!succeeded)
-      return;
     // Return relevant information to derived pattern.
     result = *res;
     // Replace filter on both tiledOp and tiledAndPaddedOp, if necessary.
@@ -278,7 +275,6 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
   // Try to pad on the fly by rewriting res->op as a padded op.
   if (failed(rewriteAsPaddedOp(rewriter, *res, options))) {
     // Set so RAII guard does not propagate TiledLinalgOp to `result`.
-    succeeded = false;
     return failure();
   }
 

diff  --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
index fab8c80acd3ff..de2aeee460380 100644
--- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -test-linalg-transform-patterns=test-tile-and-pad-pattern -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-and-pad-pattern tile-sizes-for-padding=2,3,4" -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-and-pad-pattern tile-sizes-for-padding=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE
 
 // CHECK-LABEL: func @matmul_tensors(
 // CHECK-SAME:    %[[TA:[0-9a-z]+]]: tensor<?x?xi8>
@@ -39,3 +40,10 @@ func @matmul_tensors(
 //      CHECK: return %[[TD0]] : tensor<?x?xi32>
   return %0 : tensor<?x?xi32>
 }
+
+// CHECK-1DIM-TILE: func @matmul_tensors(
+// CHECK-1DIM-TILE:    %[[TA:[0-9a-z]+]]: tensor<?x?xi8>
+// CHECK-1DIM-TILE:    %[[TB:[0-9a-z]+]]: tensor<?x?xi8>
+// CHECK-1DIM-TILE:    %[[TC:[0-9a-z]+]]: tensor<?x?xi32>) -> tensor<?x?xi32> {
+// CHECK-1DIM-TILE-NOT: scf.for
+// CHECK-1DIM-TILE: linalg.matmul_i8_i8_i32 ins(%[[TA]], %[[TB]] : tensor<?x?xi8>, tensor<?x?xi8>) outs(%[[TC]] : tensor<?x?xi32>) -> tensor<?x?xi32>

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index fd8fb3bc6eff2..a6fe895035d20 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -87,6 +87,10 @@ struct TestLinalgTransforms
   Option<int> testHoistPadding{*this, "test-hoist-padding",
                                llvm::cl::desc("Test hoist padding"),
                                llvm::cl::init(0)};
+  ListOption<int64_t> tileSizesForPadding{
+      *this, "tile-sizes-for-padding",
+      llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore,
+      llvm::cl::MiscFlags::CommaSeparated};
 };
 } // end anonymous namespace
 
@@ -522,12 +526,12 @@ static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
   return b.create<ConstantOp>(op.getOwner()->getLoc(), t, b.getZeroAttr(t));
 }
 
-static void applyTileAndPadPattern(FuncOp funcOp) {
+static void applyTileAndPadPattern(FuncOp funcOp, ArrayRef<int64_t> tileSizes) {
   MLIRContext *context = funcOp.getContext();
   RewritePatternSet tilingPattern(context);
   auto linalgTilingOptions =
       linalg::LinalgTilingOptions()
-          .setTileSizes({2, 3, 4})
+          .setTileSizes(tileSizes)
           .setPaddingValueComputationFunction(getNeutralOfLinalgOp);
   tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulI8I8I32Op>>(
       context, linalgTilingOptions,
@@ -570,7 +574,7 @@ void TestLinalgTransforms::runOnFunction() {
   if (testAffineMinSCFCanonicalizationPatterns)
     return applyAffineMinSCFCanonicalizationPatterns(getFunction());
   if (testTileAndPadPattern)
-    return applyTileAndPadPattern(getFunction());
+    return applyTileAndPadPattern(getFunction(), tileSizesForPadding);
   if (testHoistPadding) {
     getFunction().walk([&](linalg::PadTensorOp padTensorOp) {
       (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding);


        


More information about the Mlir-commits mailing list