[Mlir-commits] [mlir] 9340996 - [mlir][tensor] Add pattern to rewrite tensor.generate as a constant

Matthias Springer llvmlistbot at llvm.org
Fri Jun 9 03:56:16 PDT 2023


Author: Matthias Springer
Date: 2023-06-09T12:56:07+02:00
New Revision: 93409967061366660171b475db0dd84b8b7703c5

URL: https://github.com/llvm/llvm-project/commit/93409967061366660171b475db0dd84b8b7703c5
DIFF: https://github.com/llvm/llvm-project/commit/93409967061366660171b475db0dd84b8b7703c5.diff

LOG: [mlir][tensor] Add pattern to rewrite tensor.generate as a constant

Only ops with a static tensor type and a constant yield value are rewritten.

Differential Revision: https://reviews.llvm.org/D152511

Added: 
    mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
    mlir/test/Dialect/Tensor/rewrite-as-constant.mlir

Modified: 
    mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
    mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
    mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
    mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index 9a569beaeb880..f23e310497812 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -87,6 +87,17 @@ def ApplyReassociativeReshapeFoldingPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyRewriteTensorOpsAsConstantPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.tensor.rewrite_as_constant",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that tensor ops (such as tensor.generate) should be replaced with
+    constants (arith.constant) when possible.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def Transform_TensorPadOp : Transform_ConcreteOpType<"tensor.pad">;
 
 def MakeLoopIndependentOp

diff  --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 474c559a75220..c7e157e01d06a 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -67,6 +67,10 @@ void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
 /// respectively.
 void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns);
 
+/// Populates `patterns` with patterns that replace tensor ops (such as
+/// tensor.generate) with constants when possible.
+void populateRewriteAsConstantPatterns(RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 // Transform helpers
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index 8c85d18ada00d..b4b822976d99a 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -113,6 +113,11 @@ void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
   tensor::populateReassociativeReshapeFoldingPatterns(patterns);
 }
 
+void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  tensor::populateRewriteAsConstantPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // MakeLoopIndependentOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index 083c9c936d4cf..251c129b5383f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
   IndependenceTransforms.cpp
   MergeConsecutiveInsertExtractSlicePatterns.cpp
   ReshapePatterns.cpp
+  RewriteAsConstant.cpp
   SwapExtractSliceWithProducerPatterns.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
new file mode 100644
index 0000000000000..11e1de543ac91
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
@@ -0,0 +1,53 @@
+//===- RewriteAsConstant.cpp - Patterns to rewrite tensor ops as constants ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+
+/// Rewrite tensor.generate with arith.constant if the yielded value is a
+/// constant and the tensor type is static.
+struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
+  using OpRewritePattern<GenerateOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenerateOp generateOp,
+                                PatternRewriter &rewriter) const override {
+    auto tensorType =
+        llvm::cast<RankedTensorType>(generateOp.getResult().getType());
+    if (!tensorType.hasStaticShape())
+      return failure();
+    auto terminatorOp =
+        cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator());
+    Attribute attr;
+    if (!matchPattern(terminatorOp.getValue(), m_Constant(&attr)))
+      return failure();
+    Operation *constantOp =
+        rewriter.getContext()
+            ->getLoadedDialect<TensorDialect>()
+            ->materializeConstant(rewriter,
+                                  DenseElementsAttr::get(tensorType, attr),
+                                  tensorType, generateOp->getLoc());
+    if (!constantOp)
+      return failure();
+    rewriter.replaceOp(generateOp, constantOp->getResults());
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::tensor::populateRewriteAsConstantPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<GenerateToConstant>(patterns.getContext());
+}

diff  --git a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
new file mode 100644
index 0000000000000..60930a5637ba7
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt -split-input-file -test-transform-dialect-interpreter %s | FileCheck %s
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  transform.apply_patterns to %module_op {
+    transform.apply_patterns.tensor.rewrite_as_constant
+  } : !transform.any_op
+}
+
+// CHECK-LABEL: func @tensor_generate_constant(
+//       CHECK:   %[[cst:.*]] = arith.constant dense<5.000000e+00> : tensor<2x3x5xf32>
+//       CHECK:   return %[[cst]]
+func.func @tensor_generate_constant() -> tensor<2x3x5xf32> {
+  %cst = arith.constant 5.0 : f32
+  %0 = tensor.generate {
+    ^bb0(%arg0: index, %arg1: index, %arg2: index):
+    tensor.yield %cst : f32
+  } : tensor<2x3x5xf32>
+  return %0 : tensor<2x3x5xf32>
+}


        


More information about the Mlir-commits mailing list