[Mlir-commits] [mlir] 86306df - Extract common code to deal with multidimensional vectors.

Adrian Kuegel llvmlistbot at llvm.org
Fri Mar 6 04:55:30 PST 2020


Author: Adrian Kuegel
Date: 2020-03-06T13:54:54+01:00
New Revision: 86306df7dd2a8e60d88c6306956080b53ac95589

URL: https://github.com/llvm/llvm-project/commit/86306df7dd2a8e60d88c6306956080b53ac95589
DIFF: https://github.com/llvm/llvm-project/commit/86306df7dd2a8e60d88c6306956080b53ac95589.diff

LOG: Extract common code to deal with multidimensional vectors.

Summary: Also replace dyn_cast_or_null with dyn_cast when possible.

Differential Revision: https://reviews.llvm.org/D75733

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index e38322943cfe..c7c479bf6779 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -24,6 +24,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/Functional.h"
+#include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/Passes.h"
 #include "mlir/Transforms/Utils.h"
@@ -32,6 +33,7 @@
 #include "llvm/IR/Type.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/FormatVariadic.h"
+#include <functional>
 
 using namespace mlir;
 
@@ -1165,6 +1167,36 @@ void ValidateOpCount() {
   OpCountValidator<SourceOp, OpCount>();
 }
 
+static LogicalResult HandleMultidimensionalVectors(
+    Operation *op, ArrayRef<Value> operands, LLVMTypeConverter &typeConverter,
+    std::function<Value(LLVM::LLVMType, ValueRange)> createOperand,
+    ConversionPatternRewriter &rewriter) {
+  auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
+  if (!vectorType)
+    return failure();
+  auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, typeConverter);
+  auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
+  auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
+  if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
+    return failure();
+
+  auto loc = op->getLoc();
+  Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
+  nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
+    // For this unrolled `position` corresponding to the `linearIndex`^th
+    // element, extract operand vectors
+    SmallVector<Value, 4> extractedOperands;
+    for (auto operand : operands)
+      extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
+          loc, llvmVectorTy, operand, position));
+    Value newVal = createOperand(llvmVectorTy, extractedOperands);
+    desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, newVal,
+                                                position);
+  });
+  rewriter.replaceOp(op, desc);
+  return success();
+}
+
 // Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
 // Ops for N-ary ops with one result. This supports higher-dimensional vector
 // types.
@@ -1192,7 +1224,6 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
         return this->matchFailure();
     }
 
-    auto loc = op->getLoc();
     auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
 
     if (!llvmArrayTy.isArrayTy()) {
@@ -1202,31 +1233,15 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
       return this->matchSuccess();
     }
 
-    auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
-    if (!vectorType)
-      return this->matchFailure();
-    auto vectorTypeInfo =
-        extractNDVectorTypeInfo(vectorType, this->typeConverter);
-    auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
-    if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
-      return this->matchFailure();
-
-    Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy);
-    nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
-      // For this unrolled `position` corresponding to the `linearIndex`^th
-      // element, extract operand vectors
-      SmallVector<Value, OpCount> extractedOperands;
-      for (unsigned i = 0; i < OpCount; ++i) {
-        extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
-            loc, llvmVectorTy, operands[i], position));
-      }
-      Value newVal = rewriter.create<TargetOp>(
-          loc, llvmVectorTy, extractedOperands, op->getAttrs());
-      desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc,
-                                                  newVal, position);
-    });
-    rewriter.replaceOp(op, desc);
-    return this->matchSuccess();
+    if (succeeded(HandleMultidimensionalVectors(
+            op, operands, this->typeConverter,
+            [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
+              return rewriter.create<TargetOp>(op->getLoc(), llvmVectorTy,
+                                               operands, op->getAttrs());
+            },
+            rewriter)))
+      return this->matchSuccess();
+    return this->matchFailure();
   }
 };
 
@@ -1673,7 +1688,7 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
                   ConversionPatternRewriter &rewriter) const override {
     OperandAdaptor<RsqrtOp> transformed(operands);
     auto operandType =
-        transformed.operand().getType().dyn_cast_or_null<LLVM::LLVMType>();
+        transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
 
     if (!operandType)
       return matchFailure();
@@ -1694,41 +1709,31 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern<RsqrtOp> {
       }
       auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, transformed.operand());
       rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
-      return matchSuccess();
+      return this->matchSuccess();
     }
 
     auto vectorType = resultType.dyn_cast<VectorType>();
     if (!vectorType)
       return this->matchFailure();
 
-    auto vectorTypeInfo =
-        extractNDVectorTypeInfo(vectorType, this->typeConverter);
-    auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
-    if (!llvmVectorTy || operandType != vectorTypeInfo.llvmArrayTy)
-      return this->matchFailure();
-
-    Value desc = rewriter.create<LLVM::UndefOp>(loc, operandType);
-    nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) {
-      // For this unrolled `position` corresponding to the `linearIndex`^th
-      // element, extract operand vectors
-      auto extractedOperand = rewriter.create<LLVM::ExtractValueOp>(
-          loc, llvmVectorTy, operands[0], position);
-      auto splatAttr = SplatElementsAttr::get(
-          mlir::VectorType::get(
-              {llvmVectorTy.getUnderlyingType()->getVectorNumElements()},
-              floatType),
-          floatOne);
-      auto one =
-          rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
-      auto sqrt =
-          rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, extractedOperand);
-      auto div = rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one, sqrt);
-      desc = rewriter.create<LLVM::InsertValueOp>(loc, operandType, desc, div,
-                                                  position);
-    });
-    rewriter.replaceOp(op, desc);
-
-    return matchSuccess();
+    if (succeeded(HandleMultidimensionalVectors(
+            op, operands, typeConverter,
+            [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
+              auto splatAttr = SplatElementsAttr::get(
+                  mlir::VectorType::get({llvmVectorTy.getUnderlyingType()
+                                             ->getVectorNumElements()},
+                                        floatType),
+                  floatOne);
+              auto one = rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy,
+                                                           splatAttr);
+              auto sqrt =
+                  rewriter.create<LLVM::SqrtOp>(loc, llvmVectorTy, operands[0]);
+              return rewriter.create<LLVM::FDivOp>(loc, llvmVectorTy, one,
+                                                   sqrt);
+            },
+            rewriter)))
+      return this->matchSuccess();
+    return this->matchFailure();
   }
 };
 
@@ -1745,7 +1750,7 @@ struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
 
     OperandAdaptor<TanhOp> transformed(operands);
     LLVMTypeT operandType =
-        transformed.operand().getType().dyn_cast_or_null<LLVM::LLVMType>();
+        transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
 
     if (!operandType)
       return matchFailure();


        


More information about the Mlir-commits mailing list