[Mlir-commits] [mlir] 99069ab - [mlir][linalg] fix crash when promoting rank-reducing memref.subviews

Christopher Bate llvmlistbot at llvm.org
Mon Jun 6 11:08:42 PDT 2022


Author: Christopher Bate
Date: 2022-06-06T12:06:36-06:00
New Revision: 99069ab212f547088b39aa0b73e8c77f59d89b0c

URL: https://github.com/llvm/llvm-project/commit/99069ab212f547088b39aa0b73e8c77f59d89b0c
DIFF: https://github.com/llvm/llvm-project/commit/99069ab212f547088b39aa0b73e8c77f59d89b0c.diff

LOG: [mlir][linalg] fix crash when promoting rank-reducing memref.subviews

This change adds support for promoting `linalg` operation operands that
are produced by rank-reducing `memref.subview` ops.

Differential Revision: https://reviews.llvm.org/D127086

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
    mlir/test/Dialect/Linalg/promote.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 61241a3026720..ed7d32bd6f8fc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -25,6 +25,7 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -219,7 +220,11 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
   SmallVector<OpFoldResult> partialSizes;
   fullSizes.reserve(rank);
   partialSizes.reserve(rank);
+  llvm::SmallBitVector droppedDims = subView.getDroppedDims();
+  int64_t resultDimIdx = 0;
   for (const auto &en : llvm::enumerate(subView.getOrCreateRanges(b, loc))) {
+    if (droppedDims[en.index()])
+      continue;
     auto rangeValue = en.value();
     // Try to extract a tight constant.
     LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n");
@@ -232,7 +237,7 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
     LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
     fullSizes.push_back(size);
     partialSizes.push_back(
-        b.createOrFold<memref::DimOp>(loc, subView, en.index()));
+        b.createOrFold<memref::DimOp>(loc, subView, resultDimIdx++));
   }
   SmallVector<int64_t, 4> dynSizes(fullSizes.size(), -1);
   // If a callback is not specified, then use the default implementation for

diff  --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir
index 1f883be0c6e91..7245650d911ff 100644
--- a/mlir/test/Dialect/Linalg/promote.mlir
+++ b/mlir/test/Dialect/Linalg/promote.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt %s -linalg-promote-subviews | FileCheck %s
-// RUN: mlir-opt %s -linalg-promote-subviews="test-promote-dynamic" | FileCheck %s --check-prefix=DYNAMIC
-// RUN: mlir-opt %s -linalg-promote-subviews="test-use-alloca" | FileCheck %s --check-prefix=ALLOCA
+// RUN: mlir-opt %s -linalg-promote-subviews -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -linalg-promote-subviews="test-promote-dynamic" -split-input-file | FileCheck %s --check-prefix=DYNAMIC
+// RUN: mlir-opt %s -linalg-promote-subviews="test-use-alloca" -split-input-file | FileCheck %s --check-prefix=ALLOCA
 
 #map1 = affine_map<(d0) -> (d0 + 2)>
 #map2 = affine_map<(d0) -> (d0 + 4)>
@@ -145,3 +145,46 @@ func.func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 //       CHECK:         memref.dealloc %[[tmpA_f64]] : memref<64xi8>
 //       CHECK:         memref.dealloc %[[tmpB_f64]] : memref<96xi8>
 //       CHECK:         memref.dealloc %[[tmpC_f64]] : memref<48xi8>
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)>
+#map2 = affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>
+#map5 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+#map6 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map7 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map8 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK: promote_rank_reducing_subviews([[arg0:%.+]]: memref<{{.*}}>, [[arg1:%.+]]: memref<{{.*}}>, [[arg2:%.+]]: memref<{{.*}}>, [[lb1:%.+]]: index, [[lb2:%.+]]: index, [[lb3:%.+]]: index, [[lb4:%.+]]: index, [[lb5:%.+]]: index, [[lb6:%.+]]: index, [[ub1:%.+]]: index, [[ub2:%.+]]: index
+func.func @promote_rank_reducing_subviews(%arg0:  memref<?x?x?x64xf32, #map0>, %arg1: memref<128x3x3x64xf32, #map0>, %arg2: memref<?x?x?x128xf32>,
+                                          %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %ub1: index, %ub2: index) {
+  %13 = memref.subview %arg0[%arg3, 0, %arg4, %arg8] [1, 1, %ub1, 32] [1, 1, 1, 1] : memref<?x?x?x64xf32, #map0> to memref<?x32xf32, #map5>
+  %14 = memref.subview %arg1[0, %arg6, %arg7, %arg8] [128, 1, 1, 32] [1, 1, 1, 1] : memref<128x3x3x64xf32, #map0> to memref<128x32xf32, #map5>
+  %9 = memref.subview %arg2[%arg3, %arg4, %arg5, 0] [1, 1, %ub2, 128] [1, 1, 1, 1] : memref<?x?x?x128xf32> to memref<?x128xf32, #map2>
+
+  // CHECK: [[a_alloc:%.+]] = memref.alloc
+  // CHECK: [[a_view:%.+]] = memref.view [[a_alloc]]{{.*}}                 
+  // CHECK: [[a_pro_subview:%.+]] = memref.subview [[a_view]][0, 0] [[[ub1]], {{%.+}}] [1, 1]
+
+  // CHECK: memref.alloc
+  // CHECK: [[b_view:%.+]] = memref.view
+  // CHECK: [[b_pro_subview:%.+]] = memref.subview [[b_view]]
+
+  // CHECK: memref.alloc
+  // CHECK: [[c_view:%.+]] = memref.view
+  // CHECK: [[c_pro_subview:%.+]] = memref.subview [[c_view]]
+
+  // CHECK-COUNT-3: memref.copy
+  // CHECK: linalg.generic
+  // CHECK-SAME: ins([[a_pro_subview]], [[b_pro_subview]]
+  // CHECK-SAME: outs([[c_pro_subview]]
+
+  linalg.generic {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : memref<?x32xf32, #map5>, memref<128x32xf32, #map5>) outs(%9 : memref<?x128xf32, #map2>) {
+  ^bb0(%arg9: f32, %arg10: f32, %arg11: f32):
+    %15 = arith.mulf %arg9, %arg10 : f32
+    %16 = arith.addf %arg11, %15 : f32
+    linalg.yield %16 : f32
+  }
+
+  return
+}


        


More information about the Mlir-commits mailing list