[Mlir-commits] [mlir] 2cb130f - [mlir][StandardToSPIRV] Add support for lowering uitofp to SPIR-V

Hanhan Wang llvmlistbot at llvm.org
Thu Jan 21 22:20:51 PST 2021


Author: Hanhan Wang
Date: 2021-01-21T22:20:32-08:00
New Revision: 2cb130f7661176f2c2eaa7554f2a55863cfc0ed3

URL: https://github.com/llvm/llvm-project/commit/2cb130f7661176f2c2eaa7554f2a55863cfc0ed3
DIFF: https://github.com/llvm/llvm-project/commit/2cb130f7661176f2c2eaa7554f2a55863cfc0ed3.diff

LOG: [mlir][StandardToSPIRV] Add support for lowering uitofp to SPIR-V

- Extend spirv::ConstantOp::getZero/One to handle float, vector of int, and vector of float.
- Refactor ZeroExtendI1Pattern to use getZero/One methods.
- Add one more test for lowering std.zexti which extends vector<4xi1> to vector<4xi64>.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D95120

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 72b8c5811695..95bb0eca4496 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -481,16 +481,32 @@ class ZeroExtendI1Pattern final : public OpConversionPattern<ZeroExtendIOp> {
     auto dstType =
         this->getTypeConverter()->convertType(op.getResult().getType());
     Location loc = op.getLoc();
-    Attribute zeroAttr, oneAttr;
-    if (auto vectorType = dstType.dyn_cast<VectorType>()) {
-      zeroAttr = DenseElementsAttr::get(vectorType, 0);
-      oneAttr = DenseElementsAttr::get(vectorType, 1);
-    } else {
-      zeroAttr = IntegerAttr::get(dstType, 0);
-      oneAttr = IntegerAttr::get(dstType, 1);
-    }
-    Value zero = rewriter.create<ConstantOp>(loc, zeroAttr);
-    Value one = rewriter.create<ConstantOp>(loc, oneAttr);
+    Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+    Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
+    rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
+        op, dstType, operands.front(), one, zero);
+    return success();
+  }
+};
+
+/// Converts std.uitofp to spv.Select if the type of source is i1 or vector of
+/// i1.
+class UIToFPI1Pattern final : public OpConversionPattern<UIToFPOp> {
+public:
+  using OpConversionPattern<UIToFPOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(UIToFPOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcType = operands.front().getType();
+    if (!isBoolScalarOrVector(srcType))
+      return failure();
+
+    auto dstType =
+        this->getTypeConverter()->convertType(op.getResult().getType());
+    Location loc = op.getLoc();
+    Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+    Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
     rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
         op, dstType, operands.front(), one, zero);
     return success();
@@ -1098,8 +1114,10 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
       ReturnOpPattern, SelectOpPattern,
 
       // Type cast patterns
-      ZeroExtendI1Pattern, TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>,
+      UIToFPI1Pattern, ZeroExtendI1Pattern,
+      TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>,
       TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
+      TypeCastingOpPattern<UIToFPOp, spirv::ConvertUToFOp>,
       TypeCastingOpPattern<ZeroExtendIOp, spirv::UConvertOp>,
       TypeCastingOpPattern<TruncateIOp, spirv::SConvertOp>,
       TypeCastingOpPattern<FPToSIOp, spirv::ConvertFToSOp>,

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index c90895197f43..3d99696d6882 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -25,6 +25,8 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Interfaces/CallInterfaces.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/bit.h"
 
@@ -1581,6 +1583,25 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
     return builder.create<spirv::ConstantOp>(
         loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
   }
+  if (auto floatType = type.dyn_cast<FloatType>()) {
+    return builder.create<spirv::ConstantOp>(
+        loc, type, builder.getFloatAttr(floatType, 0.0));
+  }
+  if (auto vectorType = type.dyn_cast<VectorType>()) {
+    Type elemType = vectorType.getElementType();
+    if (elemType.isa<IntegerType>()) {
+      return builder.create<spirv::ConstantOp>(
+          loc, type,
+          DenseElementsAttr::get(vectorType,
+                                 IntegerAttr::get(elemType, 0.0).getValue()));
+    }
+    if (elemType.isa<FloatType>()) {
+      return builder.create<spirv::ConstantOp>(
+          loc, type,
+          DenseFPElementsAttr::get(vectorType,
+                                   FloatAttr::get(elemType, 0.0).getValue()));
+    }
+  }
 
   llvm_unreachable("unimplemented types for ConstantOp::getZero()");
 }
@@ -1595,6 +1616,25 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
     return builder.create<spirv::ConstantOp>(
         loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
   }
+  if (auto floatType = type.dyn_cast<FloatType>()) {
+    return builder.create<spirv::ConstantOp>(
+        loc, type, builder.getFloatAttr(floatType, 1.0));
+  }
+  if (auto vectorType = type.dyn_cast<VectorType>()) {
+    Type elemType = vectorType.getElementType();
+    if (elemType.isa<IntegerType>()) {
+      return builder.create<spirv::ConstantOp>(
+          loc, type,
+          DenseElementsAttr::get(vectorType,
+                                 IntegerAttr::get(elemType, 1.0).getValue()));
+    }
+    if (elemType.isa<FloatType>()) {
+      return builder.create<spirv::ConstantOp>(
+          loc, type,
+          DenseFPElementsAttr::get(vectorType,
+                                   FloatAttr::get(elemType, 1.0).getValue()));
+    }
+  }
 
   llvm_unreachable("unimplemented types for ConstantOp::getOne()");
 }

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 633fdbc03550..252bc3eb5095 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -568,6 +568,58 @@ func @sitofp2(%arg0 : i64) -> f64 {
   return %0 : f64
 }
 
+// CHECK-LABEL: @uitofp_i16_f32
+func @uitofp_i16_f32(%arg0: i16) -> f32 {
+  // CHECK: spv.ConvertUToF %{{.*}} : i16 to f32
+  %0 = std.uitofp %arg0 : i16 to f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: @uitofp_i32_f32
+func @uitofp_i32_f32(%arg0 : i32) -> f32 {
+  // CHECK: spv.ConvertUToF %{{.*}} : i32 to f32
+  %0 = std.uitofp %arg0 : i32 to f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: @uitofp_i1_f32
+func @uitofp_i1_f32(%arg0 : i1) -> f32 {
+  // CHECK: %[[ZERO:.+]] = spv.constant 0.000000e+00 : f32
+  // CHECK: %[[ONE:.+]] = spv.constant 1.000000e+00 : f32
+  // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : i1, f32
+  %0 = std.uitofp %arg0 : i1 to f32
+  return %0 : f32
+}
+
+// CHECK-LABEL: @uitofp_i1_f64
+func @uitofp_i1_f64(%arg0 : i1) -> f64 {
+  // CHECK: %[[ZERO:.+]] = spv.constant 0.000000e+00 : f64
+  // CHECK: %[[ONE:.+]] = spv.constant 1.000000e+00 : f64
+  // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : i1, f64
+  %0 = std.uitofp %arg0 : i1 to f64
+  return %0 : f64
+}
+
+// CHECK-LABEL: @uitofp_vec_i1_f32
+func @uitofp_vec_i1_f32(%arg0 : vector<4xi1>) -> vector<4xf32> {
+  // CHECK: %[[ZERO:.+]] = spv.constant dense<0.000000e+00> : vector<4xf32>
+  // CHECK: %[[ONE:.+]] = spv.constant dense<1.000000e+00> : vector<4xf32>
+  // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xf32>
+  %0 = std.uitofp %arg0 : vector<4xi1> to vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: @uitofp_vec_i1_f64
+spv.func @uitofp_vec_i1_f64(%arg0: vector<4xi1>) -> vector<4xf64> "None" {
+  // CHECK: %[[ZERO:.+]] = spv.constant dense<0.000000e+00> : vector<4xf64>
+  // CHECK: %[[ONE:.+]] = spv.constant dense<1.000000e+00> : vector<4xf64>
+  // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xf64>
+  %0 = spv.constant dense<0.000000e+00> : vector<4xf64>
+  %1 = spv.constant dense<1.000000e+00> : vector<4xf64>
+  %2 = spv.Select %arg0, %1, %0 : vector<4xi1>, vector<4xf64>
+  spv.ReturnValue %2 : vector<4xf64>
+}
+
 // CHECK-LABEL: @zexti1
 func @zexti1(%arg0: i16) -> i64 {
   // CHECK: spv.UConvert %{{.*}} : i16 to i64
@@ -600,6 +652,15 @@ func @zexti4(%arg0 : vector<4xi1>) -> vector<4xi32> {
   return %0 : vector<4xi32>
 }
 
+// CHECK-LABEL: @zexti5
+func @zexti5(%arg0 : vector<4xi1>) -> vector<4xi64> {
+  // CHECK: %[[ZERO:.+]] = spv.constant dense<0> : vector<4xi64>
+  // CHECK: %[[ONE:.+]] = spv.constant dense<1> : vector<4xi64>
+  // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xi64>
+  %0 = std.zexti %arg0 : vector<4xi1> to vector<4xi64>
+  return %0 : vector<4xi64>
+}
+
 // CHECK-LABEL: @trunci1
 func @trunci1(%arg0 : i64) -> i16 {
   // CHECK: spv.SConvert %{{.*}} : i64 to i16


        


More information about the Mlir-commits mailing list