[llvm-branch-commits] [mlir] 3a56a96 - [mlir][spirv] Define spv.GLSL.Fma and add lowerings
Lei Zhang via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jan 19 06:27:09 PST 2021
Author: Lei Zhang
Date: 2021-01-19T09:14:21-05:00
New Revision: 3a56a96664de955888d63c49a33808e3a1a294d9
URL: https://github.com/llvm/llvm-project/commit/3a56a96664de955888d63c49a33808e3a1a294d9
DIFF: https://github.com/llvm/llvm-project/commit/3a56a96664de955888d63c49a33808e3a1a294d9.diff
LOG: [mlir][spirv] Define spv.GLSL.Fma and add lowerings
Also changes some rewriter.create + rewriter.replaceOp calls
into rewriter.replaceOpWithNewOp calls.
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D94965
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/simple.mlir
mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir
mlir/test/Target/SPIRV/glsl-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
index a566b7503a15..c34cd98dbb39 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
@@ -972,4 +972,44 @@ def SPV_GLSLSClampOp : SPV_GLSLTernaryArithmeticOp<"SClamp", 45, SPV_SignedInt>
}];
}
+// -----
+
+def SPV_GLSLFmaOp : SPV_GLSLTernaryArithmeticOp<"Fma", 50, SPV_Float> {
+ let summary = "Computes a * b + c.";
+
+ let description = [{
+ In uses where this operation is decorated with NoContraction:
+
+ - fma is considered a single operation, whereas the expression a * b + c
+ is considered two operations.
+ - The precision of fma can
diff er from the precision of the expression
+ a * b + c.
+ - fma will be computed with the same precision as any other fma decorated
+ with NoContraction, giving invariant results for the same input values
+ of a, b, and c.
+
+ Otherwise, in the absence of a NoContraction decoration, there are no
+ special constraints on the number of operations or
diff erence in precision
+ between fma and the expression a * b +c.
+
+ The operands must all be a scalar or vector whose component type is
+ floating-point.
+
+ Result Type and the type of all operands must be the same type. Results
+ are computed per component.
+
+ <!-- End of AutoGen section -->
+ ```
+ fma-op ::= ssa-id `=` `spv.GLSL.Fma` ssa-use, ssa-use, ssa-use `:`
+ float-scalar-vector-type
+ ```
+ #### Example:
+
+ ```mlir
+ %0 = spv.GLSL.Fma %a, %b, %c : f32
+ %1 = spv.GLSL.Fma %a, %b, %c : vector<3xf16>
+ ```
+ }];
+}
+
#endif // MLIR_DIALECT_SPIRV_IR_GLSL_OPS
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 1509836ef2e2..52a35a17869f 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -36,9 +36,8 @@ struct VectorBroadcastConvert final
vector::BroadcastOp::Adaptor adaptor(operands);
SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
adaptor.source());
- Value construct = rewriter.create<spirv::CompositeConstructOp>(
- broadcastOp.getLoc(), broadcastOp.getVectorType(), source);
- rewriter.replaceOp(broadcastOp, construct);
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
+ broadcastOp, broadcastOp.getVectorType(), source);
return success();
}
};
@@ -55,9 +54,23 @@ struct VectorExtractOpConvert final
return failure();
vector::ExtractOp::Adaptor adaptor(operands);
int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt();
- Value newExtract = rewriter.create<spirv::CompositeExtractOp>(
- extractOp.getLoc(), adaptor.vector(), id);
- rewriter.replaceOp(extractOp, newExtract);
+ rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+ extractOp, adaptor.vector(), id);
+ return success();
+ }
+};
+
+struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!spirv::CompositeType::isValid(fmaOp.getVectorType()))
+ return failure();
+ vector::FMAOp::Adaptor adaptor(operands);
+ rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>(
+ fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc());
return success();
}
};
@@ -74,9 +87,8 @@ struct VectorInsertOpConvert final
return failure();
vector::InsertOp::Adaptor adaptor(operands);
int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt();
- Value newInsert = rewriter.create<spirv::CompositeInsertOp>(
- insertOp.getLoc(), adaptor.source(), adaptor.dest(), id);
- rewriter.replaceOp(insertOp, newInsert);
+ rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+ insertOp, adaptor.source(), adaptor.dest(), id);
return success();
}
};
@@ -92,10 +104,9 @@ struct VectorExtractElementOpConvert final
if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
return failure();
vector::ExtractElementOp::Adaptor adaptor(operands);
- Value newExtractElement = rewriter.create<spirv::VectorExtractDynamicOp>(
- extractElementOp.getLoc(), extractElementOp.getType(), adaptor.vector(),
+ rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
+ extractElementOp, extractElementOp.getType(), adaptor.vector(),
extractElementOp.position());
- rewriter.replaceOp(extractElementOp, newExtractElement);
return success();
}
};
@@ -111,10 +122,9 @@ struct VectorInsertElementOpConvert final
if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
return failure();
vector::InsertElementOp::Adaptor adaptor(operands);
- Value newInsertElement = rewriter.create<spirv::VectorInsertDynamicOp>(
- insertElementOp.getLoc(), insertElementOp.getType(),
- insertElementOp.dest(), adaptor.source(), insertElementOp.position());
- rewriter.replaceOp(insertElementOp, newInsertElement);
+ rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
+ insertElementOp, insertElementOp.getType(), insertElementOp.dest(),
+ adaptor.source(), insertElementOp.position());
return success();
}
};
@@ -124,7 +134,8 @@ struct VectorInsertElementOpConvert final
void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
- patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert,
- VectorInsertOpConvert, VectorExtractElementOpConvert,
- VectorInsertElementOpConvert>(typeConverter, context);
+ patterns.insert<VectorBroadcastConvert, VectorExtractElementOpConvert,
+ VectorExtractOpConvert, VectorFmaOpConvert,
+ VectorInsertOpConvert, VectorInsertElementOpConvert>(
+ typeConverter, context);
}
diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
index 3594a6db805e..fddfd911fb19 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -57,3 +57,13 @@ func @insert_element_negative(%val: f32, %arg0 : vector<5xf32>, %id : i32) {
%0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32>
spv.Return
}
+
+// -----
+
+// CHECK-LABEL: func @fma
+// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
+// CHECK: spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32>
+func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) {
+ %0 = vector.fma %a, %b, %c: vector<4xf32>
+ spv.Return
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir
index 42377c2277a7..0533396406f7 100644
--- a/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir
@@ -345,3 +345,23 @@ func @fclamp(%arg0 : i32, %min : i32, %max : i32) -> () {
%2 = spv.GLSL.SClamp %arg0, %min, %max : i32
return
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.GLSL.Fma
+//===----------------------------------------------------------------------===//
+
+func @fma(%a : f32, %b : f32, %c : f32) -> () {
+ // CHECK: spv.GLSL.Fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
+ %2 = spv.GLSL.Fma %a, %b, %c : f32
+ return
+}
+
+// -----
+
+func @fma(%a : vector<3xf32>, %b : vector<3xf32>, %c : vector<3xf32>) -> () {
+ // CHECK: spv.GLSL.Fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : vector<3xf32>
+ %2 = spv.GLSL.Fma %a, %b, %c : vector<3xf32>
+ return
+}
diff --git a/mlir/test/Target/SPIRV/glsl-ops.mlir b/mlir/test/Target/SPIRV/glsl-ops.mlir
index d635bde9cbf1..4dfd249288b0 100644
--- a/mlir/test/Target/SPIRV/glsl-ops.mlir
+++ b/mlir/test/Target/SPIRV/glsl-ops.mlir
@@ -48,4 +48,10 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
%13 = spv.GLSL.SClamp %arg0, %arg1, %arg2 : si32
spv.Return
}
+
+ spv.func @fma(%arg0 : f32, %arg1 : f32, %arg2 : f32) "None" {
+ // CHECK: spv.GLSL.Fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
+ %13 = spv.GLSL.Fma %arg0, %arg1, %arg2 : f32
+ spv.Return
+ }
}
More information about the llvm-branch-commits
mailing list