[Mlir-commits] [mlir] [mlir][vector] Add pattern to drop unit dim from elementwise(a, b)) (PR #74817)

Cullen Rhodes llvmlistbot at llvm.org
Wed Dec 13 02:54:55 PST 2023


================
@@ -1446,6 +1446,92 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
+/// Replace:
+///   elementwise(a, b)
+/// with:
+///   sc_a = shape_cast(a)
+///   sc_b = shape_cast(b)
+///   res = elementwise(sc_a, sc_b)
+///   return shape_cast(res)
+/// for which `a` and `b` are vectors of rank > 2 and have unit leading and/or
+/// trailing dimension.
+///
+/// Ex:
+/// ```
+///  %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
+///  %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
----------------
c-rhodes wrote:

the following example without shape_cast in input currently crashes:
```
; foo.mlir
func.func @fold_unit_dim_add(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> {
   %add = arith.addi %arg0, %arg0 : vector<1x8xi32>
   return %add : vector<1x8xi32>
}
```

reproducer:
```
build/bin/mlir-opt foo.mlir -test-vector-transfer-flatten-patterns
LLVM ERROR: Building op `vector.shape_cast` but it isn't known in this MLIRContext: the dialect may not be loaded or this operation hasn't been added by the dialect. See also https://mlir.llvm.org/getting_started/Faq/#registered-loaded-dependent-whats-up-with-dialects-management
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.      Program arguments: build/bin/mlir-opt /home/culrho01/llvm-project/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir -test-vector-transfer-flatten-patterns
 #0 0x0000ffff9b27f848 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /home/culrho01/llvm-project/llvm/lib/Support/Unix/Signals.inc:723:11
 #1 0x0000ffff9b27fd50 PrintStackTraceSignalHandler(void*) /home/culrho01/llvm-project/llvm/lib/Support/Unix/Signals.inc:798:1
 #2 0x0000ffff9b27dee4 llvm::sys::RunSignalHandlers() /home/culrho01/llvm-project/llvm/lib/Support/Signals.cpp:105:5
 #3 0x0000ffff9b2805a8 SignalHandler(int) /home/culrho01/llvm-project/llvm/lib/Support/Unix/Signals.inc:413:1
 #4 0x0000ffffb0c8678c (linux-vdso.so.1+0x78c)
 #5 0x0000ffff9aa7ed78 raise (/lib/aarch64-linux-gnu/libc.so.6+0x33d78)
 #6 0x0000ffff9aa6baac abort (/lib/aarch64-linux-gnu/libc.so.6+0x20aac)
 #7 0x0000ffff9b14f6ac llvm::report_fatal_error(llvm::Twine const&, bool) /home/culrho01/llvm-project/llvm/lib/Support/ErrorHandling.cpp:125:5
 #8 0x0000ffffa29206dc mlir::RegisteredOperationName mlir::OpBuilder::getCheckRegisteredInfo<mlir::vector::ShapeCastOp>(mlir::MLIRContext*) /home/culrho01/llvm-project/mlir/include/mlir/IR/Builders.h:485:12
 #9 0x0000ffffa294f65c mlir::vector::ShapeCastOp mlir::OpBuilder::create<mlir::vector::ShapeCastOp, mlir::VectorType&, mlir::Value&>(mlir::Location, mlir::VectorType&, mlir::Value&) /home/culrho01/llvm-project/mlir/include/mlir/IR/Builders.h:493:26
#10 0x0000ffffa2a12490 (anonymous namespace)::DropUnitDimFromElementwiseOps::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&) const /home/culrho01/llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp:1512:28
...
```

the Vector dialect isn't registered by the `-test-vector-transfer-flatten-patterns` pass and this could introduce Vector ops where there weren't any before.


https://github.com/llvm/llvm-project/pull/74817


More information about the Mlir-commits mailing list