[Mlir-commits] [mlir] a4360ef - [mlir][spirv] Convert single element vector.splat/fma

Lei Zhang llvmlistbot at llvm.org
Mon Jun 13 09:18:33 PDT 2022


Author: Lei Zhang
Date: 2022-06-13T12:18:16-04:00
New Revision: a4360efb2c59f0b80c5838e8b36e8d0490ba6f65

URL: https://github.com/llvm/llvm-project/commit/a4360efb2c59f0b80c5838e8b36e8d0490ba6f65
DIFF: https://github.com/llvm/llvm-project/commit/a4360efb2c59f0b80c5838e8b36e8d0490ba6f65.diff

LOG: [mlir][spirv] Convert single element vector.splat/fma

Reviewed By: ThomasRaoux, hanchung

Differential Revision: https://reviews.llvm.org/D127572

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index e404d75f7d243..cb74987024138 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -145,11 +145,11 @@ struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
   LogicalResult
   matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (!spirv::CompositeType::isValid(fmaOp.getVectorType()))
+    Type dstType = getTypeConverter()->convertType(fmaOp.getType());
+    if (!dstType)
       return failure();
     rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>(
-        fmaOp, fmaOp.getType(), adaptor.getLhs(), adaptor.getRhs(),
-        adaptor.getAcc());
+        fmaOp, dstType, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
     return success();
   }
 };
@@ -321,13 +321,18 @@ class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
   LogicalResult
   matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType dstVecType = op.getType();
-    if (!spirv::CompositeType::isValid(dstVecType))
+    Type dstType = getTypeConverter()->convertType(op.getType());
+    if (!dstType)
       return failure();
-    SmallVector<Value, 4> source(dstVecType.getNumElements(),
-                                 adaptor.getInput());
-    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstVecType,
-                                                             source);
+    if (dstType.isa<spirv::ScalarType>()) {
+      rewriter.replaceOp(op, adaptor.getInput());
+    } else {
+      auto dstVecType = dstType.cast<VectorType>();
+      SmallVector<Value, 4> source(dstVecType.getNumElements(),
+                                   adaptor.getInput());
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
+                                                               source);
+    }
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 8a3f71b380c4c..76ce449c51bde 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -171,6 +171,15 @@ func.func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vecto
 
 // -----
 
+// CHECK-LABEL: @fma_size1_vector
+//       CHECK:   spv.GLSL.Fma %{{.+}} : f32
+func.func @fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<1xf32>) -> vector<1xf32> {
+  %0 = vector.fma %a, %b, %c: vector<1xf32>
+  return %0 : vector<1xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @splat
 //  CHECK-SAME: (%[[A:.+]]: f32)
 //       CHECK:   %[[VAL:.+]] = spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32>
@@ -182,6 +191,17 @@ func.func @splat(%f : f32) -> vector<4xf32> {
 
 // -----
 
+// CHECK-LABEL: func @splat_size1_vector
+//  CHECK-SAME: (%[[A:.+]]: f32)
+//       CHECK:   %[[VAL:.+]] = builtin.unrealized_conversion_cast %[[A]]
+//       CHECK:   return %[[VAL]]
+func.func @splat_size1_vector(%f : f32) -> vector<1xf32> {
+  %splat = vector.splat %f : vector<1xf32>
+  return %splat : vector<1xf32>
+}
+
+// -----
+
 // CHECK-LABEL:  func @shuffle
 //  CHECK-SAME:  %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>
 //       CHECK:    %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]


        


More information about the Mlir-commits mailing list