[Mlir-commits] [mlir] a4360ef - [mlir][spirv] Convert single element vector.splat/fma
Lei Zhang
llvmlistbot at llvm.org
Mon Jun 13 09:18:33 PDT 2022
Author: Lei Zhang
Date: 2022-06-13T12:18:16-04:00
New Revision: a4360efb2c59f0b80c5838e8b36e8d0490ba6f65
URL: https://github.com/llvm/llvm-project/commit/a4360efb2c59f0b80c5838e8b36e8d0490ba6f65
DIFF: https://github.com/llvm/llvm-project/commit/a4360efb2c59f0b80c5838e8b36e8d0490ba6f65.diff
LOG: [mlir][spirv] Convert single element vector.splat/fma
Reviewed By: ThomasRaoux, hanchung
Differential Revision: https://reviews.llvm.org/D127572
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index e404d75f7d243..cb74987024138 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -145,11 +145,11 @@ struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
LogicalResult
matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (!spirv::CompositeType::isValid(fmaOp.getVectorType()))
+ Type dstType = getTypeConverter()->convertType(fmaOp.getType());
+ if (!dstType)
return failure();
rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>(
- fmaOp, fmaOp.getType(), adaptor.getLhs(), adaptor.getRhs(),
- adaptor.getAcc());
+ fmaOp, dstType, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
return success();
}
};
@@ -321,13 +321,18 @@ class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
LogicalResult
matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- VectorType dstVecType = op.getType();
- if (!spirv::CompositeType::isValid(dstVecType))
+ Type dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
return failure();
- SmallVector<Value, 4> source(dstVecType.getNumElements(),
- adaptor.getInput());
- rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstVecType,
- source);
+ if (dstType.isa<spirv::ScalarType>()) {
+ rewriter.replaceOp(op, adaptor.getInput());
+ } else {
+ auto dstVecType = dstType.cast<VectorType>();
+ SmallVector<Value, 4> source(dstVecType.getNumElements(),
+ adaptor.getInput());
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
+ source);
+ }
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 8a3f71b380c4c..76ce449c51bde 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -171,6 +171,15 @@ func.func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vecto
// -----
+// CHECK-LABEL: @fma_size1_vector
+// CHECK: spv.GLSL.Fma %{{.+}} : f32
+func.func @fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<1xf32>) -> vector<1xf32> {
+ %0 = vector.fma %a, %b, %c: vector<1xf32>
+ return %0 : vector<1xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @splat
// CHECK-SAME: (%[[A:.+]]: f32)
// CHECK: %[[VAL:.+]] = spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32>
@@ -182,6 +191,17 @@ func.func @splat(%f : f32) -> vector<4xf32> {
// -----
+// CHECK-LABEL: func @splat_size1_vector
+// CHECK-SAME: (%[[A:.+]]: f32)
+// CHECK: %[[VAL:.+]] = builtin.unrealized_conversion_cast %[[A]]
+// CHECK: return %[[VAL]]
+func.func @splat_size1_vector(%f : f32) -> vector<1xf32> {
+ %splat = vector.splat %f : vector<1xf32>
+ return %splat : vector<1xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @shuffle
// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>
// CHECK: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
More information about the Mlir-commits
mailing list