[Mlir-commits] [mlir] 3cb1f35 - [mlir][Linalg] Use subview instead of linalg.slice in Promotion.cpp

Nicolas Vasilache llvmlistbot at llvm.org
Tue Apr 7 20:56:05 PDT 2020


Author: Nicolas Vasilache
Date: 2020-04-07T23:52:31-04:00
New Revision: 3cb1f35df2a560d5a8ce2047b87cba8e3c904170

URL: https://github.com/llvm/llvm-project/commit/3cb1f35df2a560d5a8ce2047b87cba8e3c904170
DIFF: https://github.com/llvm/llvm-project/commit/3cb1f35df2a560d5a8ce2047b87cba8e3c904170.diff

LOG: [mlir][Linalg] Use subview instead of linalg.slice in Promotion.cpp

This revision removes the reliance of Promotion on `linalg.slice` which is meant
for the rank-reducing case.

Differential Revision: https://reviews.llvm.org/D77676

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
    mlir/test/Dialect/Linalg/promote.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index f393ca2f12f9..78ce78108722 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -39,17 +39,42 @@ using llvm::SetVector;
 
 using folded_affine_min = folded::ValueBuilder<AffineMinOp>;
 using folded_linalg_range = folded::ValueBuilder<linalg::RangeOp>;
+using folded_std_dim = folded::ValueBuilder<DimOp>;
+using folded_std_subview = folded::ValueBuilder<SubViewOp>;
+using folded_std_view = folded::ValueBuilder<ViewOp>;
 
 #define DEBUG_TYPE "linalg-promotion"
 
-static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers) {
+/// If `size` comes from an AffineMinOp and one of the dimensions of AffineMin
+/// is a constant then return a new value set to the smallest such constant.
+/// Otherwise return size.
+static Value extractSmallestConstantBoundingSize(OpBuilder &b, Location loc,
+                                                 Value size) {
+  auto affineMinOp = dyn_cast_or_null<AffineMinOp>(size.getDefiningOp());
+  if (!affineMinOp)
+    return size;
+  if (!llvm::any_of(affineMinOp.getAffineMap().getResults(), [](AffineExpr e) {
+        return e.dyn_cast<AffineConstantExpr>();
+      }))
+    return size;
+  int64_t minConst = std::numeric_limits<int64_t>::max();
+  for (auto e : affineMinOp.getAffineMap().getResults())
+    if (auto cst = e.dyn_cast<AffineConstantExpr>())
+      minConst = std::min(minConst, cst.getValue());
+  assert(minConst != std::numeric_limits<int64_t>::max());
+  return b.create<ConstantIndexOp>(loc, minConst);
+}
+
+static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers,
+                         OperationFolder *folder) {
   auto *ctx = size.getContext();
   auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
   if (!dynamicBuffers)
     if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp()))
       return std_alloc(
           MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)));
-  Value mul = std_muli(std_constant_index(width), size);
+  Value mul =
+      folded_std_muli(folder, folded_std_constant_index(folder, width), size);
   return std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul);
 }
 
@@ -80,24 +105,28 @@ static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
   auto viewType = subView.getType();
   auto rank = viewType.getRank();
   Value allocSize = one;
-  SmallVector<Value, 8> fullRanges, partialRanges;
-  fullRanges.reserve(rank);
-  partialRanges.reserve(rank);
+  SmallVector<Value, 8> fullSizes, partialSizes;
+  fullSizes.reserve(rank);
+  partialSizes.reserve(rank);
   for (auto en : llvm::enumerate(subView.getRanges())) {
     auto rank = en.index();
     auto rangeValue = en.value();
-    Value d = rangeValue.size;
-    allocSize = folded_std_muli(folder, allocSize, d).getValue();
-    fullRanges.push_back(d);
-    partialRanges.push_back(
-        folded_linalg_range(folder, zero, std_dim(subView, rank), one));
+    // Try to extract a tight constant
+    Value size = extractSmallestConstantBoundingSize(b, loc, rangeValue.size);
+    allocSize = folded_std_muli(folder, allocSize, size).getValue();
+    fullSizes.push_back(size);
+    partialSizes.push_back(folded_std_dim(folder, subView, rank));
   }
-  SmallVector<int64_t, 4> dynSizes(fullRanges.size(), -1);
+  SmallVector<int64_t, 4> dynSizes(fullSizes.size(), -1);
   auto buffer =
-      allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers);
-  auto fullLocalView = std_view(
-      MemRefType::get(dynSizes, viewType.getElementType()), buffer, fullRanges);
-  auto partialLocalView = linalg_slice(fullLocalView, partialRanges);
+      allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers, folder);
+  auto fullLocalView = folded_std_view(
+      folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer,
+      fullSizes);
+  SmallVector<Value, 4> zeros(fullSizes.size(), zero);
+  SmallVector<Value, 4> ones(fullSizes.size(), one);
+  auto partialLocalView =
+      folded_std_subview(folder, fullLocalView, zeros, partialSizes, ones);
   return PromotionInfo{buffer, fullLocalView, partialLocalView};
 }
 

diff  --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir
index a9b860d8f28b..9c00f67f6776 100644
--- a/mlir/test/Dialect/Linalg/promote.mlir
+++ b/mlir/test/Dialect/Linalg/promote.mlir
@@ -7,7 +7,6 @@
 #map3 = affine_map<(d0) -> (d0 + 3)>
 
 // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
-// CHECK-DAG: #[[strided2DnoOffset:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>
 // CHECK-DAG: #[[strided2D_dynamic:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
 
 func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
@@ -46,28 +45,28 @@ func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 //       CHECK:         %[[tmpA:.*]] = alloc() : memref<32xi8>
 //       CHECK:         %[[fullA:.*]] = std.view %[[tmpA]][][{{.*}}] : memref<32xi8> to memref<?x?xf32>
 //     DYNAMIC:         std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf32>
-//       CHECK:         %[[partialA:.*]] = linalg.slice %[[fullA]][%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2DnoOffset]]>
+//       CHECK:         %[[partialA:.*]] = subview %[[fullA]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, #[[strided2D_dynamic]]>
 ///
 //       CHECK:         %[[tmpB:.*]] = alloc() : memref<48xi8>
 //       CHECK:         %[[fullB:.*]] = std.view %[[tmpB]][][{{.*}}] : memref<48xi8> to memref<?x?xf32>
 //     DYNAMIC:         std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf32>
-//       CHECK:         %[[partialB:.*]] = linalg.slice %[[fullB]][%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2DnoOffset]]>
+//       CHECK:         %[[partialB:.*]] = subview %[[fullB]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, #[[strided2D_dynamic]]>
 ///
 //       CHECK:         %[[tmpC:.*]] = alloc() : memref<24xi8>
 //       CHECK:         %[[fullC:.*]] = std.view %[[tmpC]][][{{.*}}] : memref<24xi8> to memref<?x?xf32>
 //     DYNAMIC:         std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf32>
-//       CHECK:         %[[partialC:.*]] = linalg.slice %[[fullC]][%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2DnoOffset]]>
+//       CHECK:         %[[partialC:.*]] = subview %[[fullC]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, #[[strided2D_dynamic]]>
 
 //       CHECK:         linalg.fill(%[[fullA]], {{.*}}) : memref<?x?xf32>, f32
 //       CHECK:         linalg.fill(%[[fullB]], {{.*}}) : memref<?x?xf32>, f32
 //       CHECK:         linalg.fill(%[[fullC]], {{.*}}) : memref<?x?xf32>, f32
-//       CHECK:         linalg.copy(%[[vA]], %[[partialA]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2DnoOffset]]>
-//       CHECK:         linalg.copy(%[[vB]], %[[partialB]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2DnoOffset]]>
-//       CHECK:         linalg.copy(%[[vC]], %[[partialC]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2DnoOffset]]>
+//       CHECK:         linalg.copy(%[[vA]], %[[partialA]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
+//       CHECK:         linalg.copy(%[[vB]], %[[partialB]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
+//       CHECK:         linalg.copy(%[[vC]], %[[partialC]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
 //
 //       CHECK:         linalg.matmul(%[[fullA]], %[[fullB]], %[[fullC]]) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
 //
-//       CHECK:         linalg.copy(%[[partialC]], %[[vC]]) : memref<?x?xf32, #[[strided2DnoOffset]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
+//       CHECK:         linalg.copy(%[[partialC]], %[[vC]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
 //
 //       CHECK:         dealloc %[[tmpA]] : memref<32xi8>
 //       CHECK:         dealloc %[[tmpB]] : memref<48xi8>
@@ -111,28 +110,28 @@ func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 //       CHECK:         %[[tmpA_f64:.*]] = alloc() : memref<64xi8>
 //       CHECK:         %[[fullA_f64:.*]] = std.view %[[tmpA_f64]][][{{.*}}] : memref<64xi8> to memref<?x?xf64>
 //     DYNAMIC:         std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf64>
-//       CHECK:         %[[partialA_f64:.*]] = linalg.slice %[[fullA_f64]][%{{.*}}, %{{.*}}] : memref<?x?xf64>, !linalg.range, !linalg.range, memref<?x?xf64, #[[strided2DnoOffset]]>
+//       CHECK:         %[[partialA_f64:.*]] = subview %[[fullA_f64]][%{{.*}}, %{{.*}}] : memref<?x?xf64> to memref<?x?xf64, #[[strided2D_dynamic]]>
 ///
 //       CHECK:         %[[tmpB_f64:.*]] = alloc() : memref<96xi8>
 //       CHECK:         %[[fullB_f64:.*]] = std.view %[[tmpB_f64]][][{{.*}}] : memref<96xi8> to memref<?x?xf64>
 //     DYNAMIC:         std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf64>
-//       CHECK:         %[[partialB_f64:.*]] = linalg.slice %[[fullB_f64]][%{{.*}}, %{{.*}}] : memref<?x?xf64>, !linalg.range, !linalg.range, memref<?x?xf64, #[[strided2DnoOffset]]>
+//       CHECK:         %[[partialB_f64:.*]] = subview %[[fullB_f64]][%{{.*}}, %{{.*}}] : memref<?x?xf64> to memref<?x?xf64, #[[strided2D_dynamic]]>
 ///
 //       CHECK:         %[[tmpC_f64:.*]] = alloc() : memref<48xi8>
 //       CHECK:         %[[fullC_f64:.*]] = std.view %[[tmpC_f64]][][{{.*}}] : memref<48xi8> to memref<?x?xf64>
 //     DYNAMIC:         std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf64>
-//       CHECK:         %[[partialC_f64:.*]] = linalg.slice %[[fullC_f64]][%{{.*}}, %{{.*}}] : memref<?x?xf64>, !linalg.range, !linalg.range, memref<?x?xf64, #[[strided2DnoOffset]]>
+//       CHECK:         %[[partialC_f64:.*]] = subview %[[fullC_f64]][%{{.*}}, %{{.*}}] : memref<?x?xf64> to memref<?x?xf64, #[[strided2D_dynamic]]>
 
 //       CHECK:         linalg.fill(%[[fullA_f64]], {{.*}}) : memref<?x?xf64>, f64
 //       CHECK:         linalg.fill(%[[fullB_f64]], {{.*}}) : memref<?x?xf64>, f64
 //       CHECK:         linalg.fill(%[[fullC_f64]], {{.*}}) : memref<?x?xf64>, f64
-//       CHECK:         linalg.copy(%[[vA_f64]], %[[partialA_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2DnoOffset]]>
-//       CHECK:         linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2DnoOffset]]>
-//       CHECK:         linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2DnoOffset]]>
+//       CHECK:         linalg.copy(%[[vA_f64]], %[[partialA_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2D_dynamic]]>
+//       CHECK:         linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2D_dynamic]]>
+//       CHECK:         linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2D_dynamic]]>
 //
 //       CHECK:         linalg.matmul(%[[fullA_f64]], %[[fullB_f64]], %[[fullC_f64]]) : memref<?x?xf64>, memref<?x?xf64>, memref<?x?xf64>
 //
-//       CHECK:         linalg.copy(%[[partialC_f64]], %[[vC_f64]]) : memref<?x?xf64, #[[strided2DnoOffset]]>, memref<?x?xf64, #[[strided2D_dynamic]]>
+//       CHECK:         linalg.copy(%[[partialC_f64]], %[[vC_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2D_dynamic]]>
 //
 //       CHECK:         dealloc %[[tmpA_f64]] : memref<64xi8>
 //       CHECK:         dealloc %[[tmpB_f64]] : memref<96xi8>
@@ -176,28 +175,28 @@ func @matmul_i32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 //       CHECK:         %[[tmpA_i32:.*]] = alloc() : memref<32xi8>
 //       CHECK:         %[[fullA_i32:.*]] = std.view %[[tmpA_i32]][][{{.*}}] : memref<32xi8> to memref<?x?xi32>
 //     DYNAMIC:         std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xi32>
-//       CHECK:         %[[partialA_i32:.*]] = linalg.slice %[[fullA_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32>, !linalg.range, !linalg.range, memref<?x?xi32, #[[strided2DnoOffset]]>
+//       CHECK:         %[[partialA_i32:.*]] = subview %[[fullA_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32> to memref<?x?xi32, #[[strided2D_dynamic]]>
 ///
 //       CHECK:         %[[tmpB_i32:.*]] = alloc() : memref<48xi8>
 //       CHECK:         %[[fullB_i32:.*]] = std.view %[[tmpB_i32]][][{{.*}}] : memref<48xi8> to memref<?x?xi32>
 //     DYNAMIC:         std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xi32>
-//       CHECK:         %[[partialB_i32:.*]] = linalg.slice %[[fullB_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32>, !linalg.range, !linalg.range, memref<?x?xi32, #[[strided2DnoOffset]]>
+//       CHECK:         %[[partialB_i32:.*]] = subview %[[fullB_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32> to memref<?x?xi32, #[[strided2D_dynamic]]>
 ///
 //       CHECK:         %[[tmpC_i32:.*]] = alloc() : memref<24xi8>
 //       CHECK:         %[[fullC_i32:.*]] = std.view %[[tmpC_i32]][][{{.*}}] : memref<24xi8> to memref<?x?xi32>
 //     DYNAMIC:         std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xi32>
-//       CHECK:         %[[partialC_i32:.*]] = linalg.slice %[[fullC_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32>, !linalg.range, !linalg.range, memref<?x?xi32, #[[strided2DnoOffset]]>
+//       CHECK:         %[[partialC_i32:.*]] = subview %[[fullC_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32> to memref<?x?xi32, #[[strided2D_dynamic]]>
 
 //       CHECK:         linalg.fill(%[[fullA_i32]], {{.*}}) : memref<?x?xi32>, i32
 //       CHECK:         linalg.fill(%[[fullB_i32]], {{.*}}) : memref<?x?xi32>, i32
 //       CHECK:         linalg.fill(%[[fullC_i32]], {{.*}}) : memref<?x?xi32>, i32
-//       CHECK:         linalg.copy(%[[vA_i32]], %[[partialA_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2DnoOffset]]>
-//       CHECK:         linalg.copy(%[[vB_i32]], %[[partialB_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2DnoOffset]]>
-//       CHECK:         linalg.copy(%[[vC_i32]], %[[partialC_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2DnoOffset]]>
+//       CHECK:         linalg.copy(%[[vA_i32]], %[[partialA_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2D_dynamic]]>
+//       CHECK:         linalg.copy(%[[vB_i32]], %[[partialB_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2D_dynamic]]>
+//       CHECK:         linalg.copy(%[[vC_i32]], %[[partialC_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2D_dynamic]]>
 //
 //       CHECK:         linalg.matmul(%[[fullA_i32]], %[[fullB_i32]], %[[fullC_i32]]) : memref<?x?xi32>, memref<?x?xi32>, memref<?x?xi32>
 //
-//       CHECK:         linalg.copy(%[[partialC_i32]], %[[vC_i32]]) : memref<?x?xi32, #[[strided2DnoOffset]]>, memref<?x?xi32, #[[strided2D_dynamic]]>
+//       CHECK:         linalg.copy(%[[partialC_i32]], %[[vC_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2D_dynamic]]>
 //
 //       CHECK:         dealloc %[[tmpA_i32]] : memref<32xi8>
 //       CHECK:         dealloc %[[tmpB_i32]] : memref<48xi8>


        


More information about the Mlir-commits mailing list