[Mlir-commits] [mlir] 23f0d78 - [mlir][spirv] Add vector.fma lowering to CL.fma
Jakub Kuderski
llvmlistbot at llvm.org
Mon Aug 22 20:36:46 PDT 2022
Author: Stanley Winata
Date: 2022-08-22T23:36:07-04:00
New Revision: 23f0d7828443b1317a124a36529161a8e288fe9c
URL: https://github.com/llvm/llvm-project/commit/23f0d7828443b1317a124a36529161a8e288fe9c
DIFF: https://github.com/llvm/llvm-project/commit/23f0d7828443b1317a124a36529161a8e288fe9c.diff
LOG: [mlir][spirv] Add vector.fma lowering to CL.fma
Reviewed By: antiagainst
Patch By: raikonenfnu
Differential Revision: https://reviews.llvm.org/D132424
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 12ce868c15744..04bc41ebb0aa0 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -139,6 +139,7 @@ struct VectorExtractStridedSliceOpConvert final
}
};
+template <class SPVFMAOp>
struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
using OpConversionPattern::OpConversionPattern;
@@ -148,8 +149,8 @@ struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
Type dstType = getTypeConverter()->convertType(fmaOp.getType());
if (!dstType)
return failure();
- rewriter.replaceOpWithNewOp<spirv::GLFmaOp>(
- fmaOp, dstType, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
+ rewriter.replaceOpWithNewOp<SPVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
+ adaptor.getRhs(), adaptor.getAcc());
return success();
}
};
@@ -380,9 +381,10 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
VectorExtractElementOpConvert, VectorExtractOpConvert,
- VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
- VectorInsertElementOpConvert, VectorInsertOpConvert,
- VectorReductionPattern, VectorInsertStridedSliceOpConvert,
- VectorShuffleOpConvert, VectorSplatPattern>(
- typeConverter, patterns.getContext());
+ VectorExtractStridedSliceOpConvert,
+ VectorFmaOpConvert<spirv::GLFmaOp>,
+ VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
+ VectorInsertOpConvert, VectorReductionPattern,
+ VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+ VectorSplatPattern>(typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index bde9b2f1f02d7..a5af59e41453e 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -16,6 +16,27 @@ func.func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) -> (vector<4xf16
// -----
+module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Kernel], []>, #spv.resource_limits<>> } {
+
+// CHECK-LABEL: @cl_fma
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
+// CHECK: spv.CL.fma %[[A]], %[[B]], %[[C]] : vector<4xf32>
+func.func @cl_fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf32> {
+ %0 = vector.fma %a, %b, %c: vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @cl_fma_size1_vector
+// CHECK: spv.CL.fma %{{.+}} : f32
+func.func @cl_fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<1xf32>) -> vector<1xf32> {
+ %0 = vector.fma %a, %b, %c: vector<1xf32>
+ return %0 : vector<1xf32>
+}
+
+} // end module
+
+// -----
+
// CHECK-LABEL: @broadcast
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]]
More information about the Mlir-commits
mailing list