[Mlir-commits] [mlir] 9340996 - [mlir][tensor] Add pattern to rewrite tensor.generate as a constant
Matthias Springer
llvmlistbot at llvm.org
Fri Jun 9 03:56:16 PDT 2023
Author: Matthias Springer
Date: 2023-06-09T12:56:07+02:00
New Revision: 93409967061366660171b475db0dd84b8b7703c5
URL: https://github.com/llvm/llvm-project/commit/93409967061366660171b475db0dd84b8b7703c5
DIFF: https://github.com/llvm/llvm-project/commit/93409967061366660171b475db0dd84b8b7703c5.diff
LOG: [mlir][tensor] Add pattern to rewrite tensor.generate as a constant
Only ops with a static tensor type and a constant yield value are rewritten.
Differential Revision: https://reviews.llvm.org/D152511
Added:
mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
Modified:
mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index 9a569beaeb880..f23e310497812 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -87,6 +87,17 @@ def ApplyReassociativeReshapeFoldingPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyRewriteTensorOpsAsConstantPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.tensor.rewrite_as_constant",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that tensor ops (such as tensor.generate) should be replaced with
+ constants (arith.constant) when possible.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def Transform_TensorPadOp : Transform_ConcreteOpType<"tensor.pad">;
def MakeLoopIndependentOp
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 474c559a75220..c7e157e01d06a 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -67,6 +67,10 @@ void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
/// respectively.
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns);
+/// Populates `patterns` with patterns that replace tensor ops (such as
+/// tensor.generate) with constants when possible.
+void populateRewriteAsConstantPatterns(RewritePatternSet &patterns);
+
//===----------------------------------------------------------------------===//
// Transform helpers
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index 8c85d18ada00d..b4b822976d99a 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -113,6 +113,11 @@ void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
tensor::populateReassociativeReshapeFoldingPatterns(patterns);
}
+void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ tensor::populateRewriteAsConstantPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// MakeLoopIndependentOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index 083c9c936d4cf..251c129b5383f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
IndependenceTransforms.cpp
MergeConsecutiveInsertExtractSlicePatterns.cpp
ReshapePatterns.cpp
+ RewriteAsConstant.cpp
SwapExtractSliceWithProducerPatterns.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
new file mode 100644
index 0000000000000..11e1de543ac91
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
@@ -0,0 +1,53 @@
+//===- RewriteAsConstant.cpp - Patterns to rewrite tensor ops as constants ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+
+/// Rewrite tensor.generate with arith.constant if the yielded value is a
+/// constant and the tensor type is static.
+struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
+ using OpRewritePattern<GenerateOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(GenerateOp generateOp,
+ PatternRewriter &rewriter) const override {
+ auto tensorType =
+ llvm::cast<RankedTensorType>(generateOp.getResult().getType());
+ if (!tensorType.hasStaticShape())
+ return failure();
+ auto terminatorOp =
+ cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator());
+ Attribute attr;
+ if (!matchPattern(terminatorOp.getValue(), m_Constant(&attr)))
+ return failure();
+ Operation *constantOp =
+ rewriter.getContext()
+ ->getLoadedDialect<TensorDialect>()
+ ->materializeConstant(rewriter,
+ DenseElementsAttr::get(tensorType, attr),
+ tensorType, generateOp->getLoc());
+ if (!constantOp)
+ return failure();
+ rewriter.replaceOp(generateOp, constantOp->getResults());
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::tensor::populateRewriteAsConstantPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<GenerateToConstant>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
new file mode 100644
index 0000000000000..60930a5637ba7
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt -split-input-file -test-transform-dialect-interpreter %s | FileCheck %s
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+ transform.apply_patterns to %module_op {
+ transform.apply_patterns.tensor.rewrite_as_constant
+ } : !transform.any_op
+}
+
+// CHECK-LABEL: func @tensor_generate_constant(
+// CHECK: %[[cst:.*]] = arith.constant dense<5.000000e+00> : tensor<2x3x5xf32>
+// CHECK: return %[[cst]]
+func.func @tensor_generate_constant() -> tensor<2x3x5xf32> {
+ %cst = arith.constant 5.0 : f32
+ %0 = tensor.generate {
+ ^bb0(%arg0: index, %arg1: index, %arg2: index):
+ tensor.yield %cst : f32
+ } : tensor<2x3x5xf32>
+ return %0 : tensor<2x3x5xf32>
+}
More information about the Mlir-commits
mailing list