[Mlir-commits] [mlir] 91acb5b - Add rsqrt op to Standard dialect and lower it to LLVM dialect.
Adrian Kuegel
llvmlistbot at llvm.org
Wed Mar 4 04:14:01 PST 2020
Author: Adrian Kuegel
Date: 2020-03-04T13:13:31+01:00
New Revision: 91acb5b3e1c372895f7f6fa9f5cf95bf80c2ae0b
URL: https://github.com/llvm/llvm-project/commit/91acb5b3e1c372895f7f6fa9f5cf95bf80c2ae0b
DIFF: https://github.com/llvm/llvm-project/commit/91acb5b3e1c372895f7f6fa9f5cf95bf80c2ae0b.diff
LOG: Add rsqrt op to Standard dialect and lower it to LLVM dialect.
Summary:
This adds an rsqrt op to the standard dialect, and lowers
it as 1 / sqrt to the LLVM dialect.
Differential Revision: https://reviews.llvm.org/D75353
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
mlir/test/IR/core-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 851a6434a9a1..08a81b385bf3 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1122,6 +1122,16 @@ def RemFOp : FloatArithmeticOp<"remf"> {
let summary = "floating point division remainder operation";
}
+def RsqrtOp : FloatUnaryOp<"rsqrt"> {
+ let summary = "reciprocal of sqrt (1 / sqrt of the specified value)";
+ let description = [{
+ The `rsqrt` operation computes the reciprocal of the square root. It takes
+ one operand and returns one result of the same type. This type may be a
+ float scalar type, a vector whose element type is float, or a tensor of
+ floats. It has no standard attributes.
+ }];
+}
+
def SignedRemIOp : IntArithmeticOp<"remi_signed"> {
let summary = "signed integer division remainder operation";
let hasFolder = 1;
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 96d7e82d6a32..9f728678911e 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -16,10 +16,12 @@
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/Functional.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -1662,6 +1664,74 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
bool useAlloca;
};
+// A `rsqrt` is converted into `1 / sqrt`.
+struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
+ using LLVMLegalizationPattern<RsqrtOp>::LLVMLegalizationPattern;
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ OperandAdaptor<RsqrtOp> transformed(operands);
+ auto operandType =
+ transformed.operand().getType().dyn_cast_or_null<LLVM::LLVMType>();
+
+ if (!operandType)
+ return matchFailure();
+
+ auto loc = op->getLoc();
+ auto resultType = *op->result_type_begin();
+ auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
+ auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
+
+ if (!operandType.isArrayTy()) {
+ LLVM::ConstantOp one;
+ if (operandType.isVectorTy()) {
+ one = rewriter.create<LLVM::ConstantOp>(
+ loc, operandType,
+ SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
+ } else {
+ one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
+ }
+ auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
+ rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
+ return matchSuccess();
+ }
+
+ auto vectorType = resultType.dyn_cast<VectorType>();
+ if (!vectorType)
+ return this->matchFailure();
+
+ auto vectorTypeInfo =
+ extractNDVectorTypeInfo(vectorType, this->typeConverter);
+ auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
+ if (!llvmVectorTy || operandType != vectorTypeInfo.llvmArrayTy)
+ return this->matchFailure();
+
+ Value desc = rewriter.create<LLVM::UndefOp>(loc, operandType);
+ nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
+ // For this unrolled `position` corresponding to the `linearIndex`^th
+ // element, extract operand vectors
+ auto extractedOperand = rewriter.create<LLVM::ExtractValueOp>(
+ loc, llvmVectorTy, operands[0], position);
+ auto splatAttr = SplatElementsAttr::get(
+ mlir::VectorType::get(
+ {llvmVectorTy.getUnderlyingType()->getVectorNumElements()},
+ floatType),
+ floatOne);
+ auto one =
+ rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
+ auto sqrt =
+ rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, extractedOperand);
+ auto div = rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one, sqrt);
+ desc = rewriter.create<LLVM::InsertValueOp>(loc, operandType, desc, div,
+ position);
+ });
+ rewriter.replaceOp(op, desc);
+
+ return matchSuccess();
+ }
+};
+
// A `tanh` is converted into a call to the `tanh` function.
struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
using LLVMLegalizationPattern<TanhOp>::LLVMLegalizationPattern;
@@ -2806,6 +2876,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
PrefetchOpLowering,
RemFOpLowering,
ReturnOpLowering,
+ RsqrtOpLowering,
SIToFPLowering,
SelectOpLowering,
ShiftLeftOpLowering,
diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
index fd574d7e67d9..6aec95931484 100644
--- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -18,6 +18,56 @@ func @strided_memref(%ind: index) {
// -----
+// CHECK-LABEL: func @rsqrt(
+// CHECK-SAME: !llvm.float
+func @rsqrt(%arg0 : f32) {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float
+ // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (!llvm.float) -> !llvm.float
+ // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm.float
+ %0 = rsqrt %arg0 : f32
+ std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt_double(
+// CHECK-SAME: !llvm.double
+func @rsqrt_double(%arg0 : f64) {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f64) : !llvm.double
+ // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (!llvm.double) -> !llvm.double
+ // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm.double
+ %0 = rsqrt %arg0 : f64
+ std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt_vector(
+// CHECK-SAME: !llvm<"<4 x float>">
+func @rsqrt_vector(%arg0 : vector<4xf32>) {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : !llvm<"<4 x float>">
+ // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%arg0) : (!llvm<"<4 x float>">) -> !llvm<"<4 x float>">
+ // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm<"<4 x float>">
+ %0 = rsqrt %arg0 : vector<4xf32>
+ std.return
+}
+
+// -----
+
+// CHECK-LABEL: func @rsqrt_multidim_vector(
+// CHECK-SAME: !llvm<"[4 x <3 x float>]">
+func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) {
+ // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %arg0[0] : !llvm<"[4 x <3 x float>]">
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<3xf32>) : !llvm<"<3 x float>">
+ // CHECK: %[[SQRT:.*]] = "llvm.intr.sqrt"(%[[EXTRACT]]) : (!llvm<"<3 x float>">) -> !llvm<"<3 x float>">
+ // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : !llvm<"<3 x float>">
+ // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[DIV]], %0[0] : !llvm<"[4 x <3 x float>]">
+ %0 = rsqrt %arg0 : vector<4x3xf32>
+ std.return
+}
+
+// -----
+
// This should not crash. The first operation cannot be converted, so the
// second should not match. This attempts to convert `return` to `llvm.return`
// and complains about non-LLVM types.
diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir
index 28ef33edba06..e2f6f57ba0a1 100644
--- a/mlir/test/IR/core-ops.mlir
+++ b/mlir/test/IR/core-ops.mlir
@@ -512,6 +512,9 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
// CHECK: = fptrunc {{.*}} : vector<4xf32> to vector<4xf16>
%144 = fptrunc %vcf32 : vector<4xf32> to vector<4xf16>
+ // CHECK: %{{[0-9]+}} = rsqrt %arg1 : f32
+ %145 = rsqrt %f : f32
+
return
}
More information about the Mlir-commits
mailing list