[Mlir-commits] [mlir] 1cddcfd - Fix CollapsedLayoutMap for dim size 1 case

Yi Zhang llvmlistbot at llvm.org
Fri Apr 22 14:49:45 PDT 2022


Author: Yi Zhang
Date: 2022-04-22T17:48:24-04:00
New Revision: 1cddcfdc3c683b393df1a5c9063252eb60e52818

URL: https://github.com/llvm/llvm-project/commit/1cddcfdc3c683b393df1a5c9063252eb60e52818
DIFF: https://github.com/llvm/llvm-project/commit/1cddcfdc3c683b393df1a5c9063252eb60e52818.diff

LOG: Fix CollapsedLayoutMap for dim size 1 case

This change fixes `CollapsedLayoutMap` for cases where the collapsed
dims are size 1. The cases where inner most dims are size 1 and
noncontiguous can be represented by the strided form and therefore can
be allowed. For such cases, the new stride should be of the next entry
in an association whose dimension is not size 1. If the next entry is
dynamic, it's not possible to decide which stride to use at compilation
time and the stride is set to dynamic.

Differential Revision: https://reviews.llvm.org/D124137

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir
    mlir/test/Dialect/Tensor/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 5d16e3f018d43..e7f2e03434899 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1824,12 +1824,27 @@ computeCollapsedLayoutMap(MemRefType srcType,
   if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
     return failure();
 
-  // The result strides are exactly the strides of the last entry of each
-  // reassociation.
+  // The result stride of a reassociation group is the stride of the last entry
+  // of the reassociation. (TODO: Should be the minimum stride in the
+  // reassociation because strides are not necessarily sorted. E.g., when using
+  // memref.transpose.) Dimensions of size 1 should be skipped, because their
+  // strides are meaningless and could have any arbitrary value.
   SmallVector<int64_t> resultStrides;
   resultStrides.reserve(reassociation.size());
-  for (ReassociationIndices reassoc : reassociation)
-    resultStrides.push_back(srcStrides[reassoc.back()]);
+  for (const ReassociationIndices &reassoc : reassociation) {
+    ArrayRef<int64_t> ref = llvm::makeArrayRef(reassoc);
+    while (srcShape[ref.back()] == 1 && ref.size() > 1)
+      ref = ref.drop_back();
+    if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
+      resultStrides.push_back(srcStrides[ref.back()]);
+    } else {
+      // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
+      // the corresponding stride may have to be skipped. (See above comment.)
+      // Therefore, the result stride cannot be statically determined and must
+      // be dynamic.
+      resultStrides.push_back(ShapedType::kDynamicStrideOrOffset);
+    }
+  }
 
   // Validate that each reassociation group is contiguous.
   unsigned resultStrideIndex = resultStrides.size() - 1;

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index b3a9b7e31a309..8ebd9f39ccbd6 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -331,14 +331,14 @@ func.func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>)
 
 func.func @do_not_compose_collapse_of_expand_non_identity_layout(
     %arg0: memref<?x?xf32, offset : 0, strides : [?, 1]>)
-    -> memref<?xf32> {
+    -> memref<?xf32, offset : 0, strides : [?]> {
   %1 = memref.expand_shape %arg0 [[0, 1], [2]] :
     memref<?x?xf32, offset : 0, strides : [?, 1]> into
     memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]>
   %2 = memref.collapse_shape %1 [[0, 1, 2]] :
     memref<?x4x?xf32, offset : 0, strides : [?, ?, 1]> into
-    memref<?xf32>
-  return %2 : memref<?xf32>
+    memref<?xf32, offset : 0, strides : [?]>
+  return %2 : memref<?xf32, offset : 0, strides : [?]>
 }
 // CHECK-LABEL: func @do_not_compose_collapse_of_expand_non_identity_layout
 // CHECK: expand

diff  --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 83fb1f0bfd9e3..587508c698e3f 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -1,11 +1,14 @@
 // RUN: mlir-opt %s -tensor-bufferize -cse | FileCheck %s
 
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)>
-// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 140 + d1 * 20 + d2 * 5 + d3 + s0)>
-// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0 + 1)>
-// CHECK-DAG: #[[$MAP4:.*]] = affine_map<() -> (1)>
-// CHECK-DAG: #[[$MAP5:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
+ // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)>
+ // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 140 + d1 * 20 + d2 * 5 + d3 + s0)>
+ // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0 + 1)>
+ // CHECK-DAG: #[[$MAP4:.*]] = affine_map<() -> (1)>
+ // CHECK-DAG: #[[$MAP5:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
+ // CHECK-DAG: #[[$MAP6:.*]] = affine_map<(d0) -> (d0 * 2)>
+ // CHECK-DAG: #[[$MAP7:.*]] = affine_map<(d0, d1, d2)[s0] -> (d0 * 8 + s0 + d1 * 4 + d2)>
+ // CHECK-DAG: #[[$MAP8:.*]] = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
 
 // CHECK-LABEL:   func @dim(
 // CHECK-SAME:              %[[TENSOR:.*]]: tensor<f32>,
@@ -330,17 +333,6 @@ func.func @tensor.expand_shape_of_slice(
   return %1 : tensor<?x7x2x5xf32>
 }
 
-// CHECK-LABEL: func @tensor.expand_shape_of_slice2(
-//  CHECK-SAME:     %[[t1:.*]]: tensor<1x2xf32>
-func.func @tensor.expand_shape_of_slice2(%t1: tensor<1x2xf32>) -> tensor<1xf32> {
-  // CHECK: memref.subview {{.*}} : memref<1x2xf32> to memref<1x1xf32, #[[$MAP5]]>
-  %0 = tensor.extract_slice %t1[0, 0][1, 1][1, 1] : tensor<1x2xf32> to tensor<1x1xf32>
-  // CHECK: memref.collapse_shape %{{.*}} [
-  // CHECK-SAME: [0, 1]] : memref<1x1xf32, #[[$MAP5]]> into memref<1xf32>
-  %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x1xf32> into tensor<1xf32>
-  return %1 : tensor<1xf32>
-}
-
 // CHECK-LABEL: func @tensor.collapse_shape(
 //  CHECK-SAME:     %[[t1:.*]]: tensor<2x?x?xf32>
 func.func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor<?x?xf32> {
@@ -393,3 +385,26 @@ func.func @tensor.collapse_shape_of_slice2(
   %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] : tensor<87x78x68x12xi64> into tensor<87x63648xi64>
   return %1 : tensor<87x63648xi64>
 }
+
+// CHECK-LABEL: func @tensor.collapse_shape_of_slice3(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<1x2xf32>
+func.func @tensor.collapse_shape_of_slice3(%t1: tensor<1x2xf32>) -> tensor<1xf32> {
+  // CHECK: memref.subview {{.*}} : memref<1x2xf32> to memref<1x1xf32, #[[$MAP5]]>
+  %0 = tensor.extract_slice %t1[0, 0][1, 1][1, 1] : tensor<1x2xf32> to tensor<1x1xf32>
+  // CHECK: memref.collapse_shape %{{.*}} [
+  // CHECK-SAME: [0, 1]] : memref<1x1xf32, #[[$MAP5]]> into memref<1xf32, #[[$MAP6]]>
+  %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x1xf32> into tensor<1xf32>
+  return %1 : tensor<1xf32>
+}
+
+// CHECK-LABEL:   func @tensor.collapse_shape_of_slice4(
+//  CHECK-SAME:     %[[t1:.*]]: tensor<?x2x4xf32>,
+// CHECK-SAME:      %[[OFFSET:.*]]: index) -> tensor<8xf32> {
+func.func @tensor.collapse_shape_of_slice4(%arg0: tensor<?x2x4xf32>, %offset: index, %size: index) -> tensor<8xf32> {
+  // CHECK: memref.subview %{{.*}} : memref<?x2x4xf32> to memref<4x2x1xf32, #[[$MAP7]]>
+  %0 = tensor.extract_slice %arg0[0, 0, %offset] [4, 2, 1] [1, 1, 1] : tensor<?x2x4xf32> to tensor<4x2x1xf32>
+  // CHECK: memref.collapse_shape %{{.*}} [
+  // CHECK-SAME: [0, 1, 2]] : memref<4x2x1xf32, #[[$MAP7]]> into memref<8xf32, #[[$MAP8]]>
+  %ret = tensor.collapse_shape %0 [[0, 1, 2]] : tensor<4x2x1xf32> into tensor<8xf32>
+  return %ret: tensor<8xf32>
+}


        


More information about the Mlir-commits mailing list