[llvm-branch-commits] [mlir] 3a56a96 - [mlir][spirv] Define spv.GLSL.Fma and add lowerings

Lei Zhang via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jan 19 06:27:09 PST 2021


Author: Lei Zhang
Date: 2021-01-19T09:14:21-05:00
New Revision: 3a56a96664de955888d63c49a33808e3a1a294d9

URL: https://github.com/llvm/llvm-project/commit/3a56a96664de955888d63c49a33808e3a1a294d9
DIFF: https://github.com/llvm/llvm-project/commit/3a56a96664de955888d63c49a33808e3a1a294d9.diff

LOG: [mlir][spirv] Define spv.GLSL.Fma and add lowerings

Also changes some rewriter.create + rewriter.replaceOp calls
into rewriter.replaceOpWithNewOp calls.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D94965

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/simple.mlir
    mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir
    mlir/test/Target/SPIRV/glsl-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
index a566b7503a15..c34cd98dbb39 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td
@@ -972,4 +972,44 @@ def SPV_GLSLSClampOp : SPV_GLSLTernaryArithmeticOp<"SClamp", 45, SPV_SignedInt>
   }];
 }
 
+// -----
+
+def SPV_GLSLFmaOp : SPV_GLSLTernaryArithmeticOp<"Fma", 50, SPV_Float> {
+  let summary = "Computes a * b + c.";
+
+  let description = [{
+    In uses where this operation is decorated with NoContraction:
+
+    - fma is considered a single operation, whereas the expression a * b + c
+      is considered two operations.
+    - The precision of fma can 
diff er from the precision of the expression
+      a * b + c.
+    - fma will be computed with the same precision as any other fma decorated
+      with NoContraction, giving invariant results for the same input values
+      of a, b, and c.
+
+    Otherwise, in the absence of a NoContraction decoration, there are no
+    special constraints on the number of operations or 
diff erence in precision
+    between fma and the expression a * b +c.
+
+    The operands must all be a scalar or vector whose component type is
+    floating-point.
+
+    Result Type and the type of all operands must be the same type. Results
+    are computed per component.
+
+    <!-- End of AutoGen section -->
+    ```
+    fma-op ::= ssa-id `=` `spv.GLSL.Fma` ssa-use, ssa-use, ssa-use `:`
+               float-scalar-vector-type
+    ```
+    #### Example:
+
+    ```mlir
+    %0 = spv.GLSL.Fma %a, %b, %c : f32
+    %1 = spv.GLSL.Fma %a, %b, %c : vector<3xf16>
+    ```
+  }];
+}
+
 #endif // MLIR_DIALECT_SPIRV_IR_GLSL_OPS

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 1509836ef2e2..52a35a17869f 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -36,9 +36,8 @@ struct VectorBroadcastConvert final
     vector::BroadcastOp::Adaptor adaptor(operands);
     SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
                                  adaptor.source());
-    Value construct = rewriter.create<spirv::CompositeConstructOp>(
-        broadcastOp.getLoc(), broadcastOp.getVectorType(), source);
-    rewriter.replaceOp(broadcastOp, construct);
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
+        broadcastOp, broadcastOp.getVectorType(), source);
     return success();
   }
 };
@@ -55,9 +54,23 @@ struct VectorExtractOpConvert final
       return failure();
     vector::ExtractOp::Adaptor adaptor(operands);
     int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt();
-    Value newExtract = rewriter.create<spirv::CompositeExtractOp>(
-        extractOp.getLoc(), adaptor.vector(), id);
-    rewriter.replaceOp(extractOp, newExtract);
+    rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+        extractOp, adaptor.vector(), id);
+    return success();
+  }
+};
+
+struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::FMAOp fmaOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!spirv::CompositeType::isValid(fmaOp.getVectorType()))
+      return failure();
+    vector::FMAOp::Adaptor adaptor(operands);
+    rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>(
+        fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc());
     return success();
   }
 };
@@ -74,9 +87,8 @@ struct VectorInsertOpConvert final
       return failure();
     vector::InsertOp::Adaptor adaptor(operands);
     int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt();
-    Value newInsert = rewriter.create<spirv::CompositeInsertOp>(
-        insertOp.getLoc(), adaptor.source(), adaptor.dest(), id);
-    rewriter.replaceOp(insertOp, newInsert);
+    rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+        insertOp, adaptor.source(), adaptor.dest(), id);
     return success();
   }
 };
@@ -92,10 +104,9 @@ struct VectorExtractElementOpConvert final
     if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
       return failure();
     vector::ExtractElementOp::Adaptor adaptor(operands);
-    Value newExtractElement = rewriter.create<spirv::VectorExtractDynamicOp>(
-        extractElementOp.getLoc(), extractElementOp.getType(), adaptor.vector(),
+    rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
+        extractElementOp, extractElementOp.getType(), adaptor.vector(),
         extractElementOp.position());
-    rewriter.replaceOp(extractElementOp, newExtractElement);
     return success();
   }
 };
@@ -111,10 +122,9 @@ struct VectorInsertElementOpConvert final
     if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
       return failure();
     vector::InsertElementOp::Adaptor adaptor(operands);
-    Value newInsertElement = rewriter.create<spirv::VectorInsertDynamicOp>(
-        insertElementOp.getLoc(), insertElementOp.getType(),
-        insertElementOp.dest(), adaptor.source(), insertElementOp.position());
-    rewriter.replaceOp(insertElementOp, newInsertElement);
+    rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
+        insertElementOp, insertElementOp.getType(), insertElementOp.dest(),
+        adaptor.source(), insertElementOp.position());
     return success();
   }
 };
@@ -124,7 +134,8 @@ struct VectorInsertElementOpConvert final
 void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
                                          SPIRVTypeConverter &typeConverter,
                                          OwningRewritePatternList &patterns) {
-  patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert,
-                  VectorInsertOpConvert, VectorExtractElementOpConvert,
-                  VectorInsertElementOpConvert>(typeConverter, context);
+  patterns.insert<VectorBroadcastConvert, VectorExtractElementOpConvert,
+                  VectorExtractOpConvert, VectorFmaOpConvert,
+                  VectorInsertOpConvert, VectorInsertElementOpConvert>(
+      typeConverter, context);
 }

diff  --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
index 3594a6db805e..fddfd911fb19 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -57,3 +57,13 @@ func @insert_element_negative(%val: f32, %arg0 : vector<5xf32>, %id : i32) {
   %0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32>
   spv.Return
 }
+
+// -----
+
+// CHECK-LABEL: func @fma
+//  CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
+//       CHECK:   spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32>
+func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) {
+  %0 = vector.fma %a, %b, %c: vector<4xf32>
+  spv.Return
+}

diff  --git a/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir
index 42377c2277a7..0533396406f7 100644
--- a/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir
@@ -345,3 +345,23 @@ func @fclamp(%arg0 : i32, %min : i32, %max : i32) -> () {
   %2 = spv.GLSL.SClamp %arg0, %min, %max : i32
   return
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.GLSL.Fma
+//===----------------------------------------------------------------------===//
+
+func @fma(%a : f32, %b : f32, %c : f32) -> () {
+  // CHECK: spv.GLSL.Fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
+  %2 = spv.GLSL.Fma %a, %b, %c : f32
+  return
+}
+
+// -----
+
+func @fma(%a : vector<3xf32>, %b : vector<3xf32>, %c : vector<3xf32>) -> () {
+  // CHECK: spv.GLSL.Fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : vector<3xf32>
+  %2 = spv.GLSL.Fma %a, %b, %c : vector<3xf32>
+  return
+}

diff  --git a/mlir/test/Target/SPIRV/glsl-ops.mlir b/mlir/test/Target/SPIRV/glsl-ops.mlir
index d635bde9cbf1..4dfd249288b0 100644
--- a/mlir/test/Target/SPIRV/glsl-ops.mlir
+++ b/mlir/test/Target/SPIRV/glsl-ops.mlir
@@ -48,4 +48,10 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     %13 = spv.GLSL.SClamp %arg0, %arg1, %arg2 : si32
     spv.Return
   }
+
+  spv.func @fma(%arg0 : f32, %arg1 : f32, %arg2 : f32) "None" {
+    // CHECK: spv.GLSL.Fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
+    %13 = spv.GLSL.Fma %arg0, %arg1, %arg2 : f32
+    spv.Return
+  }
 }


        


More information about the llvm-branch-commits mailing list