[Mlir-commits] [mlir] faa66b1 - [mlir] Bufferize tensor constant ops

Sean Silva llvmlistbot at llvm.org
Thu Nov 12 14:57:07 PST 2020


Author: Sean Silva
Date: 2020-11-12T14:56:10-08:00
New Revision: faa66b1b2c7a328e747c283dfd0dcf43c365330d

URL: https://github.com/llvm/llvm-project/commit/faa66b1b2c7a328e747c283dfd0dcf43c365330d
DIFF: https://github.com/llvm/llvm-project/commit/faa66b1b2c7a328e747c283dfd0dcf43c365330d.diff

LOG: [mlir] Bufferize tensor constant ops

We lower them to a std.global_memref (uniqued by constant value) + a
std.get_global_memref to produce the corresponding memref value.
This allows removing Linalg's somewhat hacky lowering of tensor
constants, now that std properly supports this.

Differential Revision: https://reviews.llvm.org/D91306

Added: 
    mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert.mlir
    mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
    mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
    mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir
    mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
    mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
    mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
    mlir/test/Dialect/Linalg/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 9162543a310c..b9f6c617de50 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -64,7 +64,7 @@ def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
 def LinalgBufferize : Pass<"linalg-bufferize", "ModuleOp"> {
   let summary = "Bufferize the linalg dialect";
   let constructor = "mlir::createLinalgBufferizePass()";
-  let dependentDialects = ["linalg::LinalgDialect", "vector::VectorDialect"];
+  let dependentDialects = ["linalg::LinalgDialect"];
 }
 
 def LinalgLowerToParallelLoops

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
index 235b462b1024..8ecae01ce486 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -35,6 +35,9 @@ std::unique_ptr<Pass> createStdBufferizePass();
 /// Creates an instance of func bufferization pass.
 std::unique_ptr<Pass> createFuncBufferizePass();
 
+/// Creates an instance of tensor constant bufferization pass.
+std::unique_ptr<Pass> createTensorConstantBufferizePass();
+
 /// Creates an instance of the StdExpand pass that legalizes Std
 /// dialect ops to be convertible to LLVM. For example,
 /// `std.ceildivi_signed` gets transformed to a number of std operations,

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
index b3b91a5b02a9..3be398fecb0c 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
@@ -51,4 +51,17 @@ def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
   let constructor = "mlir::createFuncBufferizePass()";
 }
 
+def TensorConstantBufferize : Pass<"tensor-constant-bufferize", "ModuleOp"> {
+  let summary = "Bufferize tensor constants.";
+  let description = [{
+    This pass bufferizes tensor constants.
+
+    This pass needs to be a module pass because it inserts std.global_memref
+    ops into the module, which cannot be done safely from a function pass due to
+    multi-threading. Most other bufferization passes can run in parallel at
+    function granularity.
+  }];
+  let constructor = "mlir::createTensorConstantBufferizePass()";
+}
+
 #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir
index d26b8f75cf28..08bb81246793 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-elementwise.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-opt %s -convert-elementwise-to-linalg -std-bufferize -tensor-constant-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
 // RUN: | FileCheck %s

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir
index fbb026c12138..9744844f9e47 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert-multiple-uses.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -func-bufferize \
+// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
 // RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert.mlir
new file mode 100644
index 000000000000..d20892b3cc40
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-subtensor-insert.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
+// RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func @main() {
+  %const = constant dense<10.0> : tensor<2xf32>
+  %insert_val = constant dense<20.0> : tensor<1xf32>
+  %inserted = subtensor_insert %insert_val into %const[0][1][1] : tensor<1xf32> into tensor<2xf32>
+
+  %unranked = tensor_cast %inserted : tensor<2xf32> to tensor<*xf32>
+  call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
+
+  //      CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
+  // CHECK-SAME: rank = 1 offset = 0 sizes = [2] strides = [1] data =
+  // CHECK-NEXT: [20, 10]
+
+  return
+}
+
+func @print_memref_f32(%ptr : tensor<*xf32>)

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir
index 1ac09d6df195..67e490009194 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-opt %s -tensor-constant-bufferize -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
 // RUN: | FileCheck %s

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
index 453bf62eb3c1..7a828b64b731 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
@@ -1,11 +1,11 @@
-// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -func-bufferize \
+// RUN: mlir-opt %s -linalg-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize \
 // RUN: -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
 // RUN: | FileCheck %s
 
 // RUN: mlir-opt %s  -linalg-tile="linalg-tile-sizes=1,2,3" -linalg-bufferize \
-// RUN: -scf-bufferize -std-bufferize -func-bufferize -convert-linalg-to-loops \
+// RUN: -scf-bufferize -std-bufferize -tensor-constant-bufferize -func-bufferize -convert-linalg-to-loops \
 // RUN:  -convert-scf-to-std -convert-linalg-to-llvm | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index b78a26281c66..255a0d4cff90 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -325,60 +325,6 @@ class SubTensorInsertOpConverter
     return success();
   }
 };
-
-/// TensorConstantOp conversion inserts a linearized 1-D vector constant that is
-/// stored in memory. A linalg.reshape is introduced to convert to the desired
-/// n-D buffer form.
-class TensorConstantOpConverter : public OpConversionPattern<ConstantOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const final {
-
-    RankedTensorType rankedTensorType =
-        op.getType().dyn_cast<RankedTensorType>();
-    if (!rankedTensorType)
-      return failure();
-    if (llvm::any_of(rankedTensorType.getShape(), [](int64_t s) {
-          return s == 0 || ShapedType::isDynamic(s);
-        }))
-      return failure();
-
-    int64_t nElements = 1;
-    for (int64_t s : rankedTensorType.getShape())
-      nElements *= s;
-    Type elementType = rankedTensorType.getElementType();
-    MemRefType memrefType =
-        getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
-    VectorType flatVectorType = VectorType::get({nElements}, elementType);
-    MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType);
-    MemRefType flatMemrefType = MemRefType::get({nElements}, elementType);
-
-    Location loc = op.getLoc();
-    auto attr = op.getValue().cast<DenseElementsAttr>();
-    Value alloc =
-        rewriter.create<AllocOp>(loc, memrefOfFlatVectorType, ValueRange{});
-    Value cstVec = rewriter.create<ConstantOp>(loc, flatVectorType,
-                                               attr.reshape(flatVectorType));
-    rewriter.create<StoreOp>(loc, cstVec, alloc);
-
-    Value memref =
-        rewriter.create<vector::TypeCastOp>(loc, flatMemrefType, alloc);
-    if (rankedTensorType.getRank() > 1) {
-      // Introduce a linalg.reshape to flatten the memref.
-      AffineMap collapseAllDims = AffineMap::getMultiDimIdentityMap(
-          /*numDims=*/rankedTensorType.getRank(), op.getContext());
-      memref = rewriter.create<linalg::ReshapeOp>(
-          loc, memrefType, memref,
-          rewriter.getAffineMapArrayAttr(collapseAllDims));
-    }
-    rewriter.replaceOp(op, memref);
-
-    return success();
-  }
-};
 } // namespace
 
 namespace {
@@ -391,7 +337,7 @@ struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
     BufferizeTypeConverter typeConverter;
 
     // Mark all Standard operations legal.
-    target.addLegalDialect<StandardOpsDialect, vector::VectorDialect>();
+    target.addLegalDialect<StandardOpsDialect>();
     target.addIllegalOp<SubTensorOp, SubTensorInsertOp>();
 
     // Mark all Linalg operations illegal as long as they work on tensors.
@@ -422,8 +368,7 @@ void mlir::linalg::populateLinalgBufferizePatterns(
   patterns.insert<
       // clang-format off
       SubTensorOpConverter,
-      SubTensorInsertOpConverter,
-      TensorConstantOpConverter
+      SubTensorInsertOpConverter
       // clang-format on
       >(typeConverter, context);
 }

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
index caea896a6ffa..ce5494cf855b 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms
   ExpandTanh.cpp
   FuncBufferize.cpp
   FuncConversions.cpp
+  TensorConstantBufferize.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/StandardOps/Transforms

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
new file mode 100644
index 000000000000..ee4b398379f7
--- /dev/null
+++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
@@ -0,0 +1,124 @@
+//===- Bufferize.cpp - Bufferization for std ops --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements bufferization of tensor-valued std.constant ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/Transforms/Bufferize.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+// This class creates global ops for all tensor-valued constants in the program.
+// It creates them with pretty names and makes sure that duplicate globals
+// aren't created.
+class GlobalCreator {
+public:
+  explicit GlobalCreator(ModuleOp module);
+  GlobalMemrefOp getGlobalFor(Attribute attr) {
+    assert(globals.find(attr) != globals.end() && "unknown constant attr");
+    return globals[attr];
+  }
+
+private:
+  DenseMap<Attribute, GlobalMemrefOp> globals;
+};
+
+GlobalCreator::GlobalCreator(ModuleOp module) {
+  BufferizeTypeConverter typeConverter;
+  // Create a builder without an insertion point. We will insert using the
+  // symbol table to guarantee unique names.
+  OpBuilder globalBuilder(module.getContext());
+  SymbolTable symbolTable(module);
+  module.walk([&](ConstantOp op) {
+    // We only want tensor constants for now.
+    auto type = op.getType().dyn_cast<RankedTensorType>();
+    if (!type)
+      return;
+    // If we already have a global for this constant value, no need to do
+    // anything else.
+    auto it = globals.find(op.getValue());
+    if (it != globals.end())
+      return;
+
+    // Create a pretty name.
+    SmallString<64> buf;
+    llvm::raw_svector_ostream os(buf);
+    interleave(type.getShape(), os, "x");
+    os << "x" << type.getElementType();
+
+    auto global = globalBuilder.create<GlobalMemrefOp>(
+        op.getLoc(), (Twine("__constant_") + os.str()).str(),
+        /*sym_visibility=*/globalBuilder.getStringAttr("private"),
+        /*type=*/
+        TypeAttr::get(typeConverter.convertType(type)), /*initial_value=*/
+        op.getValue().cast<ElementsAttr>(), /*constant=*/true);
+    symbolTable.insert(global);
+    // The symbol table inserts at the end of the module, but globals are a bit
+    // nicer if they are at the beginning.
+    global.getOperation()->moveBefore(&module.front());
+    globals[op.getValue()] = global;
+  });
+}
+} // namespace
+
+namespace {
+class BufferizeTensorConstantOp : public OpConversionPattern<ConstantOp> {
+public:
+  BufferizeTensorConstantOp(GlobalCreator &globals,
+                            TypeConverter &typeConverter, MLIRContext *context)
+      : OpConversionPattern<ConstantOp>(typeConverter, context, /*benefit=*/1),
+        globals(globals) {}
+
+  LogicalResult
+  matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto type = op.getType().dyn_cast<RankedTensorType>();
+    if (!type)
+      return failure();
+
+    auto globalMemref = globals.getGlobalFor(op.value());
+    rewriter.replaceOpWithNewOp<GetGlobalMemrefOp>(op, globalMemref.type(),
+                                                   globalMemref.getName());
+    return success();
+  }
+  GlobalCreator &globals;
+};
+} // namespace
+
+namespace {
+struct TensorConstantBufferizePass
+    : public TensorConstantBufferizeBase<TensorConstantBufferizePass> {
+  void runOnOperation() override {
+    auto module = getOperation();
+    GlobalCreator globals(module);
+
+    auto *context = &getContext();
+    BufferizeTypeConverter typeConverter;
+    OwningRewritePatternList patterns;
+    ConversionTarget target(*context);
+
+    target.addLegalDialect<StandardOpsDialect>();
+    patterns.insert<BufferizeTensorConstantOp>(globals, typeConverter, context);
+    target.addDynamicallyLegalOp<ConstantOp>(
+        [&](ConstantOp op) { return typeConverter.isLegal(op.getType()); });
+    if (failed(applyPartialConversion(module, target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createTensorConstantBufferizePass() {
+  return std::make_unique<TensorConstantBufferizePass>();
+}

diff  --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir
index ef79f911d3e8..866ded2fbe0e 100644
--- a/mlir/test/Dialect/Linalg/bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/bufferize.mlir
@@ -94,24 +94,6 @@ func @dynamic_results(%arg0: tensor<?x?xf32>)
 
 // -----
 
-// Check lowering of tensor-valued std.constant's
-// TODO: Move this to std-bufferize.
-
-// CHECK-LABEL:   func @constant() -> tensor<2x3xf32> {
-// CHECK:           %[[VECTOR_MEMREF:.*]] = alloc() : memref<vector<6xf32>>
-// CHECK:           %[[VECTOR_CONST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : vector<6xf32>
-// CHECK:           store %[[VECTOR_CONST]], %[[VECTOR_MEMREF]][] : memref<vector<6xf32>>
-// CHECK:           %[[MEMREF:.*]] = vector.type_cast %[[VECTOR_MEMREF]] : memref<vector<6xf32>> to memref<6xf32>
-// CHECK:           %[[FINAL_SHAPE:.*]] = linalg.reshape %[[MEMREF]] [#map] : memref<6xf32> into memref<2x3xf32>
-// CHECK:           %[[RESULT:.*]] = tensor_load %[[FINAL_SHAPE]] : memref<2x3xf32>
-// CHECK:           return %[[RESULT]] : tensor<2x3xf32>
-func @constant() -> tensor<2x3xf32> {
-  %0 = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
-  return %0: tensor<2x3xf32>
-}
-
-// -----
-
 #accesses = [
   affine_map<(i, j, k) -> (j, i, k)>,
   affine_map<(i, j, k) -> (i, j)>

diff  --git a/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir b/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir
new file mode 100644
index 000000000000..fe897da670da
--- /dev/null
+++ b/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir
@@ -0,0 +1,59 @@
+// RUN: mlir-opt %s -tensor-constant-bufferize -split-input-file
+
+// CHECK-LABEL: module {
+// We check the debug name too since we put some effort into making that readable.
+// The name isn't load-bearing though.
+// CHECK: global_memref "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00>
+// CHECK: @basic
+func @basic() -> tensor<3x4xf32> {
+  // CHECK: %[[MEMREF:.*]] = get_global_memref @__constant_3x4xf32 : memref<3x4xf32>
+  // CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF]]
+  %0 = constant dense<7.0> : tensor<3x4xf32>
+  // CHECK: return %[[TENSOR]]
+  return %0 : tensor<3x4xf32>
+}
+
+// CHECK: }
+
+// -----
+
+// CHECK-LABEL: module {
+
+// Only one global is created.
+// CHECK: global_memref
+// CHECK-NOT: global_memref
+func @duplicate_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
+  %0 = constant dense<7.0> : tensor<3x4xf32>
+  %1 = constant dense<7.0> : tensor<3x4xf32>
+  return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32>
+}
+
+// CHECK: }
+
+// -----
+
+// CHECK-LABEL: module {
+
+// Two globals are created.
+// CHECK: global_memref
+// CHECK: global_memref
+// CHECK-NOT: global_memref
+func @multiple_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
+  %0 = constant dense<7.0> : tensor<3x4xf32>
+  %1 = constant dense<8.0> : tensor<3x4xf32>
+  return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32>
+}
+
+// CHECK: }
+
+// -----
+
+// CHECK-LABEL: module {
+// We don't convert non-tensor globals.
+// CHECK-NOT: global_memref
+func @non_tensor() {
+    %0 = constant 7 : i32
+    return
+}
+
+// CHECK: }


        


More information about the Mlir-commits mailing list