[Mlir-commits] [mlir] 9a795f0 - [mlir][Vector] Adds a pattern to fold `arith.extf` into `vector.contract`

Manish Gupta llvmlistbot at llvm.org
Mon Jun 5 16:31:21 PDT 2023


Author: Manish Gupta
Date: 2023-06-05T23:22:20Z
New Revision: 9a795f0c59b1707d1f4bdb352e8805133d72d9e2

URL: https://github.com/llvm/llvm-project/commit/9a795f0c59b1707d1f4bdb352e8805133d72d9e2
DIFF: https://github.com/llvm/llvm-project/commit/9a795f0c59b1707d1f4bdb352e8805133d72d9e2.diff

LOG: [mlir][Vector] Adds a pattern to fold `arith.extf` into `vector.contract`

Consider mixed precision data type, i.e., F16 input lhs, F16 input rhs, F32 accumulation, and F32 output. This is typically written as F32 <= F16*F16 + F32.

During vectorization from linalg to vector for mixed precision data type (F32 <= F16*F16 + F32), linalg.matmul introduces arith.extf on input lhs and rhs operands.

"linalg.matmul"(%lhs, %rhs, %acc) ({
      ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
        %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
        %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
       %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
        %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
      "linalg.yield"(%acc) : (f32) -> ()
    })
There are backend that natively supports mixed-precision data type and does not need the arith.extf. For example, NVIDIA A100 GPU has mma.sync.aligned.*.f32.f16.f16.f32 that can support mixed-precision data type. However, the presence of arith.extf in the IR, introduces the unnecessary casting targeting F32 Tensor Cores instead of F16 Tensor Cores for NVIDIA backend. This patch adds a folding pattern to fold arith.extf into vector.contract

Differential Revision: https://reviews.llvm.org/D151918

Added: 
    mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
    mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index abfacd81e88da..49a235186ecd3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -74,6 +74,10 @@ isBroadcastableTo(Type srcType, VectorType dstVectorType,
 void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
                                                     PatternBenefit benefit = 1);
 
+/// Collect a set of patterns that fold arithmetic extension on floating point
+/// into vector contract for the backends with native support.
+void populateFoldArithExtensionPatterns(RewritePatternSet &patterns);
+
 /// Returns the integer type required for subscripts in the vector dialect.
 IntegerType getVectorSubscriptType(Builder &builder);
 

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index d634d6a19030d..f9e778d331ef0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1212,8 +1212,54 @@ struct CanonicalizeContractMatmulToMMT final
   FilterConstraintType filter;
 };
 
+/// Pattern to fold arithmetic extensions on floating point data types into
+/// vector contraction operations. linalg.matmul introduces arithmetic
+/// extensions on its operands. Please mlir snippets below for more details.
+/// ```mlir
+///   "linalg.matmul"(%lhs, %rhs, %acc) ({
+///      ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
+///        %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
+///        %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
+///        %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
+///        %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
+///        "linalg.yield"(%acc) : (f32) -> ()
+///     })
+/// ```
+/// This restricts the native usage of mixed precision NVIDIA Ampere Tensor
+/// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
+/// This pattern folds the arithmetic extensions into the vector contraction and
+/// enables the usage of native mixed precision Tensor Core instructions.
+struct FoldArithExtIntoContractionOp
+    : public OpRewritePattern<vector::ContractionOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+                                PatternRewriter &rewriter) const override {
+
+    auto lhsDefOp = contractOp.getLhs().getDefiningOp<arith::ExtFOp>();
+    auto rhsDefOp = contractOp.getRhs().getDefiningOp<arith::ExtFOp>();
+
+    if (!lhsDefOp || !rhsDefOp) {
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "no defining op on contract operands");
+    }
+
+    rewriter.replaceOpWithNewOp<vector::ContractionOp>(
+        contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
+        contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
+        contractOp.getIteratorTypesAttr());
+
+    return success();
+  }
+};
+
 } // namespace
 
+void mlir::vector::populateFoldArithExtensionPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<FoldArithExtIntoContractionOp>(patterns.getContext());
+}
+
 void mlir::vector::populateVectorMaskMaterializationPatterns(
     RewritePatternSet &patterns, bool force32BitVectorIndices,
     PatternBenefit benefit) {

diff  --git a/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
new file mode 100644
index 0000000000000..0afaa19d59d15
--- /dev/null
+++ b/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
@@ -0,0 +1,46 @@
+// RUN: mlir-opt %s -split-input-file -pass-pipeline="builtin.module(func.func(test-fold-arith-extf-into-vector-contract-patterns,convert-vector-to-gpu{use-nvgpu=true},cse))" | FileCheck %s
+
+//###############################################################################################
+// FP16 input, F32 accumulation row-row-row (ldmatrix x4 for matrixA and ldmatrix x4 for matrixB)
+//###############################################################################################
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row
+func.func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row(%arg0: memref<42x32xf16, #gpu.address_space<workgroup>>, %arg1: memref<32x64xf16, #gpu.address_space<workgroup>>, %arg2: memref<42x64xf32, #gpu.address_space<workgroup>>) {
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %cst_f16 = arith.constant 0.000000e+00 : f16
+  %cst_f32 = arith.constant 0.000000e+00 : f32
+  
+  // CHECK-DAG: nvgpu.ldmatrix %arg0[%{{.*}}, %{{.*}}] {numTiles = 4 : i32, transpose = false}
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst_f16 {in_bounds = [true, true]} : memref<42x32xf16, #gpu.address_space<workgroup>>, vector<16x16xf16>
+  %A_f32 = arith.extf %A : vector<16x16xf16> to vector<16x16xf32>
+  
+
+  // CHECK-DAG: nvgpu.ldmatrix %arg1[%{{.*}}, %{{.*}}] {numTiles = 4 : i32, transpose = true}
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst_f16 {permutation_map = #map0, in_bounds = [true, true]} : memref<32x64xf16, #gpu.address_space<workgroup>>, vector<16x16xf16>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst_f32 {in_bounds = [true, true]} : memref<42x64xf32, #gpu.address_space<workgroup>>, vector<16x16xf32>
+
+  %B0 = vector.extract_strided_slice %B {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
+  %B0_f32 = arith.extf %B0 : vector<8x16xf16> to vector<8x16xf32>
+  %C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf32> to vector<16x8xf32>
+  
+  // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
+  %D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A_f32, %B0_f32, %C0 : vector<16x16xf32>, vector<8x16xf32> into vector<16x8xf32>
+  vector.transfer_write %D0, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<42x64xf32, #gpu.address_space<workgroup>>
+
+
+  %B1 = vector.extract_strided_slice %B {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
+  %B1_f32 = arith.extf %B1 : vector<8x16xf16> to vector<8x16xf32>
+  %C1 = vector.extract_strided_slice %C {offsets = [0, 8], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf32> to vector<16x8xf32>
+
+  // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
+  %D1 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A_f32, %B1_f32, %C1 : vector<16x16xf32>, vector<8x16xf32> into vector<16x8xf32>
+  vector.transfer_write %D1, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<42x64xf32, #gpu.address_space<workgroup>>
+
+  return
+}

diff  --git a/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
new file mode 100644
index 0000000000000..79429afd8ff29
--- /dev/null
+++ b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt -split-input-file -test-fold-arith-extf-into-vector-contract-patterns %s | FileCheck %s
+
+
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @fold_arith_extf_into_contract
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<64x64xf16>, %[[ARG1:.*]]: vector<64x64xf16>, %[[ARG2:.*]]: vector<64x64xf32>)
+//  CHECK-NEXT:   %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
+//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+//  CHECK-SAME:   %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xf16>, vector<64x64xf16> into vector<64x64xf32>
+//  CHECK-NEXT:   return %[[R]] : vector<64x64xf32>
+func.func @fold_arith_extf_into_contract(%arg0: vector<64x64xf16>, %arg1: vector<64x64xf16>, %arg2: vector<64x64xf32>) -> vector<64x64xf32> {
+    %lhs_f32 = arith.extf %arg0 : vector<64x64xf16> to vector<64x64xf32>
+    %rhs_f32 = arith.extf %arg1 : vector<64x64xf16> to vector<64x64xf32>
+    %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_f32, %rhs_f32, %arg2 : vector<64x64xf32>, vector<64x64xf32> into vector<64x64xf32>
+    return %result : vector<64x64xf32>
+}
\ No newline at end of file

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 3b0cf2f83f198..4fbddcee574a1 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -709,6 +710,32 @@ struct TestVectorTransferTensorSlicePatterns
   }
 };
 
+struct TestFoldArithExtensionIntoVectorContractPatterns
+    : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestFoldArithExtensionIntoVectorContractPatterns)
+
+  StringRef getArgument() const final {
+    return "test-fold-arith-extf-into-vector-contract-patterns";
+  }
+  StringRef getDescription() const final {
+    return "Test patterns that fold arithmetic extension ops into vector "
+           "contract ops";
+  }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect, func::FuncDialect, nvgpu::NVGPUDialect,
+                    memref::MemRefDialect, scf::SCFDialect,
+                    tensor::TensorDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateFoldArithExtensionPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
 } // namespace
 
 namespace mlir {
@@ -745,6 +772,8 @@ void registerTestVectorLowerings() {
   PassRegistration<TestVectorGatherLowering>();
 
   PassRegistration<TestVectorTransferTensorSlicePatterns>();
+
+  PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
 }
 } // namespace test
 } // namespace mlir


        


More information about the Mlir-commits mailing list