[all-commits] [llvm/llvm-project] 9a795f: [mlir][Vector] Adds a pattern to fold `arith.extf`...

Manish Gupta via All-commits all-commits at lists.llvm.org
Mon Jun 5 16:31:28 PDT 2023


  Branch: refs/heads/main
  Home:   https://github.com/llvm/llvm-project
  Commit: 9a795f0c59b1707d1f4bdb352e8805133d72d9e2
      https://github.com/llvm/llvm-project/commit/9a795f0c59b1707d1f4bdb352e8805133d72d9e2
  Author: Manish Gupta <manigupta at google.com>
  Date:   2023-06-05 (Mon, 05 Jun 2023)

  Changed paths:
    M mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
    M mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    A mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
    A mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
    M mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

  Log Message:
  -----------
  [mlir][Vector] Adds a pattern to fold `arith.extf` into `vector.contract`

Consider mixed precision data type, i.e., F16 input lhs, F16 input rhs, F32 accumulation, and F32 output. This is typically written as F32 <= F16*F16 + F32.

During vectorization from linalg to vector for mixed precision data type (F32 <= F16*F16 + F32), linalg.matmul introduces arith.extf on input lhs and rhs operands.

"linalg.matmul"(%lhs, %rhs, %acc) ({
      ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
        %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
        %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
       %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
        %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
      "linalg.yield"(%acc) : (f32) -> ()
    })
There are backend that natively supports mixed-precision data type and does not need the arith.extf. For example, NVIDIA A100 GPU has mma.sync.aligned.*.f32.f16.f16.f32 that can support mixed-precision data type. However, the presence of arith.extf in the IR, introduces the unnecessary casting targeting F32 Tensor Cores instead of F16 Tensor Cores for NVIDIA backend. This patch adds a folding pattern to fold arith.extf into vector.contract

Differential Revision: https://reviews.llvm.org/D151918




More information about the All-commits mailing list