[Mlir-commits] [mlir] 6121117 - [mlir][Linalg] Fix TensorConstantOp bufferization in Linalg.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Oct 13 09:37:29 PDT 2020


Author: Nicolas Vasilache
Date: 2020-10-13T16:36:56Z
New Revision: 6121117484ddd7c5a03b40004a8bba58506ce9d0

URL: https://github.com/llvm/llvm-project/commit/6121117484ddd7c5a03b40004a8bba58506ce9d0
DIFF: https://github.com/llvm/llvm-project/commit/6121117484ddd7c5a03b40004a8bba58506ce9d0.diff

LOG: [mlir][Linalg] Fix TensorConstantOp bufferization in Linalg.

TensorConstantOp bufferization currently uses the vector dialect to store constant data into memory.
Due to natural vector size and alignment properties, this is problematic with n>1-D vectors whose most minor dimension is not naturally aligned.

Instead, this revision linearizes the constant and introduces a linalg.reshape to go back to the desired shape.

Still this is still to be considered a workaround and a better longer term solution will probably involve `llvm.global`.

Differential Revision: https://reviews.llvm.org/D89311

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
    mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
    mlir/test/Dialect/Linalg/tensors-to-buffers.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a2dee8c3ae65..948883e89e1e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -816,6 +816,9 @@ class LinalgOpConverter : public BufferAssignmentConversionPattern {
                   ConversionPatternRewriter &rewriter) const final;
 };
 
+/// TensorConstantOp conversion inserts a linearized 1-D vector constant that is
+/// stored in memory. A linalg.reshape is introduced to convert to the desired
+/// n-D buffer form.
 class TensorConstantOpConverter
     : public BufferAssignmentOpConversionPattern<ConstantOp> {
 public:
@@ -827,6 +830,7 @@ class TensorConstantOpConverter
                   ConversionPatternRewriter &rewriter) const final;
 };
 
+/// TensorCastOp converts 1-1 to MemRefCastOp.
 class TensorCastOpConverter
     : public BufferAssignmentOpConversionPattern<TensorCastOp> {
 public:

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
index b706c9cef976..a06874550f56 100644
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
@@ -4,12 +4,13 @@
 // RUN: | FileCheck %s
 
 func @main() {
-  %A = constant dense<[[1.0, 2.0], [4.0, 5.0]]> : tensor<2x2xf32>
+  %A = constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
   %B = constant dense<[[1.0, 2.0, 3.0, 4.0],
-                       [5.0, 6.0, 7.0, 8.0]]> : tensor<2x4xf32>
+                       [5.0, 6.0, 7.0, 8.0],
+                       [9.0, 10.0, 11.0, 12.0]]> : tensor<3x4xf32>
   %C = constant dense<1000.0> : tensor<2x4xf32>
 
-  %D = linalg.matmul ins(%A, %B: tensor<2x2xf32>, tensor<2x4xf32>)
+  %D = linalg.matmul ins(%A, %B: tensor<2x3xf32>, tensor<3x4xf32>)
                      init(%C: tensor<2x4xf32>) -> tensor<2x4xf32>
 
   %unranked = tensor_cast %D : tensor<2x4xf32> to tensor<*xf32>
@@ -17,10 +18,11 @@ func @main() {
 
   //      CHECK: Unranked Memref base@ = {{0x[-9a-f]*}}
   // CHECK-SAME: rank = 2 offset = 0 sizes = [2, 4] strides = [4, 1] data =
-  // CHECK-NEXT: [1011, 1014, 1017, 1020]
-  // CHECK-NEXT: [1029, 1038, 1047, 1056]
+  // CHECK-NEXT: [1038,   1044,   1050,   1056]
+  // CHECK-NEXT: [1083,   1098,   1113,   1128]
 
   return
 }
 
 func @print_memref_f32(%ptr : tensor<*xf32>)
+

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
index 650c44f9e922..303624f2898d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
@@ -239,22 +239,43 @@ LogicalResult mlir::linalg::LinalgOpConverter::matchAndRewrite(
 LogicalResult mlir::linalg::TensorConstantOpConverter::matchAndRewrite(
     ConstantOp op, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
-  if (!op.getType().isa<RankedTensorType>())
+  RankedTensorType rankedTensorType = op.getType().dyn_cast<RankedTensorType>();
+  if (!rankedTensorType)
+    return failure();
+  if (llvm::any_of(rankedTensorType.getShape(), [](int64_t s) {
+        return s == 0 || ShapedType::isDynamic(s);
+      }))
     return failure();
-  auto attr = op.getValue().cast<DenseElementsAttr>();
 
-  Location loc = op.getLoc();
+  int64_t nElements = 1;
+  for (int64_t s : rankedTensorType.getShape())
+    nElements *= s;
+  Type elementType = rankedTensorType.getElementType();
   MemRefType memrefType =
       converter.convertType(op.getType()).cast<MemRefType>();
-  VectorType vectorType =
-      VectorType::get(memrefType.getShape(), memrefType.getElementType());
-  Value cstVec =
-      rewriter.create<ConstantOp>(loc, vectorType, attr.reshape(vectorType));
+  VectorType flatVectorType = VectorType::get({nElements}, elementType);
+  MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType);
+  MemRefType flatMemrefType = MemRefType::get({nElements}, elementType);
 
-  MemRefType memrefOfVectorType = MemRefType::get({}, vectorType);
-  Value alloc = rewriter.create<AllocOp>(loc, memrefOfVectorType, ValueRange{});
+  Location loc = op.getLoc();
+  auto attr = op.getValue().cast<DenseElementsAttr>();
+  Value alloc =
+      rewriter.create<AllocOp>(loc, memrefOfFlatVectorType, ValueRange{});
+  Value cstVec = rewriter.create<ConstantOp>(loc, flatVectorType,
+                                             attr.reshape(flatVectorType));
   rewriter.create<StoreOp>(loc, cstVec, alloc);
-  rewriter.replaceOpWithNewOp<vector::TypeCastOp>(op, memrefType, alloc);
+
+  Value memref =
+      rewriter.create<vector::TypeCastOp>(loc, flatMemrefType, alloc);
+  if (rankedTensorType.getRank() > 1) {
+    // Introduce a linalg.reshape to flatten the memref.
+    AffineMap collapseAllDims = AffineMap::getMultiDimIdentityMap(
+        /*numDims=*/rankedTensorType.getRank(), op.getContext());
+    memref = rewriter.create<linalg::ReshapeOp>(
+        loc, memrefType, memref,
+        rewriter.getAffineMapArrayAttr(collapseAllDims));
+  }
+  rewriter.replaceOp(op, memref);
 
   return success();
 }

diff  --git a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir
index 4d23b7e10dae..093732c7f47d 100644
--- a/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir
+++ b/mlir/test/Dialect/Linalg/tensors-to-buffers.mlir
@@ -126,28 +126,29 @@ func @dynamic_results(%arg0: tensor<?x?xf32>)
 
 // -----
 
-func @foo() -> tensor<4xf32> {
+func @foo() -> tensor<2x3xf32> {
 // CHECK-LABEL: func @foo(
-//  CHECK-SAME:   %[[A:[0-9a-z]*]]: memref<4xf32>) {
-
-  %0 = constant dense<[0.0, 1.0, 2.0, 3.0]> : tensor<4xf32>
-//  CHECK-NEXT:   %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : vector<4xf32>
-//  CHECK-NEXT:   %[[ALLOC:.*]] = alloc() : memref<vector<4xf32>>
-//  CHECK-NEXT:   store %[[CST]], %[[ALLOC]][] : memref<vector<4xf32>>
-//  CHECK-NEXT:   %[[RES:.*]] = vector.type_cast %[[ALLOC]] : memref<vector<4xf32>> to memref<4xf32>
-
-  return %0 : tensor<4xf32>
-//  CHECK-NEXT:   linalg.copy(%[[RES]], %[[A]]) : memref<4xf32>, memref<4xf32>
-//  CHECK-NEXT:   dealloc %[[ALLOC]] : memref<vector<4xf32>>
+//  CHECK-SAME:   %[[A:[0-9a-z]*]]: memref<2x3xf32>) {
+
+  %0 = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
+//  CHECK-NEXT:   %[[ALLOC:.*]] = alloc() : memref<vector<6xf32>>
+//  CHECK-NEXT:   %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : vector<6xf32>
+//  CHECK-NEXT:   store %[[CST]], %[[ALLOC]][] : memref<vector<6xf32>>
+//  CHECK-NEXT:   %[[FLAT:.*]] = vector.type_cast %[[ALLOC]] : memref<vector<6xf32>> to memref<6xf32>
+//  CHECK-NEXT:   %[[RES:.*]] = linalg.reshape %[[FLAT]] {{.*}} : memref<6xf32> into memref<2x3xf32>
+
+  return %0 : tensor<2x3xf32>
+//  CHECK-NEXT:   linalg.copy(%[[RES]], %[[A]]) : memref<2x3xf32>, memref<2x3xf32>
+//  CHECK-NEXT:   dealloc %[[ALLOC]] : memref<vector<6xf32>>
 //  CHECK-NEXT:   return
 }
 
 func @bar() {
 // CHECK-LABEL: func @bar() {
 
-  %0 = call @foo() : () -> tensor<4xf32>
-//  CHECK-NEXT:   %[[ALLOC:.*]] = alloc() : memref<4xf32>
-//  CHECK-NEXT:   call @foo(%[[ALLOC]]) : (memref<4xf32>) -> ()
+  %0 = call @foo() : () -> tensor<2x3xf32>
+//  CHECK-NEXT:   %[[ALLOC:.*]] = alloc() : memref<2x3xf32>
+//  CHECK-NEXT:   call @foo(%[[ALLOC]]) : (memref<2x3xf32>) -> ()
 
   // Instead of relying on tensor_store which introduces aliasing, we rely on
   // the conversion of print_memref_f32(tensor<*xf32>) to
@@ -155,15 +156,15 @@ func @bar() {
   // Note that this is skipping a step and we would need at least some function
   // attribute to declare that this conversion is valid (e.g. when we statically
   // know that things will play nicely at the C ABI boundary).
-  %unranked = tensor_cast %0 : tensor<4xf32> to tensor<*xf32>
+  %unranked = tensor_cast %0 : tensor<2x3xf32> to tensor<*xf32>
 //  CHECK-NEXT:   %[[UNRANKED:.*]] = memref_cast %[[ALLOC]] :
-//  CHECK-SAME:     memref<4xf32> to memref<*xf32>
+//  CHECK-SAME:     memref<2x3xf32> to memref<*xf32>
 
   call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
 //  CHECK-NEXT:   call @print_memref_f32(%[[UNRANKED]]) : (memref<*xf32>) -> ()
 
   return
-//  CHECK-NEXT:   dealloc %[[ALLOC]] : memref<4xf32>
+//  CHECK-NEXT:   dealloc %[[ALLOC]] : memref<2x3xf32>
 //  CHECK-NEXT:   return
 }
 


        


More information about the Mlir-commits mailing list