[Mlir-commits] [mlir] ffd4583 - [mlir][spirv] Change standard op patterns to consider type conversion

Lei Zhang llvmlistbot at llvm.org
Wed Mar 18 17:13:30 PDT 2020


Author: Lei Zhang
Date: 2020-03-18T20:11:05-04:00
New Revision: ffd4583c6aee0d6ca62efe962d6f2cb15de8bdab

URL: https://github.com/llvm/llvm-project/commit/ffd4583c6aee0d6ca62efe962d6f2cb15de8bdab
DIFF: https://github.com/llvm/llvm-project/commit/ffd4583c6aee0d6ca62efe962d6f2cb15de8bdab.diff

LOG: [mlir][spirv] Change standard op patterns to consider type conversion

Previously we have a few patterns that were written with DRR. DRR
at the moment does not work nicely with dialect conversion framework.
It generates normal RewritePatterns, while the dialect conversion
framework requires ConversionPatterns to take into consideration
the type conversion. So this commit starts to change existing DRR
patterns for standard ops to OpConversionPattern to incorporate the
SPIR-V type conversion. All patterns are converted except the one
for constant ops, which will happen in a subsequent commit.

Differential Revision: https://reviews.llvm.org/D76245

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
    mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td
    mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
    mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 310dcd8a86bd..5b6243fa74e9 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -6,28 +6,93 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements patterns to convert Standard Ops to the SPIR-V dialect.
+// This file implements patterns to convert standard ops to SPIR-V ops.
 //
 //===----------------------------------------------------------------------===//
+
 #include "mlir/Dialect/SPIRV/LayoutUtils.h"
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/SetVector.h"
 
 using namespace mlir;
 
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// Returns true if the given `type` is a boolean scalar or vector type.
+static bool isBoolScalarOrVector(Type type) {
+  if (type.isInteger(1))
+    return true;
+  if (auto vecType = type.dyn_cast<VectorType>())
+    return vecType.getElementType().isInteger(1);
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
 
+// Note that DRR cannot be used for the patterns in this file: we may need to
+// convert type along the way, which requires ConversionPattern. DRR generates
+// normal RewritePattern.
+
 namespace {
 
-/// Convert composite constant operation to SPIR-V dialect.
-// TODO(denis0x0D) : move to DRR.
-class ConstantCompositeOpConversion final : public SPIRVOpLowering<ConstantOp> {
+/// Converts binary standard operations to SPIR-V operations.
+template <typename StdOp, typename SPIRVOp>
+class BinaryOpPattern final : public SPIRVOpLowering<StdOp> {
+public:
+  using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
+
+  LogicalResult
+  matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(operands.size() == 2);
+    auto dstType = this->typeConverter.convertType(operation.getType());
+    if (!dstType)
+      return failure();
+    rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands,
+                                                  ArrayRef<NamedAttribute>());
+    return success();
+  }
+};
+
+/// Converts bitwise standard operations to SPIR-V operations. This is a special
+/// pattern other than the BinaryOpPatternPattern because if the operands are
+/// boolean values, SPIR-V uses 
diff erent operations (`SPIRVLogicalOp`). For
+/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
+template <typename StdOp, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
+class BitwiseOpPattern final : public SPIRVOpLowering<StdOp> {
+public:
+  using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
+
+  LogicalResult
+  matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(operands.size() == 2);
+    auto dstType =
+        this->typeConverter.convertType(operation.getResult().getType());
+    if (!dstType)
+      return failure();
+    if (isBoolScalarOrVector(operands.front().getType())) {
+      rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
+          operation, dstType, operands, ArrayRef<NamedAttribute>());
+    } else {
+      rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
+          operation, dstType, operands, ArrayRef<NamedAttribute>());
+    }
+    return success();
+  }
+};
+
+/// Converts composite std.constant operation to spv.constant.
+class ConstantCompositeOpPattern final : public SPIRVOpLowering<ConstantOp> {
 public:
   using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
 
@@ -36,12 +101,8 @@ class ConstantCompositeOpConversion final : public SPIRVOpLowering<ConstantOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// Convert constant operation with IndexType return to SPIR-V constant
-/// operation. Since IndexType is not used within SPIR-V dialect, this needs
-/// special handling to make sure the result type and the type of the value
-/// attribute are consistent.
-// TODO(ravishankarm) : This should be moved into DRR.
-class ConstantIndexOpConversion final : public SPIRVOpLowering<ConstantOp> {
+/// Converts scalar std.constant operation to spv.constant.
+class ConstantScalarOpPattern final : public SPIRVOpLowering<ConstantOp> {
 public:
   using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
 
@@ -50,8 +111,8 @@ class ConstantIndexOpConversion final : public SPIRVOpLowering<ConstantOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// Convert floating-point comparison operations to SPIR-V dialect.
-class CmpFOpConversion final : public SPIRVOpLowering<CmpFOp> {
+/// Converts floating-point comparison operations to SPIR-V ops.
+class CmpFOpPattern final : public SPIRVOpLowering<CmpFOp> {
 public:
   using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering;
 
@@ -60,8 +121,8 @@ class CmpFOpConversion final : public SPIRVOpLowering<CmpFOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// Convert compare operation to SPIR-V dialect.
-class CmpIOpConversion final : public SPIRVOpLowering<CmpIOp> {
+/// Converts integer compare operation to SPIR-V ops.
+class CmpIOpPattern final : public SPIRVOpLowering<CmpIOp> {
 public:
   using SPIRVOpLowering<CmpIOp>::SPIRVOpLowering;
 
@@ -70,33 +131,8 @@ class CmpIOpConversion final : public SPIRVOpLowering<CmpIOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// Convert integer binary operations to SPIR-V operations. Cannot use
-/// tablegen for this. If the integer operation is on variables of IndexType,
-/// the type of the return value of the replacement operation 
diff ers from
-/// that of the replaced operation. This is not handled in tablegen-based
-/// pattern specification.
-// TODO(ravishankarm) : This should be moved into DRR.
-template <typename StdOp, typename SPIRVOp>
-class IntegerOpConversion final : public SPIRVOpLowering<StdOp> {
-public:
-  using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
-
-  LogicalResult
-  matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto resultType =
-        this->typeConverter.convertType(operation.getResult().getType());
-    rewriter.template replaceOpWithNewOp<SPIRVOp>(
-        operation, resultType, operands, ArrayRef<NamedAttribute>());
-    return success();
-  }
-};
-
-/// Convert load -> spv.LoadOp. The operands of the replaced operation are of
-/// IndexType while that of the replacement operation are of type i32. This is
-/// not supported in tablegen based pattern specification.
-// TODO(ravishankarm) : This should be moved into DRR.
-class LoadOpConversion final : public SPIRVOpLowering<LoadOp> {
+/// Converts std.load to spv.Load.
+class LoadOpPattern final : public SPIRVOpLowering<LoadOp> {
 public:
   using SPIRVOpLowering<LoadOp>::SPIRVOpLowering;
 
@@ -105,9 +141,8 @@ class LoadOpConversion final : public SPIRVOpLowering<LoadOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// Convert return -> spv.Return.
-// TODO(ravishankarm) : This should be moved into DRR.
-class ReturnOpConversion final : public SPIRVOpLowering<ReturnOp> {
+/// Converts std.return to spv.Return.
+class ReturnOpPattern final : public SPIRVOpLowering<ReturnOp> {
 public:
   using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
 
@@ -116,9 +151,8 @@ class ReturnOpConversion final : public SPIRVOpLowering<ReturnOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// Convert select -> spv.Select
-// TODO(ravishankarm) : This should be moved into DRR.
-class SelectOpConversion final : public SPIRVOpLowering<SelectOp> {
+/// Converts std.select to spv.Select.
+class SelectOpPattern final : public SPIRVOpLowering<SelectOp> {
 public:
   using SPIRVOpLowering<SelectOp>::SPIRVOpLowering;
   LogicalResult
@@ -126,11 +160,8 @@ class SelectOpConversion final : public SPIRVOpLowering<SelectOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
-/// Convert store -> spv.StoreOp. The operands of the replaced operation are
-/// of IndexType while that of the replacement operation are of type i32. This
-/// is not supported in tablegen based pattern specification.
-// TODO(ravishankarm) : This should be moved into DRR.
-class StoreOpConversion final : public SPIRVOpLowering<StoreOp> {
+/// Converts std.store to spv.Store.
+class StoreOpPattern final : public SPIRVOpLowering<StoreOp> {
 public:
   using SPIRVOpLowering<StoreOp>::SPIRVOpLowering;
 
@@ -139,13 +170,47 @@ class StoreOpConversion final : public SPIRVOpLowering<StoreOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Converts type-casting standard operations to SPIR-V operations.
+template <typename StdOp, typename SPIRVOp>
+class TypeCastingOpPattern final : public SPIRVOpLowering<StdOp> {
+public:
+  using SPIRVOpLowering<StdOp>::SPIRVOpLowering;
+
+  LogicalResult
+  matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(operands.size() == 1);
+    auto dstType =
+        this->typeConverter.convertType(operation.getResult().getType());
+    if (dstType == operands.front().getType()) {
+      // Due to type conversion, we are seeing the same source and target type.
+      // Then we can just erase this operation by forwarding its operand.
+      rewriter.replaceOp(operation, operands.front());
+    } else {
+      rewriter.template replaceOpWithNewOp<SPIRVOp>(
+          operation, dstType, operands, ArrayRef<NamedAttribute>());
+    }
+    return success();
+  }
+};
+
+/// Converts std.xor to SPIR-V operations.
+class XOrOpPattern final : public SPIRVOpLowering<XOrOp> {
+public:
+  using SPIRVOpLowering<XOrOp>::SPIRVOpLowering;
+
+  LogicalResult
+  matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
 // ConstantOp with composite type.
 //===----------------------------------------------------------------------===//
 
-LogicalResult ConstantCompositeOpConversion::matchAndRewrite(
+LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
     ConstantOp constCompositeOp, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
   auto compositeType =
@@ -175,10 +240,10 @@ LogicalResult ConstantCompositeOpConversion::matchAndRewrite(
 }
 
 //===----------------------------------------------------------------------===//
-// ConstantOp with index type.
+// ConstantOp with scalar type.
 //===----------------------------------------------------------------------===//
 
-LogicalResult ConstantIndexOpConversion::matchAndRewrite(
+LogicalResult ConstantScalarOpPattern::matchAndRewrite(
     ConstantOp constIndexOp, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
   if (!constIndexOp.getResult().getType().isa<IndexType>()) {
@@ -213,8 +278,8 @@ LogicalResult ConstantIndexOpConversion::matchAndRewrite(
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
-                                  ConversionPatternRewriter &rewriter) const {
+CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
+                               ConversionPatternRewriter &rewriter) const {
   CmpFOpOperandAdaptor cmpFOpOperands(operands);
 
   switch (cmpFOp.getPredicate()) {
@@ -253,8 +318,8 @@ CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
-                                  ConversionPatternRewriter &rewriter) const {
+CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
+                               ConversionPatternRewriter &rewriter) const {
   CmpIOpOperandAdaptor cmpIOpOperands(operands);
 
   switch (cmpIOp.getPredicate()) {
@@ -286,8 +351,8 @@ CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
-                                  ConversionPatternRewriter &rewriter) const {
+LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
+                               ConversionPatternRewriter &rewriter) const {
   LoadOpOperandAdaptor loadOperands(operands);
   auto loadPtr = spirv::getElementPtr(
       typeConverter, loadOp.memref().getType().cast<MemRefType>(),
@@ -301,8 +366,8 @@ LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-ReturnOpConversion::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
-                                    ConversionPatternRewriter &rewriter) const {
+ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
+                                 ConversionPatternRewriter &rewriter) const {
   if (returnOp.getNumOperands()) {
     return failure();
   }
@@ -315,8 +380,8 @@ ReturnOpConversion::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
-                                    ConversionPatternRewriter &rewriter) const {
+SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
+                                 ConversionPatternRewriter &rewriter) const {
   SelectOpOperandAdaptor selectOperands(operands);
   rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
                                                selectOperands.true_value(),
@@ -329,8 +394,8 @@ SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
-                                   ConversionPatternRewriter &rewriter) const {
+StoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
+                                ConversionPatternRewriter &rewriter) const {
   StoreOpOperandAdaptor storeOperands(operands);
   auto storePtr = spirv::getElementPtr(
       typeConverter, storeOp.memref().getType().cast<MemRefType>(),
@@ -341,6 +406,31 @@ StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// XorOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
+                              ConversionPatternRewriter &rewriter) const {
+  assert(operands.size() == 2);
+
+  if (isBoolScalarOrVector(operands.front().getType()))
+    return failure();
+
+  auto dstType = typeConverter.convertType(xorOp.getType());
+  if (!dstType)
+    return failure();
+  rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(xorOp, dstType, operands,
+                                                   ArrayRef<NamedAttribute>());
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern population
+//===----------------------------------------------------------------------===//
+
 namespace {
 /// Import the Standard Ops to SPIR-V Patterns.
 #include "StandardToSPIRV.cpp.inc"
@@ -352,14 +442,29 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
                                      OwningRewritePatternList &patterns) {
   // Add patterns that lower operations into SPIR-V dialect.
   populateWithGenerated(context, &patterns);
-  patterns.insert<ConstantCompositeOpConversion, ConstantIndexOpConversion,
-                  CmpFOpConversion, CmpIOpConversion,
-                  IntegerOpConversion<AddIOp, spirv::IAddOp>,
-                  IntegerOpConversion<MulIOp, spirv::IMulOp>,
-                  IntegerOpConversion<SignedDivIOp, spirv::SDivOp>,
-                  IntegerOpConversion<SignedRemIOp, spirv::SModOp>,
-                  IntegerOpConversion<SubIOp, spirv::ISubOp>, LoadOpConversion,
-                  ReturnOpConversion, SelectOpConversion, StoreOpConversion>(
+  patterns.insert<
+      BinaryOpPattern<AddFOp, spirv::FAddOp>,
+      BinaryOpPattern<AddIOp, spirv::IAddOp>,
+      BinaryOpPattern<DivFOp, spirv::FDivOp>,
+      BinaryOpPattern<MulFOp, spirv::FMulOp>,
+      BinaryOpPattern<MulIOp, spirv::IMulOp>,
+      BinaryOpPattern<RemFOp, spirv::FRemOp>,
+      BinaryOpPattern<ShiftLeftOp, spirv::ShiftLeftLogicalOp>,
+      BinaryOpPattern<SignedShiftRightOp, spirv::ShiftRightArithmeticOp>,
+      BinaryOpPattern<SignedDivIOp, spirv::SDivOp>,
+      BinaryOpPattern<SignedRemIOp, spirv::SRemOp>,
+      BinaryOpPattern<SubFOp, spirv::FSubOp>,
+      BinaryOpPattern<SubIOp, spirv::ISubOp>,
+      BinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
+      BinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
+      BinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
+      BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
+      BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
+      ConstantCompositeOpPattern, ConstantScalarOpPattern, CmpFOpPattern,
+      CmpIOpPattern, LoadOpPattern, ReturnOpPattern, SelectOpPattern,
+      StoreOpPattern, TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
+      TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>,
+      TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>, XOrOpPattern>(
       context, typeConverter);
 }
 } // namespace mlir

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td
index a23ae5fe81c9..016344e16304 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td
@@ -16,34 +16,6 @@
 include "mlir/Dialect/StandardOps/IR/Ops.td"
 include "mlir/Dialect/SPIRV/SPIRVOps.td"
 
-class BinaryOpPattern<Type type, Op src, Op tgt> :
-      Pat<(src SPV_ScalarOrVectorOf<type>:$l, SPV_ScalarOrVectorOf<type>:$r),
-          (tgt $l, $r)>;
-
-class UnaryOpPattern<Type type, Op src, Op tgt> :
-      Pat<(src type:$input),
-          (tgt $input)>;
-
-def : BinaryOpPattern<SPV_Bool, AndOp, SPV_LogicalAndOp>;
-def : BinaryOpPattern<SPV_Bool, OrOp, SPV_LogicalOrOp>;
-def : BinaryOpPattern<SPV_Integer, AndOp, SPV_BitwiseAndOp>;
-def : BinaryOpPattern<SPV_Integer, OrOp, SPV_BitwiseOrOp>;
-def : BinaryOpPattern<SPV_Integer, ShiftLeftOp, SPV_ShiftLeftLogicalOp>;
-def : BinaryOpPattern<SPV_Integer, SignedShiftRightOp,
-                      SPV_ShiftRightArithmeticOp>;
-def : BinaryOpPattern<SPV_Integer, UnsignedShiftRightOp,
-                      SPV_ShiftRightLogicalOp>;
-def : BinaryOpPattern<SPV_Integer, XOrOp, SPV_BitwiseXorOp>;
-def : BinaryOpPattern<SPV_Float, AddFOp, SPV_FAddOp>;
-def : BinaryOpPattern<SPV_Float, DivFOp, SPV_FDivOp>;
-def : BinaryOpPattern<SPV_Float, MulFOp, SPV_FMulOp>;
-def : BinaryOpPattern<SPV_Float, RemFOp, SPV_FRemOp>;
-def : BinaryOpPattern<SPV_Float, SubFOp, SPV_FSubOp>;
-
-def : UnaryOpPattern<SPV_Integer, SIToFPOp, SPV_ConvertSToFOp>;
-def : UnaryOpPattern<SPV_Float, FPExtOp, SPV_FConvertOp>;
-def : UnaryOpPattern<SPV_Float, FPTruncOp, SPV_FConvertOp>;
-
 // Constant Op
 // TODO(ravishankarm): Handle lowering other constant types.
 def : Pat<(ConstantOp:$result $valueAttr),

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
index 26e2ea42d3a2..8a53488d33f4 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
@@ -1,112 +1,142 @@
-// RUN: mlir-opt -convert-std-to-spirv %s -o - | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-std-to-spirv %s -o - | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// std arithmetic ops
+//===----------------------------------------------------------------------===//
 
 module attributes {
   spv.target_env = #spv.target_env<
-    #spv.vce<v1.0, [Shader, Int64, Float64], [SPV_KHR_storage_buffer_storage_class]>,
+    #spv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>,
     {max_compute_workgroup_invocations = 128 : i32,
      max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
 } {
 
-//===----------------------------------------------------------------------===//
-// std binary arithmetic ops
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: @add_sub
-func @add_sub(%arg0 : i32, %arg1 : i32) {
-  // CHECK: spv.IAdd
-  %0 = addi %arg0, %arg1 : i32
-  // CHECK: spv.ISub
-  %1 = subi %arg0, %arg1 : i32
+// Check integer operation conversions.
+// CHECK-LABEL: @int32_scalar
+func @int32_scalar(%lhs: i32, %rhs: i32) {
+  // CHECK: spv.IAdd %{{.*}}, %{{.*}}: i32
+  %0 = addi %lhs, %rhs: i32
+  // CHECK: spv.ISub %{{.*}}, %{{.*}}: i32
+  %1 = subi %lhs, %rhs: i32
+  // CHECK: spv.IMul %{{.*}}, %{{.*}}: i32
+  %2 = muli %lhs, %rhs: i32
+  // CHECK: spv.SDiv %{{.*}}, %{{.*}}: i32
+  %3 = divi_signed %lhs, %rhs: i32
+  // CHECK: spv.SRem %{{.*}}, %{{.*}}: i32
+  %4 = remi_signed %lhs, %rhs: i32
+  // CHECK: spv.UDiv %{{.*}}, %{{.*}}: i32
+  %5 = divi_unsigned %lhs, %rhs: i32
+  // CHECK: spv.UMod %{{.*}}, %{{.*}}: i32
+  %6 = remi_unsigned %lhs, %rhs: i32
   return
 }
 
-// CHECK-LABEL: @fadd_scalar
-func @fadd_scalar(%arg: f32) {
-  // CHECK: spv.FAdd
-  %0 = addf %arg, %arg : f32
+// Check float operation conversions.
+// CHECK-LABEL: @float32_scalar
+func @float32_scalar(%lhs: f32, %rhs: f32) {
+  // CHECK: spv.FAdd %{{.*}}, %{{.*}}: f32
+  %0 = addf %lhs, %rhs: f32
+  // CHECK: spv.FSub %{{.*}}, %{{.*}}: f32
+  %1 = subf %lhs, %rhs: f32
+  // CHECK: spv.FMul %{{.*}}, %{{.*}}: f32
+  %2 = mulf %lhs, %rhs: f32
+  // CHECK: spv.FDiv %{{.*}}, %{{.*}}: f32
+  %3 = divf %lhs, %rhs: f32
+  // CHECK: spv.FRem %{{.*}}, %{{.*}}: f32
+  %4 = remf %lhs, %rhs: f32
   return
 }
 
-// CHECK-LABEL: @fdiv_scalar
-func @fdiv_scalar(%arg: f32) {
-  // CHECK: spv.FDiv
-  %0 = divf %arg, %arg : f32
+// Check int vector types.
+// CHECK-LABEL: @int_vector234
+func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<3xi16>, %arg2: vector<4xi64>) {
+  // CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<2xi8>
+  %0 = divi_signed %arg0, %arg0: vector<2xi8>
+  // CHECK: spv.SRem %{{.*}}, %{{.*}}: vector<3xi16>
+  %1 = remi_signed %arg1, %arg1: vector<3xi16>
+  // CHECK: spv.UDiv %{{.*}}, %{{.*}}: vector<4xi64>
+  %2 = divi_unsigned %arg2, %arg2: vector<4xi64>
   return
 }
 
-// CHECK-LABEL: @fmul_scalar
-func @fmul_scalar(%arg: f32) {
-  // CHECK: spv.FMul
-  %0 = mulf %arg, %arg : f32
+// Check float vector types.
+// CHECK-LABEL: @float_vector234
+func @float_vector234(%arg0: vector<2xf16>, %arg1: vector<3xf64>) {
+  // CHECK: spv.FAdd %{{.*}}, %{{.*}}: vector<2xf16>
+  %0 = addf %arg0, %arg0: vector<2xf16>
+  // CHECK: spv.FMul %{{.*}}, %{{.*}}: vector<3xf64>
+  %1 = mulf %arg1, %arg1: vector<3xf64>
   return
 }
 
-// CHECK-LABEL: @fmul_vector2
-func @fmul_vector2(%arg: vector<2xf32>) {
-  // CHECK: spv.FMul
-  %0 = mulf %arg, %arg : vector<2xf32>
+// CHECK-LABEL: @unsupported_1elem_vector
+func @unsupported_1elem_vector(%arg0: vector<1xi32>) {
+  // CHECK: addi
+  %0 = addi %arg0, %arg0: vector<1xi32>
   return
 }
 
-// CHECK-LABEL: @fmul_vector3
-func @fmul_vector3(%arg: vector<3xf32>) {
-  // CHECK: spv.FMul
-  %0 = mulf %arg, %arg : vector<3xf32>
+// CHECK-LABEL: @unsupported_5elem_vector
+func @unsupported_5elem_vector(%arg0: vector<5xi32>) {
+  // CHECK: subi
+  %1 = subi %arg0, %arg0: vector<5xi32>
   return
 }
 
-// CHECK-LABEL: @fmul_vector4
-func @fmul_vector4(%arg: vector<4xf32>) {
-  // CHECK: spv.FMul
-  %0 = mulf %arg, %arg : vector<4xf32>
+// CHECK-LABEL: @unsupported_2x2elem_vector
+func @unsupported_2x2elem_vector(%arg0: vector<2x2xi32>) {
+  // CHECK: muli
+  %2 = muli %arg0, %arg0: vector<2x2xi32>
   return
 }
 
-// CHECK-LABEL: @fmul_vector5
-func @fmul_vector5(%arg: vector<5xf32>) {
-  // Vector length of only 2, 3, and 4 is valid for SPIR-V.
-  // CHECK: mulf
-  %0 = mulf %arg, %arg : vector<5xf32>
-  return
-}
+} // end module
 
-// TODO(antiagainst): enable this once we support converting binary ops
-// needing type conversion.
-// XXXXX-LABEL: @fmul_tensor
-//func @fmul_tensor(%arg: tensor<4xf32>) {
-  // For tensors mulf cannot be lowered directly to spv.FMul.
-  // XXXXX: mulf
-  //%0 = mulf %arg, %arg : tensor<4xf32>
-  //return
-//}
-
-// CHECK-LABEL: @frem_scalar
-func @frem_scalar(%arg: f32) {
-  // CHECK: spv.FRem
-  %0 = remf %arg, %arg : f32
-  return
-}
+// -----
+
+// Check that types are converted to 32-bit when no special capabilities.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
 
-// CHECK-LABEL: @fsub_scalar
-func @fsub_scalar(%arg: f32) {
-  // CHECK: spv.FSub
-  %0 = subf %arg, %arg : f32
+// CHECK-LABEL: @int_vector234
+func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<3xi16>, %arg2: vector<4xi64>) {
+  // CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<2xi32>
+  %0 = divi_signed %arg0, %arg0: vector<2xi8>
+  // CHECK: spv.SRem %{{.*}}, %{{.*}}: vector<3xi32>
+  %1 = remi_signed %arg1, %arg1: vector<3xi16>
+  // CHECK: spv.UDiv %{{.*}}, %{{.*}}: vector<4xi32>
+  %2 = divi_unsigned %arg2, %arg2: vector<4xi64>
   return
 }
 
-// CHECK-LABEL: @div_rem
-func @div_rem(%arg0 : i32, %arg1 : i32) {
-  // CHECK: spv.SDiv
-  %0 = divi_signed %arg0, %arg1 : i32
-  // CHECK: spv.SMod
-  %1 = remi_signed %arg0, %arg1 : i32
+// CHECK-LABEL: @float_scalar
+func @float_scalar(%arg0: f16, %arg1: f64) {
+  // CHECK: spv.FAdd %{{.*}}, %{{.*}}: f32
+  %0 = addf %arg0, %arg0: f16
+  // CHECK: spv.FMul %{{.*}}, %{{.*}}: f32
+  %1 = mulf %arg1, %arg1: f64
   return
 }
 
+} // end module
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // std bit ops
 //===----------------------------------------------------------------------===//
 
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
 // CHECK-LABEL: @bitwise_scalar
 func @bitwise_scalar(%arg0 : i32, %arg1 : i32) {
   // CHECK: spv.BitwiseAnd
@@ -129,6 +159,24 @@ func @bitwise_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
   return
 }
 
+// CHECK-LABEL: @logical_scalar
+func @logical_scalar(%arg0 : i1, %arg1 : i1) {
+  // CHECK: spv.LogicalAnd
+  %0 = and %arg0, %arg1 : i1
+  // CHECK: spv.LogicalOr
+  %1 = or %arg0, %arg1 : i1
+  return
+}
+
+// CHECK-LABEL: @logical_vector
+func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
+  // CHECK: spv.LogicalAnd
+  %0 = and %arg0, %arg1 : vector<4xi1>
+  // CHECK: spv.LogicalOr
+  %1 = or %arg0, %arg1 : vector<4xi1>
+  return
+}
+
 // CHECK-LABEL: @shift_scalar
 func @shift_scalar(%arg0 : i32, %arg1 : i32) {
   // CHECK: spv.ShiftLeftLogical
@@ -213,10 +261,21 @@ func @cmpi(%arg0 : i32, %arg1 : i32) {
   return
 }
 
+} // end module
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // std.constant
 //===----------------------------------------------------------------------===//
 
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
 // CHECK-LABEL: @constant
 func @constant() {
   // CHECK: spv.constant true
@@ -244,50 +303,126 @@ func @constant() {
   return
 }
 
+} // end module
+
+// -----
+
 //===----------------------------------------------------------------------===//
-// std logical binary operations
+// std cast ops
 //===----------------------------------------------------------------------===//
 
-// CHECK-LABEL: @logical_scalar
-func @logical_scalar(%arg0 : i1, %arg1 : i1) {
-  // CHECK: spv.LogicalAnd
-  %0 = and %arg0, %arg1 : i1
-  // CHECK: spv.LogicalOr
-  %1 = or %arg0, %arg1 : i1
-  return
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: @fpext1
+func @fpext1(%arg0: f16) -> f64 {
+  // CHECK: spv.FConvert %{{.*}} : f16 to f64
+  %0 = std.fpext %arg0 : f16 to f64
+  return %0 : f64
 }
 
-// CHECK-LABEL: @logical_vector
-func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
-  // CHECK: spv.LogicalAnd
-  %0 = and %arg0, %arg1 : vector<4xi1>
-  // CHECK: spv.LogicalOr
-  %1 = or %arg0, %arg1 : vector<4xi1>
-  return
+// CHECK-LABEL: @fpext2
+func @fpext2(%arg0 : f32) -> f64 {
+  // CHECK: spv.FConvert %{{.*}} : f32 to f64
+  %0 = std.fpext %arg0 : f32 to f64
+  return %0 : f64
 }
 
-//===----------------------------------------------------------------------===//
-// std.fpext
-//===----------------------------------------------------------------------===//
+// CHECK-LABEL: @fptrunc1
+func @fptrunc1(%arg0 : f64) -> f16 {
+  // CHECK: spv.FConvert %{{.*}} : f64 to f16
+  %0 = std.fptrunc %arg0 : f64 to f16
+  return %0 : f16
+}
+
+// CHECK-LABEL: @fptrunc2
+func @fptrunc2(%arg0: f32) -> f16 {
+  // CHECK: spv.FConvert %{{.*}} : f32 to f16
+  %0 = std.fptrunc %arg0 : f32 to f16
+  return %0 : f16
+}
+
+// CHECK-LABEL: @sitofp1
+func @sitofp1(%arg0 : i32) -> f32 {
+  // CHECK: spv.ConvertSToF %{{.*}} : i32 to f32
+  %0 = std.sitofp %arg0 : i32 to f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: @sitofp2
+func @sitofp2(%arg0 : i64) -> f64 {
+  // CHECK: spv.ConvertSToF %{{.*}} : i64 to f64
+  %0 = std.sitofp %arg0 : i64 to f64
+  return %0 : f64
+}
+
+} // end module
+
+// -----
+
+// Checks that cast types will be adjusted when no special capabilities for
+// non-32-bit scalar types.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
 
-// CHECK-LABEL: @fpext
-func @fpext(%arg0 : f32) {
-  // CHECK: spv.FConvert
+// CHECK-LABEL: @fpext1
+// CHECK-SAME: %[[ARG:.*]]: f32
+func @fpext1(%arg0: f16) {
+  // CHECK-NEXT: "use"(%[[ARG]])
+  %0 = std.fpext %arg0 : f16 to f64
+  "use"(%0) : (f64) -> ()
+}
+
+// CHECK-LABEL: @fpext2
+// CHECK-SAME: %[[ARG:.*]]: f32
+func @fpext2(%arg0 : f32) {
+  // CHECK-NEXT: "use"(%[[ARG]])
   %0 = std.fpext %arg0 : f32 to f64
-  return
+  "use"(%0) : (f64) -> ()
 }
 
-//===----------------------------------------------------------------------===//
-// std.fptrunc
-//===----------------------------------------------------------------------===//
+// CHECK-LABEL: @fptrunc1
+// CHECK-SAME: %[[ARG:.*]]: f32
+func @fptrunc1(%arg0 : f64) {
+  // CHECK-NEXT: "use"(%[[ARG]])
+  %0 = std.fptrunc %arg0 : f64 to f16
+  "use"(%0) : (f16) -> ()
+}
 
-// CHECK-LABEL: @fptrunc
-func @fptrunc(%arg0 : f64) {
-  // CHECK: spv.FConvert
-  %0 = std.fptrunc %arg0 : f64 to f32
-  return
+// CHECK-LABEL: @fptrunc2
+// CHECK-SAME: %[[ARG:.*]]: f32
+func @fptrunc2(%arg0: f32) {
+  // CHECK-NEXT: "use"(%[[ARG]])
+  %0 = std.fptrunc %arg0 : f32 to f16
+  "use"(%0) : (f16) -> ()
+}
+
+// CHECK-LABEL: @sitofp
+func @sitofp(%arg0 : i64) {
+  // CHECK: spv.ConvertSToF %{{.*}} : i32 to f32
+  %0 = std.sitofp %arg0 : i64 to f64
+  "use"(%0) : (f64) -> ()
 }
 
+} // end module
+
+// -----
+
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader, Int8, Int16, Int64, Float16, Float64], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
 //===----------------------------------------------------------------------===//
 // std.select
 //===----------------------------------------------------------------------===//
@@ -301,41 +436,9 @@ func @select(%arg0 : i32, %arg1 : i32) {
 }
 
 //===----------------------------------------------------------------------===//
-// std.sitofp
+// std load/store ops
 //===----------------------------------------------------------------------===//
 
-// CHECK-LABEL: @sitofp
-func @sitofp(%arg0 : i32) {
-  // CHECK: spv.ConvertSToF
-  %0 = std.sitofp %arg0 : i32 to f32
-  return
-}
-
-//===----------------------------------------------------------------------===//
-// memref type
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: func @memref_type({{%.*}}: memref<3xi1>)
-func @memref_type(%arg0: memref<3xi1>) {
-  return
-}
-
-// CHECK-LABEL: func @memref_mem_space
-// CHECK-SAME: StorageBuffer
-// CHECK-SAME: Uniform
-// CHECK-SAME: Workgroup
-// CHECK-SAME: PushConstant
-// CHECK-SAME: Private
-// CHECK-SAME: Function
-func @memref_mem_space(
-    %arg0: memref<4xf32, 0>,
-    %arg1: memref<4xf32, 4>,
-    %arg2: memref<4xf32, 3>,
-    %arg3: memref<4xf32, 7>,
-    %arg4: memref<4xf32, 5>,
-    %arg5: memref<4xf32, 6>
-) { return }
-
 // CHECK-LABEL: @load_store_zero_rank_float
 // CHECK: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>,
 // CHECK: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<1 x f32 [4]> [0]>, StorageBuffer>)

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
index a88678fd34ac..81911bd1a633 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
@@ -255,6 +255,51 @@ func @large_vector(%arg0: vector<1024xi32>) { return }
 // MemRef types
 //===----------------------------------------------------------------------===//
 
+// Check memory spaces.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: func @memref_mem_space
+// CHECK-SAME: StorageBuffer
+// CHECK-SAME: Uniform
+// CHECK-SAME: Workgroup
+// CHECK-SAME: PushConstant
+// CHECK-SAME: Private
+// CHECK-SAME: Function
+func @memref_mem_space(
+    %arg0: memref<4xf32, 0>,
+    %arg1: memref<4xf32, 4>,
+    %arg2: memref<4xf32, 3>,
+    %arg3: memref<4xf32, 7>,
+    %arg4: memref<4xf32, 5>,
+    %arg5: memref<4xf32, 6>
+) { return }
+
+} // end module
+
+// -----
+
+// Check that boolean memref is not supported at the moment.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LABEL: func @memref_type({{%.*}}: memref<3xi1>)
+func @memref_type(%arg0: memref<3xi1>) {
+  return
+}
+
+} // end module
+
+// -----
+
 // Check that using non-32-bit scalar types in interface storage classes
 // requires special capability and extension: convert them to 32-bit if not
 // satisfied.


        


More information about the Mlir-commits mailing list