[Mlir-commits] [mlir] 04ed07b - [mlir] StandardToLLVM: clean up conversion patterns for vector operations

Alex Zinenko llvmlistbot at llvm.org
Thu Mar 26 10:24:39 PDT 2020


Author: Alex Zinenko
Date: 2020-03-26T18:24:10+01:00
New Revision: 04ed07bc174149d61c8a4ed131f0838578bdcaa5

URL: https://github.com/llvm/llvm-project/commit/04ed07bc174149d61c8a4ed131f0838578bdcaa5
DIFF: https://github.com/llvm/llvm-project/commit/04ed07bc174149d61c8a4ed131f0838578bdcaa5.diff

LOG: [mlir] StandardToLLVM: clean up conversion patterns for vector operations

Summary:
Provide a public VectorConvertToLLVMPattern utility class to implement
conversions with automatic unrolling of operation on multidimensional vectors
to lists of operations on single-dimensional vectors when lowering to the LLVM
dialect. Drop the template-based check on the number of operands since the
actual implementation does not depend on the operand number anymore. This check
only creates spurious concepts (UnaryOpLowering, BinaryOpLowering, etc).

Differential Revision: https://reviews.llvm.org/D76865

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index 95da9805606b..d2c7d9fb2abd 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -416,6 +416,11 @@ LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
                               ValueRange operands,
                               LLVMTypeConverter &typeConverter,
                               ConversionPatternRewriter &rewriter);
+
+LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
+                                    ValueRange operands,
+                                    LLVMTypeConverter &typeConverter,
+                                    ConversionPatternRewriter &rewriter);
 } // namespace detail
 } // namespace LLVM
 
@@ -441,6 +446,29 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
   }
 };
 
+/// Basic lowering implementation for rewriting from Ops to LLVM Dialect Ops
+/// with one result. This supports higher-dimensional vector types.
+template <typename SourceOp, typename TargetOp>
+class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
+public:
+  using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
+  using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    static_assert(
+        std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
+        "expected single result op");
+    static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
+                                  SourceOp>::value,
+                  "expected same operands and result type");
+    return LLVM::detail::vectorOneToOneRewrite(op, TargetOp::getOperationName(),
+                                               operands, this->typeConverter,
+                                               rewriter);
+  }
+};
+
 /// Derived class that automatically populates legalization information for
 /// 
diff erent LLVM ops.
 class LLVMConversionTarget : public ConversionTarget {

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 474a4f08b9f6..8bc27ab3340e 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1148,9 +1148,10 @@ template <typename SourceOp, unsigned OpCount>
 void ValidateOpCount() {
   OpCountValidator<SourceOp, OpCount>();
 }
+} // namespace
 
-static LogicalResult HandleMultidimensionalVectors(
-    Operation *op, ArrayRef<Value> operands, LLVMTypeConverter &typeConverter,
+static LogicalResult handleMultidimensionalVectors(
+    Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
     std::function<Value(LLVM::LLVMType, ValueRange)> createOperand,
     ConversionPatternRewriter &rewriter) {
   auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
@@ -1179,139 +1180,125 @@ static LogicalResult HandleMultidimensionalVectors(
   return success();
 }
 
-// Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
-// Ops for N-ary ops with one result. This supports higher-dimensional vector
-// types.
-template <typename SourceOp, typename TargetOp, unsigned OpCount>
-struct NaryOpLLVMOpLowering : public ConvertOpToLLVMPattern<SourceOp> {
-  using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
-  using Super = NaryOpLLVMOpLowering<SourceOp, TargetOp, OpCount>;
-
-  // Convert the type of the result to an LLVM type, pass operands as is,
-  // preserve attributes.
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    ValidateOpCount<SourceOp, OpCount>();
-    static_assert(
-        std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
-        "expected single result op");
-    static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
-                                  SourceOp>::value,
-                  "expected same operands and result type");
-
-    // Cannot convert ops if their operands are not of LLVM type.
-    for (Value operand : operands) {
-      if (!operand || !operand.getType().isa<LLVM::LLVMType>())
-        return failure();
-    }
+LogicalResult LLVM::detail::vectorOneToOneRewrite(
+    Operation *op, StringRef targetOp, ValueRange operands,
+    LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
+  assert(!operands.empty());
 
-    auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
+  // Cannot convert ops if their operands are not of LLVM type.
+  if (!llvm::all_of(operands.getTypes(),
+                    [](Type t) { return t.isa<LLVM::LLVMType>(); }))
+    return failure();
 
-    if (!llvmArrayTy.isArrayTy()) {
-      auto newOp = rewriter.create<TargetOp>(
-          op->getLoc(), operands[0].getType(), operands, op->getAttrs());
-      rewriter.replaceOp(op, newOp.getResult());
-      return success();
-    }
+  auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
+  if (!llvmArrayTy.isArrayTy())
+    return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
 
-    if (succeeded(HandleMultidimensionalVectors(
-            op, operands, this->typeConverter,
-            [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
-              return rewriter.create<TargetOp>(op->getLoc(), llvmVectorTy,
-                                               operands, op->getAttrs());
-            },
-            rewriter)))
-      return success();
-    return failure();
-  }
-};
+  auto callback = [op, targetOp, &rewriter](LLVM::LLVMType llvmVectorTy,
+                                            ValueRange operands) {
+    OperationState state(op->getLoc(), targetOp);
+    state.addTypes(llvmVectorTy);
+    state.addOperands(operands);
+    state.addAttributes(op->getAttrs());
+    return rewriter.createOperation(state)->getResult(0);
+  };
 
-template <typename SourceOp, typename TargetOp>
-using UnaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 1>;
-template <typename SourceOp, typename TargetOp>
-using BinaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 2>;
+  return handleMultidimensionalVectors(op, operands, typeConverter, callback,
+                                       rewriter);
+}
 
+namespace {
 // Specific lowerings.
 // FIXME: this should be tablegen'ed.
-struct AbsFOpLowering : public UnaryOpLLVMOpLowering<AbsFOp, LLVM::FAbsOp> {
+struct AbsFOpLowering
+    : public VectorConvertToLLVMPattern<AbsFOp, LLVM::FAbsOp> {
   using Super::Super;
 };
-struct CeilFOpLowering : public UnaryOpLLVMOpLowering<CeilFOp, LLVM::FCeilOp> {
+struct CeilFOpLowering
+    : public VectorConvertToLLVMPattern<CeilFOp, LLVM::FCeilOp> {
   using Super::Super;
 };
-struct CosOpLowering : public UnaryOpLLVMOpLowering<CosOp, LLVM::CosOp> {
+struct CosOpLowering : public VectorConvertToLLVMPattern<CosOp, LLVM::CosOp> {
   using Super::Super;
 };
-struct ExpOpLowering : public UnaryOpLLVMOpLowering<ExpOp, LLVM::ExpOp> {
+struct ExpOpLowering : public VectorConvertToLLVMPattern<ExpOp, LLVM::ExpOp> {
   using Super::Super;
 };
-struct LogOpLowering : public UnaryOpLLVMOpLowering<LogOp, LLVM::LogOp> {
+struct LogOpLowering : public VectorConvertToLLVMPattern<LogOp, LLVM::LogOp> {
   using Super::Super;
 };
-struct Log10OpLowering : public UnaryOpLLVMOpLowering<Log10Op, LLVM::Log10Op> {
+struct Log10OpLowering
+    : public VectorConvertToLLVMPattern<Log10Op, LLVM::Log10Op> {
   using Super::Super;
 };
-struct Log2OpLowering : public UnaryOpLLVMOpLowering<Log2Op, LLVM::Log2Op> {
+struct Log2OpLowering
+    : public VectorConvertToLLVMPattern<Log2Op, LLVM::Log2Op> {
   using Super::Super;
 };
-struct NegFOpLowering : public UnaryOpLLVMOpLowering<NegFOp, LLVM::FNegOp> {
+struct NegFOpLowering
+    : public VectorConvertToLLVMPattern<NegFOp, LLVM::FNegOp> {
   using Super::Super;
 };
-struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> {
+struct AddIOpLowering : public VectorConvertToLLVMPattern<AddIOp, LLVM::AddOp> {
   using Super::Super;
 };
-struct SubIOpLowering : public BinaryOpLLVMOpLowering<SubIOp, LLVM::SubOp> {
+struct SubIOpLowering : public VectorConvertToLLVMPattern<SubIOp, LLVM::SubOp> {
   using Super::Super;
 };
-struct MulIOpLowering : public BinaryOpLLVMOpLowering<MulIOp, LLVM::MulOp> {
+struct MulIOpLowering : public VectorConvertToLLVMPattern<MulIOp, LLVM::MulOp> {
   using Super::Super;
 };
 struct SignedDivIOpLowering
-    : public BinaryOpLLVMOpLowering<SignedDivIOp, LLVM::SDivOp> {
+    : public VectorConvertToLLVMPattern<SignedDivIOp, LLVM::SDivOp> {
   using Super::Super;
 };
-struct SqrtOpLowering : public UnaryOpLLVMOpLowering<SqrtOp, LLVM::SqrtOp> {
+struct SqrtOpLowering
+    : public VectorConvertToLLVMPattern<SqrtOp, LLVM::SqrtOp> {
   using Super::Super;
 };
 struct UnsignedDivIOpLowering
-    : public BinaryOpLLVMOpLowering<UnsignedDivIOp, LLVM::UDivOp> {
+    : public VectorConvertToLLVMPattern<UnsignedDivIOp, LLVM::UDivOp> {
   using Super::Super;
 };
 struct SignedRemIOpLowering
-    : public BinaryOpLLVMOpLowering<SignedRemIOp, LLVM::SRemOp> {
+    : public VectorConvertToLLVMPattern<SignedRemIOp, LLVM::SRemOp> {
   using Super::Super;
 };
 struct UnsignedRemIOpLowering
-    : public BinaryOpLLVMOpLowering<UnsignedRemIOp, LLVM::URemOp> {
+    : public VectorConvertToLLVMPattern<UnsignedRemIOp, LLVM::URemOp> {
   using Super::Super;
 };
-struct AndOpLowering : public BinaryOpLLVMOpLowering<AndOp, LLVM::AndOp> {
+struct AndOpLowering : public VectorConvertToLLVMPattern<AndOp, LLVM::AndOp> {
   using Super::Super;
 };
-struct OrOpLowering : public BinaryOpLLVMOpLowering<OrOp, LLVM::OrOp> {
+struct OrOpLowering : public VectorConvertToLLVMPattern<OrOp, LLVM::OrOp> {
   using Super::Super;
 };
-struct XOrOpLowering : public BinaryOpLLVMOpLowering<XOrOp, LLVM::XOrOp> {
+struct XOrOpLowering : public VectorConvertToLLVMPattern<XOrOp, LLVM::XOrOp> {
   using Super::Super;
 };
-struct AddFOpLowering : public BinaryOpLLVMOpLowering<AddFOp, LLVM::FAddOp> {
+struct AddFOpLowering
+    : public VectorConvertToLLVMPattern<AddFOp, LLVM::FAddOp> {
   using Super::Super;
 };
-struct SubFOpLowering : public BinaryOpLLVMOpLowering<SubFOp, LLVM::FSubOp> {
+struct SubFOpLowering
+    : public VectorConvertToLLVMPattern<SubFOp, LLVM::FSubOp> {
   using Super::Super;
 };
-struct MulFOpLowering : public BinaryOpLLVMOpLowering<MulFOp, LLVM::FMulOp> {
+struct MulFOpLowering
+    : public VectorConvertToLLVMPattern<MulFOp, LLVM::FMulOp> {
   using Super::Super;
 };
-struct DivFOpLowering : public BinaryOpLLVMOpLowering<DivFOp, LLVM::FDivOp> {
+struct DivFOpLowering
+    : public VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp> {
   using Super::Super;
 };
-struct RemFOpLowering : public BinaryOpLLVMOpLowering<RemFOp, LLVM::FRemOp> {
+struct RemFOpLowering
+    : public VectorConvertToLLVMPattern<RemFOp, LLVM::FRemOp> {
   using Super::Super;
 };
 struct CopySignOpLowering
-    : public BinaryOpLLVMOpLowering<CopySignOp, LLVM::CopySignOp> {
+    : public VectorConvertToLLVMPattern<CopySignOp, LLVM::CopySignOp> {
   using Super::Super;
 };
 struct SelectOpLowering
@@ -1695,24 +1682,21 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
     if (!vectorType)
       return failure();
 
-    if (succeeded(HandleMultidimensionalVectors(
-            op, operands, typeConverter,
-            [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
-              auto splatAttr = SplatElementsAttr::get(
-                  mlir::VectorType::get({llvmVectorTy.getUnderlyingType()
-                                             ->getVectorNumElements()},
-                                        floatType),
-                  floatOne);
-              auto one = rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy,
-                                                           splatAttr);
-              auto sqrt =
-                  rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, operands[0]);
-              return rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one,
-                                                   sqrt);
-            },
-            rewriter)))
-      return success();
-    return failure();
+    return handleMultidimensionalVectors(
+        op, operands, typeConverter,
+        [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
+          auto splatAttr = SplatElementsAttr::get(
+              mlir::VectorType::get(
+                  {llvmVectorTy.getUnderlyingType()->getVectorNumElements()},
+                  floatType),
+              floatOne);
+          auto one =
+              rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
+          auto sqrt =
+              rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, operands[0]);
+          return rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one, sqrt);
+        },
+        rewriter);
   }
 };
 


        


More information about the Mlir-commits mailing list