[Mlir-commits] [mlir] 91acb5b - Add rsqrt op to Standard dialect and lower it to LLVM dialect.

Adrian Kuegel llvmlistbot at llvm.org
Wed Mar 4 04:14:01 PST 2020


Author: Adrian Kuegel
Date: 2020-03-04T13:13:31+01:00
New Revision: 91acb5b3e1c372895f7f6fa9f5cf95bf80c2ae0b

URL: https://github.com/llvm/llvm-project/commit/91acb5b3e1c372895f7f6fa9f5cf95bf80c2ae0b
DIFF: https://github.com/llvm/llvm-project/commit/91acb5b3e1c372895f7f6fa9f5cf95bf80c2ae0b.diff

LOG: Add rsqrt op to Standard dialect and lower it to LLVM dialect.

Summary:
This adds an rsqrt op to the standard dialect, and lowers
it as 1 / sqrt to the LLVM dialect.

Differential Revision: https://reviews.llvm.org/D75353

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
    mlir/test/IR/core-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 851a6434a9a1..08a81b385bf3 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1122,6 +1122,16 @@ def RemFOp : FloatArithmeticOp<"remf"> {
   let summary = "floating point division remainder operation";
 }
 
+def RsqrtOp : FloatUnaryOp<"rsqrt"> {
+  let summary = "reciprocal of sqrt (1 / sqrt of the specified value)";
+  let description = [{
+    The `rsqrt` operation computes the reciprocal of the square root. It takes
+    one operand and returns one result of the same type. This type may be a
+    float scalar type, a vector whose element type is float, or a tensor of
+    floats. It has no standard attributes.
+  }];
+}
+
 def SignedRemIOp : IntArithmeticOp<"remi_signed"> {
   let summary = "signed integer division remainder operation";
   let hasFolder = 1;

diff  --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 96d7e82d6a32..9f728678911e 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -16,10 +16,12 @@
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/Functional.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -1662,6 +1664,74 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
   bool useAlloca;
 };
 
+// A `rsqrt` is converted into `1 / sqrt`.
+struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
+  using LLVMLegalizationPattern<RsqrtOp>::LLVMLegalizationPattern;
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    OperandAdaptor<RsqrtOp> transformed(operands);
+    auto operandType =
+        transformed.operand().getType().dyn_cast_or_null<LLVM::LLVMType>();
+
+    if (!operandType)
+      return matchFailure();
+
+    auto loc = op->getLoc();
+    auto resultType = *op->result_type_begin();
+    auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
+    auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
+
+    if (!operandType.isArrayTy()) {
+      LLVM::ConstantOp one;
+      if (operandType.isVectorTy()) {
+        one = rewriter.create<LLVM::ConstantOp>(
+            loc, operandType,
+            SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
+      } else {
+        one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
+      }
+      auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
+      rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
+      return matchSuccess();
+    }
+
+    auto vectorType = resultType.dyn_cast<VectorType>();
+    if (!vectorType)
+      return this->matchFailure();
+
+    auto vectorTypeInfo =
+        extractNDVectorTypeInfo(vectorType, this->typeConverter);
+    auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
+    if (!llvmVectorTy || operandType != vectorTypeInfo.llvmArrayTy)
+      return this->matchFailure();
+
+    Value desc = rewriter.create<LLVM::UndefOp>(loc, operandType);
+    nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
+      // For this unrolled `position` corresponding to the `linearIndex`^th
+      // element, extract operand vectors
+      auto extractedOperand = rewriter.create<LLVM::ExtractValueOp>(
+          loc, llvmVectorTy, operands[0], position);
+      auto splatAttr = SplatElementsAttr::get(
+          mlir::VectorType::get(
+              {llvmVectorTy.getUnderlyingType()->getVectorNumElements()},
+              floatType),
+          floatOne);
+      auto one =
+          rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
+      auto sqrt =
+          rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, extractedOperand);
+      auto div = rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one, sqrt);
+      desc = rewriter.create<LLVM::InsertValueOp>(loc, operandType, desc, div,
+                                                  position);
+    });
+    rewriter.replaceOp(op, desc);
+
+    return matchSuccess();
+  }
+};
+
 // A `tanh` is converted into a call to the `tanh` function.
 struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
   using LLVMLegalizationPattern<TanhOp>::LLVMLegalizationPattern;
@@ -2806,6 +2876,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
       PrefetchOpLowering,
       RemFOpLowering,
       ReturnOpLowering,
+      RsqrtOpLowering,
       SIToFPLowering,
       SelectOpLowering,
       ShiftLeftOpLowering,

diff  --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index fd574d7e67d9..6aec95931484 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -18,6 +18,56 @@ func @strided_memref(%ind: index) {
 
 // -----
 
+// CHECK-LABEL: func @rsqrt(
+// CHECK-SAME: !llvm.float
+func @rsqrt(%arg0 : f32) {
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float
+  // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float
+  // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm.float
+  %0 = rsqrt %arg0 : f32
+  std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt_double(
+// CHECK-SAME: !llvm.double
+func @rsqrt_double(%arg0 : f64) {
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f64) : !llvm.double
+  // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (!llvm.double) -> !llvm.double
+  // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm.double
+  %0 = rsqrt %arg0 : f64
+  std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt_vector(
+// CHECK-SAME: !llvm<"<4 x float>">
+func @rsqrt_vector(%arg0 : vector<4xf32>) {
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : !llvm<"<4 x float>">
+  // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (!llvm<"<4 x float>">) -> !llvm<"<4 x float>">
+  // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm<"<4 x float>">
+  %0 = rsqrt %arg0 : vector<4xf32>
+  std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt_multidim_vector(
+// CHECK-SAME: !llvm<"[4 x <3 x float>]">
+func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) {
+  // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %arg0[0] : !llvm<"[4 x <3 x float>]">
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<3xf32>) : !llvm<"<3 x float>">
+  // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%[[EXTRACT]]) : (!llvm<"<3 x float>">) -> !llvm<"<3 x float>">
+  // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm<"<3 x float>">
+  // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[DIV]], %0[0] : !llvm<"[4 x <3 x float>]">
+  %0 = rsqrt %arg0 : vector<4x3xf32>
+  std.return
+}
+
+// -----
+
 // This should not crash. The first operation cannot be converted, so the
 // second should not match. This attempts to convert `return` to `llvm.return`
 // and complains about non-LLVM types.

diff  --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 28ef33edba06..e2f6f57ba0a1 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -512,6 +512,9 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
   // CHECK: = fptrunc {{.*}} : vector<4xf32> to vector<4xf16>
   %144 = fptrunc %vcf32 : vector<4xf32> to vector<4xf16>
 
+  // CHECK: %{{[0-9]+}} = rsqrt %arg1 : f32
+  %145 = rsqrt %f : f32
+
   return
 }
 


        


More information about the Mlir-commits mailing list