[Mlir-commits] [mlir] a65a505 - [mlir][Linalg] Canonicalize tensor_reshape(splat-constant) -> splat-constant.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Aug 18 08:17:36 PDT 2020
Author: MaheshRavishankar
Date: 2020-08-18T08:17:09-07:00
New Revision: a65a50540e3b5dd1938a1d14f31b912a311537fb
URL: https://github.com/llvm/llvm-project/commit/a65a50540e3b5dd1938a1d14f31b912a311537fb
DIFF: https://github.com/llvm/llvm-project/commit/a65a50540e3b5dd1938a1d14f31b912a311537fb.diff
LOG: [mlir][Linalg] Canonicalize tensor_reshape(splat-constant) -> splat-constant.
When the operand to the linalg.tensor_reshape op is a splat constant,
the result can be replaced with a splat constant of the same value but
different type.
Differential Revision: https://reviews.llvm.org/D86117
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 009699be5263..308272d66d56 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
@@ -734,9 +735,28 @@ static LogicalResult verify(TensorReshapeOp op) {
return success();
}
+/// Reshape of a splat constant can be replaced with a constant of the result
+/// type.
+struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
+ using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+ PatternRewriter &rewriter) const override {
+ DenseElementsAttr attr;
+ if (!matchPattern(reshapeOp.src(), m_Constant(&attr)))
+ return failure();
+ if (!attr || !attr.isSplat())
+ return failure();
+ DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
+ reshapeOp.getResultType(), attr.getRawData(), true);
+ rewriter.replaceOpWithNewOp<ConstantOp>(reshapeOp, newAttr);
+ return success();
+ }
+};
+
void TensorReshapeOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<CollapseReshapeOps<TensorReshapeOp>>(context);
+ results.insert<CollapseReshapeOps<TensorReshapeOp>, FoldReshapeWithConstant>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 005bd1c87445..85321084cd0c 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -203,3 +203,60 @@ func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf
// CHECK-NOT: linalg.copy
// CHECK-NEXT: linalg.generic
+// -----
+
+func @reshape_splat_constant_int32() -> tensor<2x4x2xi32>
+{
+ %c0 = constant dense<42> : tensor<2x8xi32>
+ %0 = linalg.tensor_reshape %c0
+ [affine_map<(d0, d1, d2) -> (d0)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>]
+ : tensor<2x8xi32> into tensor<2x4x2xi32>
+ return %0 : tensor<2x4x2xi32>
+}
+// CHECK-LABEL: @reshape_splat_constant_int32
+// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi32>
+// CHECK-NOT: linalg.tensor_reshape
+// CHECK: return %[[CST]]
+
+func @reshape_splat_constant_int16() -> tensor<2x4x2xi16>
+{
+ %c0 = constant dense<42> : tensor<2x8xi16>
+ %0 = linalg.tensor_reshape %c0
+ [affine_map<(d0, d1, d2) -> (d0)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>]
+ : tensor<2x8xi16> into tensor<2x4x2xi16>
+ return %0 : tensor<2x4x2xi16>
+}
+// CHECK-LABEL: @reshape_splat_constant_int16
+// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi16>
+// CHECK-NOT: linalg.tensor_reshape
+// CHECK: return %[[CST]]
+
+func @reshape_splat_constant_float32() -> tensor<2x4x2xf32>
+{
+ %c0 = constant dense<42.0> : tensor<2x8xf32>
+ %0 = linalg.tensor_reshape %c0
+ [affine_map<(d0, d1, d2) -> (d0)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>]
+ : tensor<2x8xf32> into tensor<2x4x2xf32>
+ return %0 : tensor<2x4x2xf32>
+}
+// CHECK-LABEL: @reshape_splat_constant_float32
+// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf32>
+// CHECK-NOT: linalg.tensor_reshape
+// CHECK: return %[[CST]]
+
+func @reshape_splat_constant_float64() -> tensor<2x4x2xf64>
+{
+ %c0 = constant dense<42.0> : tensor<2x8xf64>
+ %0 = linalg.tensor_reshape %c0
+ [affine_map<(d0, d1, d2) -> (d0)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>]
+ : tensor<2x8xf64> into tensor<2x4x2xf64>
+ return %0 : tensor<2x4x2xf64>
+}
+// CHECK-LABEL: @reshape_splat_constant_float64
+// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf64>
+// CHECK-NOT: linalg.tensor_reshape
+// CHECK: return %[[CST]]
More information about the Mlir-commits
mailing list