[Mlir-commits] [mlir] 51df623 - [mlir][tensor] Fix bufferization of CollapseShapeOp / ExpandShapeOp

Matthias Springer llvmlistbot at llvm.org
Thu Mar 31 01:11:55 PDT 2022


Author: Matthias Springer
Date: 2022-03-31T17:11:45+09:00
New Revision: 51df62388e83a406e4a946ff8aae1f7299a2d92b

URL: https://github.com/llvm/llvm-project/commit/51df62388e83a406e4a946ff8aae1f7299a2d92b
DIFF: https://github.com/llvm/llvm-project/commit/51df62388e83a406e4a946ff8aae1f7299a2d92b.diff

LOG: [mlir][tensor] Fix bufferization of CollapseShapeOp / ExpandShapeOp

Infer a tighter MemRef type instead of always falling back to the most dynamic MemRef type. This is inefficient and caused op verification errors.

Differential Revision: https://reviews.llvm.org/D122649

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Tensor/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index a9519b98803cc..6f86270e5adda 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -108,12 +108,27 @@ struct CollapseShapeOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
+    RankedTensorType tensorResultType = collapseShapeOp.getResultType();
     Value buffer =
         *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/);
-    Type resultType =
-        getMemRefType(collapseShapeOp.getResultType(), state.getOptions());
+
+    if (tensorResultType.getRank() == 0) {
+      // 0-d collapses must go through a 
diff erent op builder.
+      auto bufferType = buffer.getType().cast<MemRefType>();
+      // Assume identity layout: No offset.
+      assert(bufferType.getLayout().isIdentity() &&
+             "non-zero offset for 0-d collapse not supported");
+      MemRefLayoutAttrInterface layout;
+      auto resultType = MemRefType::get({}, tensorResultType.getElementType(),
+                                        layout, bufferType.getMemorySpace());
+      replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
+          rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
+      return success();
+    }
+
+    // Result type is inferred by the builder.
     replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
-        rewriter, op, resultType, buffer, collapseShapeOp.reassociation());
+        rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
     return success();
   }
 };
@@ -175,12 +190,15 @@ struct ExpandShapeOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           BufferizationState &state) const {
     auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
+    auto tensorResultType = expandShapeOp.getResultType();
     Value buffer =
         *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
-    Type resultType =
-        getMemRefType(expandShapeOp.getResultType(), state.getOptions());
+
+    // Memref result type is inferred by the builder based on reassociation
+    // indices and result shape.
     replaceOpWithNewBufferizedOp<memref::ExpandShapeOp>(
-        rewriter, op, resultType, buffer, expandShapeOp.reassociation());
+        rewriter, op, tensorResultType.getShape(), buffer,
+        expandShapeOp.getReassociationIndices());
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index cbb05473807b0..5fa8e3f8a2a46 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -1,6 +1,8 @@
 // RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
 
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 140 + d1 * 20 + d2 * 5 + d3 + s0)>
 
 // CHECK-LABEL:   func @dim(
 // CHECK-SAME:              %[[TENSOR:.*]]: tensor<f32>,
@@ -242,7 +244,7 @@ func @tensor.generate_unknown_ops_in_body(%arg0: index) -> tensor<?xindex> {
 func @tensor.extract_slice(
     %t1: tensor<?x?xf32>, %idx1: index, %idx2: index) -> tensor<?x10xf32> {
   // CHECK: %[[m:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32>
-  // CHECK: %[[r:.*]] = memref.subview %[[m]][5, %[[idx2]]] [%[[idx1]], 10] [1, 1] : memref<?x?xf32> to memref<?x10xf32, #[[$MAP]]>
+  // CHECK: %[[r:.*]] = memref.subview %[[m]][5, %[[idx2]]] [%[[idx1]], 10] [1, 1] : memref<?x?xf32> to memref<?x10xf32, #[[$MAP0]]>
   %0 = tensor.extract_slice %t1[5, %idx2][%idx1, 10][1, 1]
       : tensor<?x?xf32> to tensor<?x10xf32>
   // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
@@ -256,7 +258,7 @@ func @tensor.extract_slice(
 func @tensor.extract_slice_rank_reducing(
     %t1: tensor<?x10x?xf32>, %idx1: index, %idx2: index) -> tensor<?x15xf32> {
   // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10x?xf32>
-  // CHECK: %[[r:.*]] = memref.subview %[[m1]][5, %[[idx1]], 10] [%[[idx2]], 1, 15] [1, 1, 1] : memref<?x10x?xf32> to memref<?x15xf32, #[[$MAP]]>
+  // CHECK: %[[r:.*]] = memref.subview %[[m1]][5, %[[idx1]], 10] [%[[idx2]], 1, 15] [1, 1, 1] : memref<?x10x?xf32> to memref<?x15xf32, #[[$MAP0]]>
   %0 = tensor.extract_slice %t1[5, %idx1, 10][%idx2, 1, 15][1, 1, 1]
       : tensor<?x10x?xf32> to tensor<?x15xf32>
   // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
@@ -316,6 +318,23 @@ func @tensor.expand_shape(%t1: tensor<?x10xf32>) -> tensor<2x?x10xf32> {
   return %0 : tensor<2x?x10xf32>
 }
 
+// CHECK-LABEL: func @tensor.expand_shape_of_slice(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x20xf32>
+func @tensor.expand_shape_of_slice(
+    %t1: tensor<?x20xf32>, %o1: index, %s1: index) -> tensor<?x7x2x5xf32> {
+  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x20xf32>
+  // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, #[[$MAP1]]>
+  %0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] :
+      tensor<?x20xf32> to tensor<?x10xf32>
+  // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [
+  // CHECK-SAME: [0, 1], [2, 3]] : memref<?x10xf32, #[[$MAP1]]> into memref<?x7x2x5xf32, #[[$MAP2]]>
+  %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] :
+      tensor<?x10xf32> into tensor<?x7x2x5xf32>
+  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
+  // CHECK: return %[[r]]
+  return %1 : tensor<?x7x2x5xf32>
+}
+
 // CHECK-LABEL: func @tensor.collapse_shape(
 //  CHECK-SAME:     %[[t1:.*]]: tensor<2x?x?xf32>
 func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor<?x?xf32> {
@@ -329,3 +348,16 @@ func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor<?x?xf32> {
   // CHECK: return %[[r]]
   return %0 : tensor<?x?xf32>
 }
+
+// CHECK-LABEL: func @tensor.collapse_shape_to_scalar(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<1x1x1xf32>
+func @tensor.collapse_shape_to_scalar(%t1: tensor<1x1x1xf32>) -> tensor<f32> {
+  // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<1x1x1xf32>
+  // CHECK: %[[collapsed:.*]] = memref.collapse_shape %[[m1]] [] : memref<1x1x1xf32> into memref<f32>
+  %0 = tensor.collapse_shape %t1 []
+      : tensor<1x1x1xf32> into tensor<f32>
+
+  // CHECK: %[[r:.*]] = bufferization.to_tensor %[[collapsed]]
+  // CHECK: return %[[r]]
+  return %0 : tensor<f32>
+}


        


More information about the Mlir-commits mailing list