[Mlir-commits] [mlir] 2ae37be - Allow empty dimension arrays in `linalg::inferContractionDims` (#69496)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 19 10:13:32 PDT 2023


Author: bjacob
Date: 2023-10-19T13:13:27-04:00
New Revision: 2ae37be4b433632e46209aa04fcc857675783f81

URL: https://github.com/llvm/llvm-project/commit/2ae37be4b433632e46209aa04fcc857675783f81
DIFF: https://github.com/llvm/llvm-project/commit/2ae37be4b433632e46209aa04fcc857675783f81.diff

LOG: Allow empty dimension arrays in `linalg::inferContractionDims` (#69496)

This function was returning failure when any of the intersection sets
was empty, but this is actually legitimate in "matrix times vector"
cases, where some of the operands have lower dimensionality, implying
unit-dimension semantics for the "missing" dimensions.

Example:

```mlir
func.func @transpose_extend_batch_matmul(
    %vec: tensor<32x128xi16>,
    %mat: tensor<11008x32x128xi4>) -> tensor<11008x32xi32> {
  %c0_i32 = arith.constant 0 : i32
  %cst_0 = arith.constant 0.000000e+00 : f32
  %0 = tensor.empty() : tensor<11008x32xi32>
  %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<11008x32xi32>) -> tensor<11008x32xi32>
  %2 = tensor.empty() : tensor<11008xf32>
  %3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<11008xf32>) -> tensor<11008xf32>
  %batch_matmul_result = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, 
                                                          affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 
                                                          affine_map<(d0, d1, d2) -> (d0, d1)>], 
                                         iterator_types = ["parallel", "parallel", "reduction"]} 
                                         ins(%vec, %mat : tensor<32x128xi16>, tensor<11008x32x128xi4>) 
                                         outs(%1 : tensor<11008x32xi32>) {
  ^bb0(%in: i16, %in_3: i4, %out: i32):
      %19 = arith.extsi %in : i16 to i32
      %20 = arith.extui %in_3 : i4 to i32
      %21 = arith.muli %19, %20 : i32
      %22 = arith.addi %21, %out : i32
      linalg.yield %22 : i32
  } -> tensor<11008x32xi32>
  return %batch_matmul_result : tensor<11008x32xi32>
}
```

Here, we were returning failure because `ac` is empty. With this PR, we
return this useful information:

```
batch: [ 1 ]
m: [ ]
n: [ 0 ]
k: [ 2 ]
```

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/test/Dialect/Linalg/match-ops-interpreter.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index ea50e1232a4c74a..5fde8d71cac3e75 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -227,9 +227,6 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
       linalgOp, linalgOp.getDpsInputOperand(1), red);
   llvm::set_intersect(ra, rb);
 
-  if (ac.empty() || bc.empty() || ra.empty())
-    return failure();
-
   // Return each set in sorted order.
   ContractionDimensions dimensions{
       SmallVector<unsigned, 2>(batches.begin(), batches.end()),

diff  --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index bad6893eaa99e1e..1da092ab42ad779 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -910,6 +910,19 @@ module attributes { transform.target_tag = "start_here" } {
     return %result : tensor<10x15xf64>
   }
 
+  func.func @vecmat_simple(%lhs: tensor<20xf32>, %rhs: tensor<20x15xf32>) -> tensor<15xf64> {
+    %cst = arith.constant 0.0 : f64
+    %empty = tensor.empty() : tensor<15xf64>
+    %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<15xf64>) -> tensor<15xf64>
+    // expected-remark @below {{contraction}}
+    // expected-remark @below {{batch dims}}
+    // expected-remark @below {{m dims}}
+    // expected-remark @below {{n dims 0}}
+    // expected-remark @below {{k dims 1}}
+    %result = linalg.vecmat ins(%lhs, %rhs: tensor<20xf32>, tensor<20x15xf32>) outs(%fill: tensor<15xf64>) -> tensor<15xf64>
+    return %result : tensor<15xf64>
+  }
+
   func.func @double_batch(%lhs: tensor<40x10x50x20xf32>, %rhs: tensor<40x20x50x15xf32>) -> tensor<40x10x50x15xf32> {
     %cst = arith.constant 0.0 : f32
     %empty = tensor.empty() : tensor<40x10x50x15xf32>


        


More information about the Mlir-commits mailing list