[Mlir-commits] [mlir] 35df2f6 - Refactor GenericPadTensorOpVectorizationPattern

Nicolas Vasilache llvmlistbot at llvm.org
Wed Jul 7 04:45:46 PDT 2021


Author: Yi Zhang
Date: 2021-07-07T11:44:32Z
New Revision: 35df2f6fbd1ae2e6f9313454e5446212fcbcf90a

URL: https://github.com/llvm/llvm-project/commit/35df2f6fbd1ae2e6f9313454e5446212fcbcf90a
DIFF: https://github.com/llvm/llvm-project/commit/35df2f6fbd1ae2e6f9313454e5446212fcbcf90a.diff

LOG: Refactor GenericPadTensorOpVectorizationPattern

Refactor the original code to rewrite a PadTensorOp into a
sequence of InitTensorOp, FillOp and InsertSliceOp without
vectorization by default. `GenericPadTensorOpVectorizationPattern`
provides a customized OptimizeCopyFn to vectorize the
copying step.

Reviewed By: silvas, nicolasvasilache, springerm

Differential Revision: https://reviews.llvm.org/D105293

Added: 
    mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 5f533df13741..a9724943f9be 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -883,6 +883,28 @@ struct PadTensorOpTransformationPattern : public OpRewritePattern<PadTensorOp> {
                                 PatternRewriter &rewriter) const override;
 };
 
+using OptimizeCopyFn =
+    std::function<LogicalResult(PatternRewriter &, PadTensorOp, Value)>;
+
+/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp and
+/// InsertSliceOp. For now, only constant padding values are supported.
+/// `OptimizeCopyFn` can be used to customize copying step optimization.
+struct GeneralizePadTensorOpPattern : public OpRewritePattern<PadTensorOp> {
+  GeneralizePadTensorOpPattern(MLIRContext *context,
+                               OptimizeCopyFn optimizeCopyFn = nullptr,
+                               PatternBenefit benefit = 1)
+      : OpRewritePattern<PadTensorOp>(context, benefit),
+        optimizeCopyFn(optimizeCopyFn) {}
+  LogicalResult matchAndRewrite(PadTensorOp padOp,
+                                PatternRewriter &rewriter) const override;
+
+protected:
+  OptimizeCopyFn optimizeCopyFn;
+  Value createFillOrGenerateOp(PatternRewriter &rewriter, PadTensorOp padOp,
+                               Value dest,
+                               const SmallVector<Value> &dynSizes) const;
+};
+
 /// Populates `patterns` with patterns that vectorize linalg.pad_tensor.
 /// These patterns are meant to apply in a complementary fashion. Benefits
 /// are used to encode a certain ordering of pattern application. To avoid

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index fba709a87152..e269b41674a3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -334,6 +334,7 @@ struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
     target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation);
 
     RewritePatternSet patterns(&context);
+    patterns.add<GeneralizePadTensorOpPattern>(patterns.getContext());
     populateLinalgBufferizePatterns(typeConverter, patterns);
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 6dba56494d5e..82d259ed0572 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -699,6 +699,95 @@ LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
   return success();
 }
 
+/// Filling `dest` using FillOp constant padding value if possible.
+/// Otherwise, generate a tensor::GenerateOp.
+Value GeneralizePadTensorOpPattern::createFillOrGenerateOp(
+    PatternRewriter &rewriter, PadTensorOp padOp, Value dest,
+    const SmallVector<Value> &dynSizes) const {
+  auto padValue = padOp.getConstantPaddingValue();
+  if (padValue)
+    return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
+
+  // Fill could not be optimized: Lower to tensor::GenerateOp with region.
+  auto generateOp = rewriter.create<tensor::GenerateOp>(
+      padOp.getLoc(), padOp.getResultType(), dynSizes);
+  // Copy region to new op.
+  BlockAndValueMapping bvm;
+  padOp.region().cloneInto(&generateOp.getRegion(), bvm);
+  // Rewrite linalg::YieldOp to tensor::YieldOp.
+  OpBuilder::InsertionGuard guard(rewriter);
+  auto yieldOp =
+      dyn_cast<linalg::YieldOp>(generateOp.getRegion().front().getTerminator());
+  assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator");
+  assert(yieldOp.values().size() == 1);
+  rewriter.setInsertionPoint(yieldOp);
+  rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]);
+  return generateOp;
+}
+
+LogicalResult
+GeneralizePadTensorOpPattern::matchAndRewrite(PadTensorOp padOp,
+                                              PatternRewriter &rewriter) const {
+  // Given an OpFoldResult, return an index-typed value.
+  auto getIdxValue = [&](OpFoldResult ofr) {
+    if (auto val = ofr.dyn_cast<Value>())
+      return val;
+    return rewriter
+        .create<ConstantIndexOp>(
+            padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
+        .getResult();
+  };
+
+  auto resultType = padOp.getResultType();
+  // Compute size of InitTensorOp. Any combination of static/dynamic is
+  // supported.
+  SmallVector<Value> dynSizes;
+  SmallVector<int64_t> staticSizes;
+  for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
+    if (resultType.isDynamicDim(dim)) {
+      auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(),
+                                                          padOp.source(), dim);
+      // Add low and high padding value.
+      auto plusLow = rewriter.createOrFold<AddIOp>(
+          padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
+      auto plusHigh = rewriter.createOrFold<AddIOp>(
+          padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
+      dynSizes.push_back(plusHigh);
+    }
+    staticSizes.push_back(resultType.getDimSize(dim));
+  }
+
+  // Init tensor and fill it with padding.
+  Value init = rewriter.create<InitTensorOp>(
+      padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
+  Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
+
+  // Try optimize the copy of source.
+  if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
+    return success();
+
+  // PadTensorOps cannot be optimized. Generate a InsertSliceOp instead
+  // for copying the PadOp source.
+  auto sourceType = padOp.getSourceType();
+  // Compute size of source of PadTensorOp.
+  SmallVector<OpFoldResult> srcSizes;
+  for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
+    if (sourceType.isDynamicDim(dim)) {
+      srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
+          padOp.getLoc(), padOp.source(), dim));
+    } else {
+      srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
+    }
+  }
+  // Strides of InsertSliceOp are all 1.
+  SmallVector<OpFoldResult> strides(sourceType.getRank(),
+                                    rewriter.getIndexAttr(1));
+  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
+      padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
+
+  return success();
+}
+
 /// Given an OpFoldResult, return a Value. If the OpFoldResult is an Attribute,
 /// it must be of type Integer.
 static Value asValue(OpBuilder &builder, Location loc, OpFoldResult ofr) {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2c315706d702..f83892005e1b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -682,104 +682,15 @@ static SmallVector<Value> ofrToIndexValues(OpBuilder &builder, Location loc,
 /// If there is enough static type information, TransferReadOps and
 /// TransferWriteOps may be generated instead of InsertSliceOps.
 struct GenericPadTensorOpVectorizationPattern
-    : public OpRewritePattern<PadTensorOp> {
-  using OpRewritePattern<PadTensorOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(PadTensorOp padOp,
-                                PatternRewriter &rewriter) const final {
-    // Given an OpFoldResult, return an index-typed value.
-    auto getIdxValue = [&](OpFoldResult ofr) {
-      if (auto val = ofr.dyn_cast<Value>())
-        return val;
-      return rewriter.create<ConstantIndexOp>(
-          padOp.getLoc(), getIntFromAttr(ofr.get<Attribute>())).getResult();
-    };
-
-    auto resultType = padOp.getResultType();
-    // Compute size of InitTensorOp. Any combination of static/dynamic is
-    // supported.
-    SmallVector<Value> dynSizes;
-    SmallVector<int64_t> staticSizes;
-    for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
-      if (resultType.isDynamicDim(dim)) {
-        auto srcSize = rewriter.createOrFold<tensor::DimOp>(
-            padOp.getLoc(), padOp.source(), dim);
-        // Add low and high padding value.
-        auto plusLow = rewriter.createOrFold<AddIOp>(
-            padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
-        auto plusHigh = rewriter.createOrFold<AddIOp>(
-            padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
-        dynSizes.push_back(plusHigh);
-      }
-      staticSizes.push_back(resultType.getDimSize(dim));
-    }
-
-    // Init tensor and fill it with padding.
-    Value init = rewriter.create<InitTensorOp>(
-        padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
-    Value fill = tryVectorizeFill(rewriter, padOp, init, dynSizes);
-
-    // Try vectorizing the copy of source.
-    if (tryVectorizeCopy(rewriter, padOp, fill).succeeded())
-      return success();
-
-    // Neither source type nor PadTensorOp result type have static shape. Such
-    // PadTensorOps cannot be vectorized. Generate a InsertSliceOp instead
-    // for copying the PadOp source.
-
-    auto sourceType = padOp.getSourceType();
-    // Compute size of source of PadTensorOp.
-    SmallVector<OpFoldResult> srcSizes;
-    for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
-      if (sourceType.isDynamicDim(dim)) {
-        srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
-            padOp.getLoc(), padOp.source(), dim));
-      } else {
-        srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
-      }
-    }
-    // Strides of InsertSliceOp are all 1.
-    SmallVector<OpFoldResult> strides(sourceType.getRank(),
-                                      rewriter.getIndexAttr(1));
-    rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
-        padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
-
-    return success();
-  }
-
-  /// Vectorize the filling of `dest`. This is possible if the padOp is padding
-  /// with a constant value. Otherwise, generate a tensor::GenerateOp.
-  Value tryVectorizeFill(PatternRewriter &rewriter, PadTensorOp padOp,
-                         Value dest, const SmallVector<Value> &dynSizes) const {
-    // Fill can be vectorized if padValue is a constant. (If there is enough
-    // static type information, the FillOp will be vectorized by another
-    // pattern.)
-    auto padValue = padOp.getConstantPaddingValue();
-    if (padValue)
-      return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
-
-    // Fill could not be vectorized: Lower to tensor::GenerateOp with region.
-    auto generateOp = rewriter.create<tensor::GenerateOp>(
-        padOp.getLoc(), padOp.getResultType(), dynSizes);
-    // Copy region to new op.
-    BlockAndValueMapping bvm;
-    padOp.region().cloneInto(&generateOp.getRegion(), bvm);
-    // Rewrite linalg::YieldOp to tensor::YieldOp.
-    OpBuilder::InsertionGuard guard(rewriter);
-    auto yieldOp = dyn_cast<linalg::YieldOp>(
-        generateOp.getRegion().front().getTerminator());
-    assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator");
-    assert(yieldOp.values().size() == 1);
-    rewriter.setInsertionPoint(yieldOp);
-    rewriter.replaceOpWithNewOp<tensor::YieldOp>(yieldOp, yieldOp.values()[0]);
-    return generateOp;
-  }
-
+    : public GeneralizePadTensorOpPattern {
+  GenericPadTensorOpVectorizationPattern(MLIRContext *context,
+                                         PatternBenefit benefit = 1)
+      : GeneralizePadTensorOpPattern(context, tryVectorizeCopy, benefit) {}
   /// Vectorize the copying of a PadTensorOp's source. This is possible if each
   /// dimension size is statically know in the source type or the result type
   /// (or both).
-  LogicalResult tryVectorizeCopy(PatternRewriter &rewriter, PadTensorOp padOp,
-                                 Value dest) const {
+  static LogicalResult tryVectorizeCopy(PatternRewriter &rewriter,
+                                        PadTensorOp padOp, Value dest) {
     auto sourceType = padOp.getSourceType();
     auto resultType = padOp.getResultType();
 

diff  --git a/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir b/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir
new file mode 100644
index 000000000000..9572ee6d1b35
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir
@@ -0,0 +1,46 @@
+// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-pad-tensor"  %s | FileCheck --check-prefix=CHECK %s
+
+// CHECK-LABEL:   func @generalize_pad_tensor_static_shape(
+// CHECK-SAME:                                             %[[IN:.*]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
+// CHECK:           %[[C0:.*]] = constant 0.000000e+00 : f32
+// CHECK:           %[[INIT:.*]] = linalg.init_tensor [1, 32, 32, 1] : tensor<1x32x32x1xf32>
+// CHECK:           %[[FILL:.*]] = linalg.fill(%[[C0]], %[[INIT]]) : f32, tensor<1x32x32x1xf32> -> tensor<1x32x32x1xf32>
+// CHECK:           %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]][0, 2, 2, 0] [1, 28, 28, 1] [1, 1, 1, 1] : tensor<1x28x28x1xf32> into tensor<1x32x32x1xf32>
+// CHECK:           return %[[PADDED]] : tensor<1x32x32x1xf32>
+func @generalize_pad_tensor_static_shape(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> {
+  %cst = constant 0.000000e+00 : f32
+  %0 = linalg.pad_tensor %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0]  {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):  // no predecessors
+    linalg.yield %cst : f32
+  } : tensor<1x28x28x1xf32> to tensor<1x32x32x1xf32>
+  return %0 : tensor<1x32x32x1xf32>
+}
+
+// CHECK-LABEL:   func @generalize_pad_tensor_dynamic_shape(
+// CHECK-SAME:                                              %[[IN:.*]]: tensor<4x?x2x?xf32>,
+// CHECK-SAME:                                              %[[OFFSET:.*]]: index) -> tensor<4x?x?x?xf32> {
+// CHECK:           %[[C0:.*]] = constant 0 : index
+// CHECK:           %[[CST:.*]] = constant 0.000000e+00 : f32
+// CHECK:           %[[C2:.*]] = constant 2 : index
+// CHECK:           %[[C1:.*]] = constant 1 : index
+// CHECK:           %[[C3:.*]] = constant 3 : index
+// CHECK:           %[[DIM1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32>
+// CHECK:           %[[OUT_DIM2:.*]] = addi %[[OFFSET]], %[[C2]] : index
+// CHECK:           %[[DIM3:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32>
+// CHECK:           %[[OUT_DIM3:.*]] = addi %[[DIM3]], %[[OFFSET]] : index
+// CHECK:           %[[INIT:.*]] = linalg.init_tensor [4, %[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]] : tensor<4x?x?x?xf32>
+// CHECK:           %[[FILL:.*]] = linalg.fill(%[[CST]], %[[INIT]]) : f32, tensor<4x?x?x?xf32> -> tensor<4x?x?x?xf32>
+// CHECK:           %[[DIM1_1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32>
+// CHECK:           %[[DIM3_1:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32>
+// CHECK:           %[[PADDED:.*]] = tensor.insert_slice %[[IN]] into %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[OFFSET]], %[[C0]]] [4, %[[DIM1_1]], 2, %[[DIM3_1]]] [1, 1, 1, 1] : tensor<4x?x2x?xf32> into tensor<4x?x?x?xf32>
+// CHECK:           return %[[PADDED]] : tensor<4x?x?x?xf32>
+// CHECK:         }
+func @generalize_pad_tensor_dynamic_shape(%arg0: tensor<4x?x2x?xf32>, %arg1: index) -> tensor<4x?x?x?xf32> {
+  %c0 = constant 0 : index
+  %cst = constant 0.0 : f32
+  %out = linalg.pad_tensor %arg0 low[%c0, %c0, %arg1, %c0] high[%c0, %c0, %c0, %arg1]  {
+  ^bb0(%gen_arg1: index, %gen_arg2: index, %gen_arg3: index, %gen_arg4: index):  // no predecessors
+    linalg.yield %cst : f32
+  } : tensor<4x?x2x?xf32> to tensor<4x?x?x?xf32>
+  return %out : tensor<4x?x?x?xf32>
+}

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir
new file mode 100644
index 000000000000..162f5282f8e3
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-padtensor.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s -linalg-bufferize -std-bufferize \
+// RUN: -tensor-constant-bufferize -tensor-bufferize -func-bufferize \
+// RUN: -finalizing-bufferize \
+// RUN: -convert-linalg-to-loops -convert-scf-to-std -convert-linalg-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+
+func @main() {
+  %const = constant dense<[[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]]]> : tensor<1x2x3xf32>
+  %dynamic = tensor.cast %const: tensor<1x2x3xf32> to tensor<1x?x3xf32>
+  %offset = constant 2 : index
+  %cst = constant 2.3 : f32
+  %c0 = constant 0 : index
+  %out = linalg.pad_tensor %dynamic low[%c0, %offset, %c0] high[%c0, %c0, %offset]  {
+  ^bb0(%gen_arg1: index, %gen_arg2: index, %gen_arg3: index):  // no predecessors
+    linalg.yield %cst : f32
+  } : tensor<1x?x3xf32> to tensor<1x?x?xf32>
+  %unranked = tensor.cast %out: tensor<1x?x?xf32> to tensor<*xf32>
+  call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
+
+  //      CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
+  // CHECK-SAME: rank = 3 offset = 0 sizes = [1, 4, 5] strides = [20, 5, 1] data =
+  // CHECK-NEXT{LITERAL}: [[[2.3,    2.3,    2.3,    2.3,    2.3],
+  // CHECK-NEXT: [2.3,    2.3,    2.3,    2.3,    2.3],
+  // CHECK-NEXT: [1,    2,    3,    2.3,    2.3],
+  // CHECK-NEXT: [2,    3,    4,    2.3,    2.3]]]
+
+  return
+}
+
+func private @print_memref_f32(%ptr : tensor<*xf32>)

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 45e7c8235067..8a8ce298cc6a 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -97,6 +97,10 @@ struct TestLinalgTransforms
       *this, "test-transform-pad-tensor",
       llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
       llvm::cl::init(false)};
+  Option<bool> testGeneralizePadTensor{
+      *this, "test-generalize-pad-tensor",
+      llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
+      llvm::cl::init(false)};
   Option<bool> testSwapSubTensorPadTensor{
       *this, "test-swap-subtensor-padtensor",
       llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
@@ -530,6 +534,12 @@ static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
+static void applyGeneralizePadTensorPatterns(FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext());
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
   RewritePatternSet patterns(funcOp.getContext());
   patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
@@ -614,6 +624,8 @@ void TestLinalgTransforms::runOnFunction() {
     return applyLinalgToVectorPatterns(getFunction());
   if (testTransformPadTensor)
     return applyPadTensorToGenericPatterns(getFunction());
+  if (testGeneralizePadTensor)
+    return applyGeneralizePadTensorPatterns(getFunction());
   if (testSwapSubTensorPadTensor)
     return applyExtractSliceOfPadTensorSwapPattern(getFunction());
   if (testAffineMinSCFCanonicalizationPatterns)


        


More information about the Mlir-commits mailing list