[Mlir-commits] [mlir] e027c00 - [mlir][tensor] Add a pattern to split tensor.pad ops

Lei Zhang llvmlistbot at llvm.org
Wed Feb 16 10:44:18 PST 2022


Author: Lei Zhang
Date: 2022-02-16T13:43:57-05:00
New Revision: e027c00821dda48d06cf73c6c16176fcb7b8adcb

URL: https://github.com/llvm/llvm-project/commit/e027c00821dda48d06cf73c6c16176fcb7b8adcb
DIFF: https://github.com/llvm/llvm-project/commit/e027c00821dda48d06cf73c6c16176fcb7b8adcb.diff

LOG: [mlir][tensor] Add a pattern to split tensor.pad ops

This commit adds a pattern to wrap a tensor.pad op with
an scf.if op to separate the cases where we don't need padding
(all pad sizes are actually zeros) and where we indeed need
padding.

This pattern is meant to handle padding inside tiled loops.
Under such cases the padding sizes typically depend on the
loop induction variables. Splitting them would allow treating
perfect tiles and edge tiles separately.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D117018

Added: 
    mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
    mlir/lib/Dialect/Tensor/Transforms/SplitPadding.cpp
    mlir/test/Dialect/Tensor/split-padding.mlir
    mlir/test/lib/Dialect/Tensor/CMakeLists.txt
    mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/CMakeLists.txt
    mlir/tools/mlir-opt/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
    utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
new file mode 100644
index 0000000000000..e6267e9cf02e5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -0,0 +1,26 @@
+//===- Transforms.h - Tensor Transformation Patterns ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
+#define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace tensor {
+
+/// Populates `patterns` with patterns to wrap a tensor.pad op with an scf.if op
+/// to separate the cases where we don't need padding (all pad sizes are
+/// actually zeros) and where we indeed need padding.
+void populateSplitPaddingPatterns(RewritePatternSet &patterns,
+                                  PatternBenefit baseBenefit = 1);
+
+} // namespace tensor
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index d36e556fd7723..2a677195679f0 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRTensorTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
+  SplitPadding.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/SplitPadding.cpp b/mlir/lib/Dialect/Tensor/Transforms/SplitPadding.cpp
new file mode 100644
index 0000000000000..922ad0c82ae8b
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/SplitPadding.cpp
@@ -0,0 +1,94 @@
+//===- SplitPadding.cpp - Splitting tensor.pad Op -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to wrap a tensor.pad op with an scf.if op
+/// to separate the cases where we don't need padding (all pad sizes are
+/// actually zeros) and where we indeed need padding.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "mlir-tensor-split-padding"
+
+using namespace mlir;
+
+/// Returns true if the the given `attrOrValue` is a constant zero.
+static bool isZero(OpFoldResult attrOrValue) {
+  if (Optional<int64_t> val = getConstantIntValue(attrOrValue))
+    return val.getValue() == 0;
+  return false;
+}
+
+/// Gets the given `attrOrValue` as a Value by creating constant ops for
+/// attributes.
+static Value getAsValue(OpFoldResult attrOrValue, OpBuilder &builder,
+                        Location loc) {
+  if (Value val = attrOrValue.dyn_cast<Value>())
+    return val;
+  auto attr = attrOrValue.get<Attribute>().cast<IntegerAttr>();
+  return builder.create<arith::ConstantIndexOp>(loc, attr.getInt());
+}
+
+namespace {
+
+struct SplitPadding final : public OpRewritePattern<tensor::PadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::PadOp padOp,
+                                PatternRewriter &rewriter) const override {
+    // Avoid infinitely applying this pattern.
+    if (padOp->getParentOfType<scf::IfOp>())
+      return failure();
+
+    // If all padding sizes are zero, we don't need to do anything.
+    SmallVector<OpFoldResult> lowPads = padOp.getMixedLowPad();
+    SmallVector<OpFoldResult> highPads = padOp.getMixedHighPad();
+    if (llvm::all_of(lowPads, isZero) && llvm::all_of(highPads, isZero))
+      return failure();
+
+    // Build the condition for the scf.if op: all pad sizes are zero.
+    Location loc = padOp.getLoc();
+    Value cstZero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    SmallVector<Value> eqZeroCmpVals;
+    for (OpFoldResult pad : llvm::concat<OpFoldResult>(lowPads, highPads)) {
+      if (!isZero(pad))
+        eqZeroCmpVals.push_back(rewriter.create<arith::CmpIOp>(
+            loc, arith::CmpIPredicate::eq, getAsValue(pad, rewriter, loc),
+            cstZero));
+    }
+    Value ifCond = eqZeroCmpVals.front();
+    for (Value cmp : llvm::makeArrayRef(eqZeroCmpVals).drop_front())
+      ifCond = rewriter.create<arith::AndIOp>(loc, ifCond, cmp);
+
+    // Build the scf.if op itself. For the "then" branch, we can elide the
+    // padding. For the "else" branch, we retain the clone op.
+    auto thenBuilder = [&padOp](OpBuilder &builder, Location loc) {
+      builder.create<scf::YieldOp>(loc, padOp.source());
+    };
+    auto elseBuilder = [&padOp](OpBuilder &builder, Location loc) {
+      Operation *newOp = builder.clone(*padOp);
+      builder.create<scf::YieldOp>(loc, newOp->getResults());
+    };
+    rewriter.replaceOpWithNewOp<scf::IfOp>(padOp, padOp.getType(), ifCond,
+                                           thenBuilder, elseBuilder);
+    return success();
+  }
+};
+
+} // namespace
+
+void tensor::populateSplitPaddingPatterns(RewritePatternSet &patterns,
+                                          PatternBenefit baseBenefit) {
+  patterns.add<SplitPadding>(patterns.getContext(), baseBenefit);
+}

diff  --git a/mlir/test/Dialect/Tensor/split-padding.mlir b/mlir/test/Dialect/Tensor/split-padding.mlir
new file mode 100644
index 0000000000000..40d186c678c4d
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/split-padding.mlir
@@ -0,0 +1,44 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-split-padding-patterns %s | FileCheck %s
+
+// CHECK-LABEL: func @pad_all_zero_sizes
+func @pad_all_zero_sizes(%input: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %f0 = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %0 = tensor.pad %input low[0, %c0, 0] high[%c0, 0, 0] {
+  ^bb0(%dim0: index, %dim1: index, %dim2: index):
+    tensor.yield %f0 : f32
+  } : tensor<?x?x?xf32> to tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+// CHECK-NOT: scf.if
+//     CHECK: tensor.pad
+
+// -----
+
+// CHECK-LABEL: func @pad_non_zero_sizes
+//  CHECK-SAME: (%[[INPUT:.+]]: tensor<?x?x8xf32>, %[[LOW0:.+]]: index, %[[HIGH1:.+]]: index)
+func @pad_non_zero_sizes(%input: tensor<?x?x8xf32>, %low0: index, %high1: index) -> tensor<?x?x8xf32> {
+  %f0 = arith.constant 0.0 : f32
+  %0 = tensor.pad %input low[%low0, 0, 0] high[0, %high1, 0] {
+  ^bb0(%dim0: index, %dim1: index, %dim2: index):
+    tensor.yield %f0 : f32
+  } : tensor<?x?x8xf32> to tensor<?x?x8xf32>
+  return %0 : tensor<?x?x8xf32>
+}
+
+// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[EQ0:.+]] = arith.cmpi eq, %[[LOW0]], %[[C0]] : index
+// CHECK: %[[EQ1:.+]] = arith.cmpi eq, %[[HIGH1]], %[[C0]] : index
+// CHECK: %[[AND:.+]] = arith.andi %[[EQ0]], %[[EQ1]] : i1
+// CHECK: %[[IF:.+]] = scf.if %[[AND]] -> (tensor<?x?x8xf32>) {
+// CHECK:   scf.yield %[[INPUT]] : tensor<?x?x8xf32>
+// CHECK: } else {
+// CHECK:   %[[PAD:.+]] = tensor.pad %[[INPUT]] low[%[[LOW0]], 0, 0] high[0, %[[HIGH1]], 0]  {
+// CHECK:   ^bb0(%{{.+}}: index, %{{.+}}: index, %{{.+}}: index):
+// CHECK:     tensor.yield %[[F0]] : f32
+// CHECK:   } : tensor<?x?x8xf32> to tensor<?x?x8xf32>
+// CHECK:   scf.yield %[[PAD]] : tensor<?x?x8xf32>
+// CHECK: }
+// CHECK: return %[[IF]] : tensor<?x?x8xf32>

diff  --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index d02dc61045fae..87078f8cee684 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -8,6 +8,7 @@ add_subdirectory(SCF)
 add_subdirectory(Shape)
 add_subdirectory(SPIRV)
 add_subdirectory(StandardOps)
+add_subdirectory(Tensor)
 add_subdirectory(Test)
 add_subdirectory(Tosa)
 add_subdirectory(Vector)

diff  --git a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt
new file mode 100644
index 0000000000000..fbe7485257bbc
--- /dev/null
+++ b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt
@@ -0,0 +1,14 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRTensorTestPasses
+  TestTensorTransforms.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  MLIRArithmetic
+  MLIRPass
+  MLIRSCF
+  MLIRTensor
+  MLIRTensorTransforms
+  MLIRTransforms
+  )

diff  --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
new file mode 100644
index 0000000000000..c720ca1e3a235
--- /dev/null
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -0,0 +1,65 @@
+//===- TestTensorTransforms.cpp - Test Tensor transformation patterns -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements logic for testing Tensor transformations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestTensorTransforms
+    : public PassWrapper<TestTensorTransforms, OperationPass<FuncOp>> {
+  TestTensorTransforms() = default;
+  TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {}
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithmeticDialect, scf::SCFDialect>();
+  }
+
+  StringRef getArgument() const final {
+    return "test-tensor-transform-patterns";
+  }
+  StringRef getDescription() const final {
+    return "Test Tensor transformation patterns by applying them greedily.";
+  }
+
+  void runOnOperation() override;
+
+  Option<bool> testSplitPaddingPatterns{
+      *this, "test-split-padding-patterns",
+      llvm::cl::desc("Test patterns to split tensor.pad ops"),
+      llvm::cl::init(false)};
+};
+} // namespace
+
+static void applySplitPaddingPatterns(FuncOp funcOp) {
+  RewritePatternSet patterns(funcOp.getContext());
+  tensor::populateSplitPaddingPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
+void TestTensorTransforms::runOnOperation() {
+  FuncOp func = getOperation();
+  if (testSplitPaddingPatterns)
+    applySplitPaddingPatterns(func);
+}
+
+namespace mlir {
+namespace test {
+void registerTestTensorTransforms() {
+  PassRegistration<TestTensorTransforms>();
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index c03d6403a74eb..b83548e2c19df 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -23,7 +23,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRShapeTestPasses
     MLIRSPIRVTestPasses
     MLIRStandardOpsTestPasses
-    MLIRVectorTestPasses
+    MLIRTensorTestPasses
     MLIRTestAnalysis
     MLIRTestDialect
     MLIRTestIR
@@ -31,6 +31,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRTestReducer
     MLIRTestRewrite
     MLIRTestTransforms
+    MLIRVectorTestPasses
     )
 endif()
 

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 73d1b54bbf4fd..647fffaf240d3 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -106,6 +106,7 @@ void registerTestPreparationPassWithAllowedMemrefResults();
 void registerTestRecursiveTypesPass();
 void registerTestSCFUtilsPass();
 void registerTestSliceAnalysisPass();
+void registerTestTensorTransforms();
 void registerTestVectorLowerings();
 } // namespace test
 } // namespace mlir
@@ -194,6 +195,7 @@ void registerTestPasses() {
   mlir::test::registerTestRecursiveTypesPass();
   mlir::test::registerTestSCFUtilsPass();
   mlir::test::registerTestSliceAnalysisPass();
+  mlir::test::registerTestTensorTransforms();
   mlir::test::registerTestVectorLowerings();
 }
 #endif

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index da029a721b3ae..511c22132ad67 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4645,6 +4645,7 @@ cc_library(
     hdrs = [
         "include/mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h",
         "include/mlir/Dialect/Tensor/Transforms/Passes.h",
+        "include/mlir/Dialect/Tensor/Transforms/Transforms.h",
     ],
     includes = ["include"],
     deps = [
@@ -4652,6 +4653,7 @@ cc_library(
         ":Async",
         ":BufferizationDialect",
         ":BufferizationTransforms",
+        ":DialectUtils",
         ":IR",
         ":MemRefDialect",
         ":ParallelLoopMapperAttrGen",
@@ -5985,6 +5987,7 @@ cc_binary(
         "//mlir/test:TestShapeDialect",
         "//mlir/test:TestStandardOps",
         "//mlir/test:TestStandardToLLVM",
+        "//mlir/test:TestTensor",
         "//mlir/test:TestTosaDialect",
         "//mlir/test:TestTransforms",
         "//mlir/test:TestTypeDialect",

diff  --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 505508ba89763..86aaabedd14f2 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -510,6 +510,21 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "TestTensor",
+    srcs = glob(["lib/Dialect/Tensor/*.cpp"]),
+    defines = ["MLIR_CUDA_CONVERSIONS_ENABLED"],
+    includes = ["lib/Dialect/Test"],
+    deps = [
+        "//mlir:ArithmeticDialect",
+        "//mlir:Pass",
+        "//mlir:SCFDialect",
+        "//mlir:TensorDialect",
+        "//mlir:TensorTransforms",
+        "//mlir:Transforms",
+    ],
+)
+
 cc_library(
     name = "TestVector",
     srcs = glob(["lib/Dialect/Vector/*.cpp"]),


        


More information about the Mlir-commits mailing list