[Mlir-commits] [mlir] 23f0d78 - [mlir][spirv] Add vector.fma lowering to CL.fma

Jakub Kuderski llvmlistbot at llvm.org
Mon Aug 22 20:36:46 PDT 2022


Author: Stanley Winata
Date: 2022-08-22T23:36:07-04:00
New Revision: 23f0d7828443b1317a124a36529161a8e288fe9c

URL: https://github.com/llvm/llvm-project/commit/23f0d7828443b1317a124a36529161a8e288fe9c
DIFF: https://github.com/llvm/llvm-project/commit/23f0d7828443b1317a124a36529161a8e288fe9c.diff

LOG: [mlir][spirv] Add vector.fma lowering to CL.fma

Reviewed By: antiagainst
Patch By: raikonenfnu

Differential Revision: https://reviews.llvm.org/D132424

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 12ce868c15744..04bc41ebb0aa0 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -139,6 +139,7 @@ struct VectorExtractStridedSliceOpConvert final
   }
 };
 
+template <class SPVFMAOp>
 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -148,8 +149,8 @@ struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
     Type dstType = getTypeConverter()->convertType(fmaOp.getType());
     if (!dstType)
       return failure();
-    rewriter.replaceOpWithNewOp<spirv::GLFmaOp>(
-        fmaOp, dstType, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
+    rewriter.replaceOpWithNewOp<SPVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
+                                          adaptor.getRhs(), adaptor.getAcc());
     return success();
   }
 };
@@ -380,9 +381,10 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                          RewritePatternSet &patterns) {
   patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
                VectorExtractElementOpConvert, VectorExtractOpConvert,
-               VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
-               VectorInsertElementOpConvert, VectorInsertOpConvert,
-               VectorReductionPattern, VectorInsertStridedSliceOpConvert,
-               VectorShuffleOpConvert, VectorSplatPattern>(
-      typeConverter, patterns.getContext());
+               VectorExtractStridedSliceOpConvert,
+               VectorFmaOpConvert<spirv::GLFmaOp>,
+               VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
+               VectorInsertOpConvert, VectorReductionPattern,
+               VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+               VectorSplatPattern>(typeConverter, patterns.getContext());
 }

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index bde9b2f1f02d7..a5af59e41453e 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -16,6 +16,27 @@ func.func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) -> (vector<4xf16
 
 // -----
 
+module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Kernel], []>, #spv.resource_limits<>> } {
+
+// CHECK-LABEL: @cl_fma
+//  CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
+//       CHECK:   spv.CL.fma %[[A]], %[[B]], %[[C]] : vector<4xf32>
+func.func @cl_fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf32> {
+  %0 = vector.fma %a, %b, %c: vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @cl_fma_size1_vector
+//       CHECK:   spv.CL.fma %{{.+}} : f32
+func.func @cl_fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<1xf32>) -> vector<1xf32> {
+  %0 = vector.fma %a, %b, %c: vector<1xf32>
+  return %0 : vector<1xf32>
+}
+
+} // end module
+
+// -----
+
 // CHECK-LABEL: @broadcast
 //  CHECK-SAME: %[[A:.*]]: f32
 //       CHECK:   spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]]


        


More information about the Mlir-commits mailing list