[Mlir-commits] [mlir] a65a505 - [mlir][Linalg] Canonicalize tensor_reshape(splat-constant) -> splat-constant.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Aug 18 08:17:36 PDT 2020


Author: MaheshRavishankar
Date: 2020-08-18T08:17:09-07:00
New Revision: a65a50540e3b5dd1938a1d14f31b912a311537fb

URL: https://github.com/llvm/llvm-project/commit/a65a50540e3b5dd1938a1d14f31b912a311537fb
DIFF: https://github.com/llvm/llvm-project/commit/a65a50540e3b5dd1938a1d14f31b912a311537fb.diff

LOG: [mlir][Linalg] Canonicalize tensor_reshape(splat-constant) -> splat-constant.

When the operand to the linalg.tensor_reshape op is a splat constant,
the result can be replaced with a splat constant of the same value but
different type.

Differential Revision: https://reviews.llvm.org/D86117

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 009699be5263..308272d66d56 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Function.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -734,9 +735,28 @@ static LogicalResult verify(TensorReshapeOp op) {
   return success();
 }
 
+/// Reshape of a splat constant can be replaced with a constant of the result
+/// type.
+struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
+  using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    DenseElementsAttr attr;
+    if (!matchPattern(reshapeOp.src(), m_Constant(&attr)))
+      return failure();
+    if (!attr || !attr.isSplat())
+      return failure();
+    DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
+        reshapeOp.getResultType(), attr.getRawData(), true);
+    rewriter.replaceOpWithNewOp<ConstantOp>(reshapeOp, newAttr);
+    return success();
+  }
+};
+
 void TensorReshapeOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
-  results.insert<CollapseReshapeOps<TensorReshapeOp>>(context);
+  results.insert<CollapseReshapeOps<TensorReshapeOp>, FoldReshapeWithConstant>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 005bd1c87445..85321084cd0c 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -203,3 +203,60 @@ func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf
 //   CHECK-NOT:   linalg.copy
 //  CHECK-NEXT:   linalg.generic
 
+// -----
+
+func @reshape_splat_constant_int32() -> tensor<2x4x2xi32>
+{
+  %c0 = constant dense<42> : tensor<2x8xi32>
+  %0 = linalg.tensor_reshape %c0
+         [affine_map<(d0, d1, d2) -> (d0)>,
+          affine_map<(d0, d1, d2) -> (d1, d2)>]
+       : tensor<2x8xi32> into tensor<2x4x2xi32>
+  return %0 : tensor<2x4x2xi32>
+}
+// CHECK-LABEL: @reshape_splat_constant_int32
+//       CHECK:   %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi32>
+//   CHECK-NOT:   linalg.tensor_reshape
+//       CHECK:   return %[[CST]]
+
+func @reshape_splat_constant_int16() -> tensor<2x4x2xi16>
+{
+  %c0 = constant dense<42> : tensor<2x8xi16>
+  %0 = linalg.tensor_reshape %c0
+         [affine_map<(d0, d1, d2) -> (d0)>,
+          affine_map<(d0, d1, d2) -> (d1, d2)>]
+       : tensor<2x8xi16> into tensor<2x4x2xi16>
+  return %0 : tensor<2x4x2xi16>
+}
+// CHECK-LABEL: @reshape_splat_constant_int16
+//       CHECK:   %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xi16>
+//   CHECK-NOT:   linalg.tensor_reshape
+//       CHECK:   return %[[CST]]
+
+func @reshape_splat_constant_float32() -> tensor<2x4x2xf32>
+{
+  %c0 = constant dense<42.0> : tensor<2x8xf32>
+  %0 = linalg.tensor_reshape %c0
+         [affine_map<(d0, d1, d2) -> (d0)>,
+          affine_map<(d0, d1, d2) -> (d1, d2)>]
+       : tensor<2x8xf32> into tensor<2x4x2xf32>
+  return %0 : tensor<2x4x2xf32>
+}
+// CHECK-LABEL: @reshape_splat_constant_float32
+//       CHECK:   %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf32>
+//   CHECK-NOT:   linalg.tensor_reshape
+//       CHECK:   return %[[CST]]
+
+func @reshape_splat_constant_float64() -> tensor<2x4x2xf64>
+{
+  %c0 = constant dense<42.0> : tensor<2x8xf64>
+  %0 = linalg.tensor_reshape %c0
+         [affine_map<(d0, d1, d2) -> (d0)>,
+          affine_map<(d0, d1, d2) -> (d1, d2)>]
+       : tensor<2x8xf64> into tensor<2x4x2xf64>
+  return %0 : tensor<2x4x2xf64>
+}
+// CHECK-LABEL: @reshape_splat_constant_float64
+//       CHECK:   %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf64>
+//   CHECK-NOT:   linalg.tensor_reshape
+//       CHECK:   return %[[CST]]


        


More information about the Mlir-commits mailing list