[Mlir-commits] [mlir] c37ed77 - [tensor][bufferize] Use affine.apply instead of arith.addi in PadOp lowering
Matthias Springer
llvmlistbot at llvm.org
Tue Aug 23 02:46:25 PDT 2022
Author: Matthias Springer
Date: 2022-08-23T11:46:11+02:00
New Revision: c37ed7762e9f473e9497c52c7669a025965651f7
URL: https://github.com/llvm/llvm-project/commit/c37ed7762e9f473e9497c52c7669a025965651f7
DIFF: https://github.com/llvm/llvm-project/commit/c37ed7762e9f473e9497c52c7669a025965651f7.diff
LOG: [tensor][bufferize] Use affine.apply instead of arith.addi in PadOp lowering
Affine exprs compose better than arith ops.
Differential Revision: https://reviews.llvm.org/D132456
Added:
Modified:
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
mlir/test/Dialect/Tensor/bufferize.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 881237c499ed8..3600524ce7e22 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -790,9 +791,12 @@ struct PadOpInterface
Value srcDim = rewriter.create<tensor::DimOp>(loc, padOp.getSource(), i);
Value lowPad = toValue(mixedLowPad[i]);
Value highPad = toValue(mixedHighPad[i]);
- Value s1 = rewriter.create<arith::AddIOp>(loc, lowPad, highPad);
- Value s2 = rewriter.create<arith::AddIOp>(loc, s1, srcDim);
- dynamicSizes.push_back(s2);
+ AffineExpr s0, s1, s2;
+ bindSymbols(op->getContext(), s0, s1, s2);
+ AffineExpr sumExpr = s0 + s1 + s2;
+ Value sum = rewriter.create<AffineApplyOp>(
+ loc, sumExpr, ValueRange{srcDim, lowPad, highPad});
+ dynamicSizes.push_back(sum);
}
// Create tensor::GenerateOp.
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index 8479c43211e83..66e4cc906f238 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
MLIRTensorTransformsIncGen
LINK_LIBS PUBLIC
+ MLIRAffineDialect
MLIRArithmeticDialect
MLIRBufferizationDialect
MLIRBufferizationTransforms
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 8d53585cb1d8c..7cde99d94d590 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -547,6 +547,7 @@ func.func @tensor.reshape(%t1: tensor<?x10xf32>) -> tensor<2x2x5xf32> {
// -----
+// CHECK: #[[$sum_map:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
// CHECK-LABEL: func @tensor.pad(
// CHECK-SAME: %[[t1:.*]]: tensor<?x10xindex>, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index
func.func @tensor.pad(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
@@ -557,10 +558,8 @@ func.func @tensor.pad(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index
// CHECK-DAG: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]]
// CHECK-DAG: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]]
- // CHECK-DAG: %[[pad0:.*]] = arith.addi %[[c5]], %[[h1]]
- // CHECK-DAG: %[[size0:.*]] = arith.addi %[[pad0]], %[[dim0]]
- // CHECK-DAG: %[[pad1:.*]] = arith.addi %[[l2]], %[[h2]]
- // CHECK-DAG: %[[size1:.*]] = arith.addi %[[pad1]], %[[dim1]]
+ // CHECK-DAG: %[[size0:.*]] = affine.apply #[[$sum_map]]()[%[[dim0]], %[[c5]], %[[h1]]]
+ // CHECK-DAG: %[[size1:.*]] = affine.apply #[[$sum_map]]()[%[[dim1]], %[[l2]], %[[h2]]]
// CHECK: %[[alloc:.*]] = memref.alloc(%[[size0]], %[[size1]]) {{.*}} : memref<?x?xindex>
// CHECK: scf.parallel ({{.*}}) = (%[[c0]], %[[c0]]) to (%[[size0]], %[[size1]]) step (%[[c1]], %[[c1]]) {
// CHECK: memref.store
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index f6f9539b40203..7d45b2ef5cc4c 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -5048,6 +5048,7 @@ cc_library(
],
includes = ["include"],
deps = [
+ ":AffineDialect",
":ArithmeticDialect",
":BufferizationDialect",
":BufferizationTransforms",
More information about the Mlir-commits
mailing list