[PATCH] D136483: [mlir][MemRefToLLVM] Reuse existing lowering for collaspe/expand_shape

Quentin Colombet via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 4 19:08:28 PDT 2022


qcolombet added a comment.
Herald added a subscriber: Moerafaat.

Reporting on this:

> Do we expect LLVM's CSE to simplify this? Is there any indication of this actually happening or not?

Yes, I confirmed that CSE is happening just fine.

Here is what I did:
Old `mlir-opt`:

  mlir-opt -convert-memref-to-llvm -lower-affine -convert-arith-to-llvm  -convert-func-to-llvm -reconcile-unrealized-casts <input>.mlir -o  <output>.mlir
  mlir-translate -mlir-to-llvmir  <output>.mlir -o - | opt -S -early-cse -o old-static.ll

New `mlir-opt`, i.e., with this patch:

  # Run the expand pass first (right now it is called simplify-extract-strided-metadata
  mlir-opt -simplify-extract-strided-metadata -convert-memref-to-llvm -lower-affine -convert-arith-to-llvm  -convert-func-to-llvm -reconcile-unrealized-casts <input>.mlir -o  <output>.mlir
  mlir-translate -mlir-to-llvmir  <output>.mlir -o - | opt -S -early-cse -o new-static.ll

Result the IR is semantically equivalent and as performant in both case. The only difference is the extract_strided_metadata descriptor that stays around is we don't run DCE.
E.g., with the function `collapse_shape_static` from in `memref-to-llvm.mlir`:

  --- old-static-cse.ll   2022-11-05 01:35:30.604898681 +0000
  +++ new-static-cse.ll   2022-11-05 01:35:24.384293356 +0000
  @@ -1,36 +1,38 @@
   ; ModuleID = '<stdin>'
   source_filename = "LLVMDialectModule"
   
   declare ptr @malloc(i64)
   
   declare void @free(ptr)
   
   define { ptr, ptr, i64, [3 x i64], [3 x i64] } @collapse_shape_static(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, i64 %5, i64 %6, i64 %7, i64 %8, i64 %9, i64 %10, i64 %11, i64 %12) {
     %14 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } undef, ptr %0, 0
     %15 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %14, ptr %1, 1
     %16 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %15, i64 %2, 2
     %17 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %16, i64 %3, 3, 0
     %18 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %17, i64 %8, 4, 0
     %19 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %18, i64 %4, 3, 1
     %20 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %19, i64 %9, 4, 1
     %21 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %20, i64 %5, 3, 2
     %22 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %21, i64 %10, 4, 2
     %23 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %22, i64 %6, 3, 3
     %24 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %23, i64 %11, 4, 3
     %25 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %24, i64 %7, 3, 4
     %26 = insertvalue { ptr, ptr, i64, [5 x i64], [5 x i64] } %25, i64 %12, 4, 4
  -  %27 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } undef, ptr %0, 0
  -  %28 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %27, ptr %1, 1
  -  %29 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %28, i64 %2, 2
  -  %30 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %29, i64 3, 3, 0
  -  %31 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %30, i64 4, 3, 1
  -  %32 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %31, i64 5, 3, 2
  +  %27 = insertvalue { ptr, ptr, i64 } undef, ptr %0, 0
  +  %28 = insertvalue { ptr, ptr, i64 } %27, ptr %1, 1
  +  %29 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } undef, ptr %0, 0
  +  %30 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %29, ptr %1, 1
  +  %31 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %30, i64 0, 2
  +  %32 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %31, i64 3, 3, 0
     %33 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %32, i64 20, 4, 0
  -  %34 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %33, i64 5, 4, 1
  -  %35 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %34, i64 1, 4, 2
  -  ret { ptr, ptr, i64, [3 x i64], [3 x i64] } %35
  +  %34 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %33, i64 4, 3, 1
  +  %35 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %34, i64 5, 4, 1
  +  %36 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %35, i64 5, 3, 2
  +  %37 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %36, i64 1, 4, 2
  +  ret { ptr, ptr, i64, [3 x i64], [3 x i64] } %37
   }
   
   !llvm.module.flags = !{!0}
   
   !0 = !{i32 2, !"Debug Info Version", i32 3}

And it gets more evident that things are the same if you run `instcombine` instead of `cse`:

  --- old-static-instcombine.ll   2022-11-05 02:07:09.937594426 +0000
  +++ new-static-instcombine.ll   2022-11-05 02:07:17.966374824 +0000
  @@ -1,23 +1,23 @@
   ; ModuleID = '<stdin>'
   source_filename = "LLVMDialectModule"
   
   declare ptr @malloc(i64)
   
   declare void @free(ptr)
   
   define { ptr, ptr, i64, [3 x i64], [3 x i64] } @collapse_shape_static(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, i64 %5, i64 %6, i64 %7, i64 %8, i64 %9, i64 %10, i64 %11, i64 %12) {
     %14 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } undef, ptr %0, 0
     %15 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %14, ptr %1, 1
  -  %16 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %15, i64 %2, 2
  +  %16 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %15, i64 0, 2
     %17 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %16, i64 3, 3, 0
  -  %18 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %17, i64 4, 3, 1
  -  %19 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %18, i64 5, 3, 2
  -  %20 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %19, i64 20, 4, 0
  -  %21 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %20, i64 5, 4, 1
  +  %18 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %17, i64 20, 4, 0
  +  %19 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %18, i64 4, 3, 1
  +  %20 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %19, i64 5, 4, 1
  +  %21 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %20, i64 5, 3, 2
     %22 = insertvalue { ptr, ptr, i64, [3 x i64], [3 x i64] } %21, i64 1, 4, 2
     ret { ptr, ptr, i64, [3 x i64], [3 x i64] } %22
   }
   
   !llvm.module.flags = !{!0}
   
   !0 = !{i32 2, !"Debug Info Version", i32 3}


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D136483/new/

https://reviews.llvm.org/D136483



More information about the llvm-commits mailing list