[Mlir-commits] [mlir] 9a795f0 - [mlir][Vector] Adds a pattern to fold `arith.extf` into `vector.contract`
Manish Gupta
llvmlistbot at llvm.org
Mon Jun 5 16:31:21 PDT 2023
Author: Manish Gupta
Date: 2023-06-05T23:22:20Z
New Revision: 9a795f0c59b1707d1f4bdb352e8805133d72d9e2
URL: https://github.com/llvm/llvm-project/commit/9a795f0c59b1707d1f4bdb352e8805133d72d9e2
DIFF: https://github.com/llvm/llvm-project/commit/9a795f0c59b1707d1f4bdb352e8805133d72d9e2.diff
LOG: [mlir][Vector] Adds a pattern to fold `arith.extf` into `vector.contract`
Consider mixed precision data type, i.e., F16 input lhs, F16 input rhs, F32 accumulation, and F32 output. This is typically written as F32 <= F16*F16 + F32.
During vectorization from linalg to vector for mixed precision data type (F32 <= F16*F16 + F32), linalg.matmul introduces arith.extf on input lhs and rhs operands.
"linalg.matmul"(%lhs, %rhs, %acc) ({
^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
%lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
%rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
%mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
%acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
"linalg.yield"(%acc) : (f32) -> ()
})
There are backend that natively supports mixed-precision data type and does not need the arith.extf. For example, NVIDIA A100 GPU has mma.sync.aligned.*.f32.f16.f16.f32 that can support mixed-precision data type. However, the presence of arith.extf in the IR, introduces the unnecessary casting targeting F32 Tensor Cores instead of F16 Tensor Cores for NVIDIA backend. This patch adds a folding pattern to fold arith.extf into vector.contract
Differential Revision: https://reviews.llvm.org/D151918
Added:
mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index abfacd81e88da..49a235186ecd3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -74,6 +74,10 @@ isBroadcastableTo(Type srcType, VectorType dstVectorType,
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Collect a set of patterns that fold arithmetic extension on floating point
+/// into vector contract for the backends with native support.
+void populateFoldArithExtensionPatterns(RewritePatternSet &patterns);
+
/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index d634d6a19030d..f9e778d331ef0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1212,8 +1212,54 @@ struct CanonicalizeContractMatmulToMMT final
FilterConstraintType filter;
};
+/// Pattern to fold arithmetic extensions on floating point data types into
+/// vector contraction operations. linalg.matmul introduces arithmetic
+/// extensions on its operands. Please mlir snippets below for more details.
+/// ```mlir
+/// "linalg.matmul"(%lhs, %rhs, %acc) ({
+/// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
+/// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
+/// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
+/// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
+/// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
+/// "linalg.yield"(%acc) : (f32) -> ()
+/// })
+/// ```
+/// This restricts the native usage of mixed precision NVIDIA Ampere Tensor
+/// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
+/// This pattern folds the arithmetic extensions into the vector contraction and
+/// enables the usage of native mixed precision Tensor Core instructions.
+struct FoldArithExtIntoContractionOp
+ : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+
+ auto lhsDefOp = contractOp.getLhs().getDefiningOp<arith::ExtFOp>();
+ auto rhsDefOp = contractOp.getRhs().getDefiningOp<arith::ExtFOp>();
+
+ if (!lhsDefOp || !rhsDefOp) {
+ return rewriter.notifyMatchFailure(contractOp,
+ "no defining op on contract operands");
+ }
+
+ rewriter.replaceOpWithNewOp<vector::ContractionOp>(
+ contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
+ contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
+ contractOp.getIteratorTypesAttr());
+
+ return success();
+ }
+};
+
} // namespace
+void mlir::vector::populateFoldArithExtensionPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<FoldArithExtIntoContractionOp>(patterns.getContext());
+}
+
void mlir::vector::populateVectorMaskMaterializationPatterns(
RewritePatternSet &patterns, bool force32BitVectorIndices,
PatternBenefit benefit) {
diff --git a/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
new file mode 100644
index 0000000000000..0afaa19d59d15
--- /dev/null
+++ b/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir
@@ -0,0 +1,46 @@
+// RUN: mlir-opt %s -split-input-file -pass-pipeline="builtin.module(func.func(test-fold-arith-extf-into-vector-contract-patterns,convert-vector-to-gpu{use-nvgpu=true},cse))" | FileCheck %s
+
+//###############################################################################################
+// FP16 input, F32 accumulation row-row-row (ldmatrix x4 for matrixA and ldmatrix x4 for matrixB)
+//###############################################################################################
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row
+func.func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row(%arg0: memref<42x32xf16, #gpu.address_space<workgroup>>, %arg1: memref<32x64xf16, #gpu.address_space<workgroup>>, %arg2: memref<42x64xf32, #gpu.address_space<workgroup>>) {
+ %c0 = arith.constant 0 : index
+ %c8 = arith.constant 8 : index
+ %cst_f16 = arith.constant 0.000000e+00 : f16
+ %cst_f32 = arith.constant 0.000000e+00 : f32
+
+ // CHECK-DAG: nvgpu.ldmatrix %arg0[%{{.*}}, %{{.*}}] {numTiles = 4 : i32, transpose = false}
+ %A = vector.transfer_read %arg0[%c0, %c0], %cst_f16 {in_bounds = [true, true]} : memref<42x32xf16, #gpu.address_space<workgroup>>, vector<16x16xf16>
+ %A_f32 = arith.extf %A : vector<16x16xf16> to vector<16x16xf32>
+
+
+ // CHECK-DAG: nvgpu.ldmatrix %arg1[%{{.*}}, %{{.*}}] {numTiles = 4 : i32, transpose = true}
+ %B = vector.transfer_read %arg1[%c0, %c0], %cst_f16 {permutation_map = #map0, in_bounds = [true, true]} : memref<32x64xf16, #gpu.address_space<workgroup>>, vector<16x16xf16>
+ %C = vector.transfer_read %arg2[%c0, %c0], %cst_f32 {in_bounds = [true, true]} : memref<42x64xf32, #gpu.address_space<workgroup>>, vector<16x16xf32>
+
+ %B0 = vector.extract_strided_slice %B {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
+ %B0_f32 = arith.extf %B0 : vector<8x16xf16> to vector<8x16xf32>
+ %C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf32> to vector<16x8xf32>
+
+ // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
+ %D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A_f32, %B0_f32, %C0 : vector<16x16xf32>, vector<8x16xf32> into vector<16x8xf32>
+ vector.transfer_write %D0, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<42x64xf32, #gpu.address_space<workgroup>>
+
+
+ %B1 = vector.extract_strided_slice %B {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
+ %B1_f32 = arith.extf %B1 : vector<8x16xf16> to vector<8x16xf32>
+ %C1 = vector.extract_strided_slice %C {offsets = [0, 8], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf32> to vector<16x8xf32>
+
+ // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
+ %D1 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A_f32, %B1_f32, %C1 : vector<16x16xf32>, vector<8x16xf32> into vector<16x8xf32>
+ vector.transfer_write %D1, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<42x64xf32, #gpu.address_space<workgroup>>
+
+ return
+}
diff --git a/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
new file mode 100644
index 0000000000000..79429afd8ff29
--- /dev/null
+++ b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt -split-input-file -test-fold-arith-extf-into-vector-contract-patterns %s | FileCheck %s
+
+
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-LABEL: func.func @fold_arith_extf_into_contract
+// CHECK-SAME: (%[[ARG0:.*]]: vector<64x64xf16>, %[[ARG1:.*]]: vector<64x64xf16>, %[[ARG2:.*]]: vector<64x64xf32>)
+// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xf16>, vector<64x64xf16> into vector<64x64xf32>
+// CHECK-NEXT: return %[[R]] : vector<64x64xf32>
+func.func @fold_arith_extf_into_contract(%arg0: vector<64x64xf16>, %arg1: vector<64x64xf16>, %arg2: vector<64x64xf32>) -> vector<64x64xf32> {
+ %lhs_f32 = arith.extf %arg0 : vector<64x64xf16> to vector<64x64xf32>
+ %rhs_f32 = arith.extf %arg1 : vector<64x64xf16> to vector<64x64xf32>
+ %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_f32, %rhs_f32, %arg2 : vector<64x64xf32>, vector<64x64xf32> into vector<64x64xf32>
+ return %result : vector<64x64xf32>
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 3b0cf2f83f198..4fbddcee574a1 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -709,6 +710,32 @@ struct TestVectorTransferTensorSlicePatterns
}
};
+struct TestFoldArithExtensionIntoVectorContractPatterns
+ : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestFoldArithExtensionIntoVectorContractPatterns)
+
+ StringRef getArgument() const final {
+ return "test-fold-arith-extf-into-vector-contract-patterns";
+ }
+ StringRef getDescription() const final {
+ return "Test patterns that fold arithmetic extension ops into vector "
+ "contract ops";
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect, func::FuncDialect, nvgpu::NVGPUDialect,
+ memref::MemRefDialect, scf::SCFDialect,
+ tensor::TensorDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateFoldArithExtensionPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
} // namespace
namespace mlir {
@@ -745,6 +772,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorGatherLowering>();
PassRegistration<TestVectorTransferTensorSlicePatterns>();
+
+ PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
}
} // namespace test
} // namespace mlir
More information about the Mlir-commits
mailing list