[Mlir-commits] [mlir] 9b5a3d1 - [mlir][vector] Add helper that builds a scalar reduction according to CombiningKind

Matthias Springer llvmlistbot at llvm.org
Thu Feb 10 05:35:52 PST 2022


Author: Matthias Springer
Date: 2022-02-10T22:35:43+09:00
New Revision: 9b5a3d14b2c3bed697f3f5c873bef82bead27818

URL: https://github.com/llvm/llvm-project/commit/9b5a3d14b2c3bed697f3f5c873bef82bead27818
DIFF: https://github.com/llvm/llvm-project/commit/9b5a3d14b2c3bed697f3f5c873bef82bead27818.diff

LOG: [mlir][vector] Add helper that builds a scalar reduction according to CombiningKind

Differential Revision: https://reviews.llvm.org/D119433

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
    mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index f06ef2bc38ad..f6b84f1e28cd 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
 #define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
 
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Support/LLVM.h"
 
@@ -30,12 +31,14 @@ class VectorType;
 class VectorTransferOpInterface;
 
 namespace vector {
-class TransferWriteOp;
-class TransferReadOp;
-
 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
 /// the type of `source`.
 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
+
+/// Return the result value of reducing two scalar/vector values with the
+/// corresponding arith operation.
+Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
+                         Value v1, Value v2);
 } // namespace vector
 
 /// Return the number of elements of basis, `0` if empty.

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
index 1d37a3345827..db5c667a4935 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
@@ -243,47 +243,8 @@ struct TwoDimMultiReductionToElementWise
     for (int64_t i = 1; i < srcShape[0]; i++) {
       auto operand =
           rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), i);
-      switch (multiReductionOp.kind()) {
-      case vector::CombiningKind::ADD:
-        if (elementType.isIntOrIndex())
-          result = rewriter.create<arith::AddIOp>(loc, operand, result);
-        else
-          result = rewriter.create<arith::AddFOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::MUL:
-        if (elementType.isIntOrIndex())
-          result = rewriter.create<arith::MulIOp>(loc, operand, result);
-        else
-          result = rewriter.create<arith::MulFOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::MINUI:
-        result = rewriter.create<arith::MinUIOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::MINSI:
-        result = rewriter.create<arith::MinSIOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::MINF:
-        result = rewriter.create<arith::MinFOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::MAXUI:
-        result = rewriter.create<arith::MaxUIOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::MAXSI:
-        result = rewriter.create<arith::MaxSIOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::MAXF:
-        result = rewriter.create<arith::MaxFOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::AND:
-        result = rewriter.create<arith::AndIOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::OR:
-        result = rewriter.create<arith::OrIOp>(loc, operand, result);
-        break;
-      case vector::CombiningKind::XOR:
-        result = rewriter.create<arith::XOrIOp>(loc, operand, result);
-        break;
-      }
+      result = makeArithReduction(rewriter, loc, multiReductionOp.kind(),
+                                  operand, result);
     }
 
     rewriter.replaceOp(multiReductionOp, result);

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 155beb8d91fd..226faccaba96 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -10,6 +10,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+
 #include <type_traits>
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -18,8 +20,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
-
-#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
@@ -514,40 +515,11 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
     if (!acc)
       return Optional<Value>(mul);
 
-    Value combinedResult;
-    switch (kind) {
-    case CombiningKind::ADD:
-      combinedResult = rewriter.create<arith::AddIOp>(loc, mul, acc);
-      break;
-    case CombiningKind::MUL:
-      combinedResult = rewriter.create<arith::MulIOp>(loc, mul, acc);
-      break;
-    case CombiningKind::MINUI:
-      combinedResult = rewriter.create<arith::MinUIOp>(loc, mul, acc);
-      break;
-    case CombiningKind::MINSI:
-      combinedResult = rewriter.create<arith::MinSIOp>(loc, mul, acc);
-      break;
-    case CombiningKind::MAXUI:
-      combinedResult = rewriter.create<arith::MaxUIOp>(loc, mul, acc);
-      break;
-    case CombiningKind::MAXSI:
-      combinedResult = rewriter.create<arith::MaxSIOp>(loc, mul, acc);
-      break;
-    case CombiningKind::AND:
-      combinedResult = rewriter.create<arith::AndIOp>(loc, mul, acc);
-      break;
-    case CombiningKind::OR:
-      combinedResult = rewriter.create<arith::OrIOp>(loc, mul, acc);
-      break;
-    case CombiningKind::XOR:
-      combinedResult = rewriter.create<arith::XOrIOp>(loc, mul, acc);
-      break;
-    case CombiningKind::MINF: // Only valid for floating point types.
-    case CombiningKind::MAXF: // Only valid for floating point types.
+    if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
+      // Only valid for floating point types.
       return Optional<Value>();
-    }
-    return Optional<Value>(combinedResult);
+
+    return makeArithReduction(rewriter, loc, kind, mul, acc);
   }
 
   static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc,
@@ -565,28 +537,14 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
     if (!acc)
       return Optional<Value>(mul);
 
-    Value combinedResult;
-    switch (kind) {
-    case CombiningKind::MUL:
-      combinedResult = rewriter.create<arith::MulFOp>(loc, mul, acc);
-      break;
-    case CombiningKind::MINF:
-      combinedResult = rewriter.create<arith::MinFOp>(loc, mul, acc);
-      break;
-    case CombiningKind::MAXF:
-      combinedResult = rewriter.create<arith::MaxFOp>(loc, mul, acc);
-      break;
-    case CombiningKind::ADD:   // Already handled this special case above.
-    case CombiningKind::AND:   // Only valid for integer types.
-    case CombiningKind::MINUI: // Only valid for integer types.
-    case CombiningKind::MINSI: // Only valid for integer types.
-    case CombiningKind::MAXUI: // Only valid for integer types.
-    case CombiningKind::MAXSI: // Only valid for integer types.
-    case CombiningKind::OR:    // Only valid for integer types.
-    case CombiningKind::XOR:   // Only valid for integer types.
+    if (kind == CombiningKind::ADD || kind == CombiningKind::AND ||
+        kind == CombiningKind::MINUI || kind == CombiningKind::MINSI ||
+        kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI ||
+        kind == CombiningKind::OR || kind == CombiningKind::XOR)
+      // Already handled or only valid for integer types.
       return Optional<Value>();
-    }
-    return Optional<Value>(combinedResult);
+
+    return makeArithReduction(rewriter, loc, kind, mul, acc);
   }
 };
 

diff  --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 800e19ceda71..5388b7bd6ca1 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -22,6 +22,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/MathExtras.h"
 #include <numeric>
@@ -42,6 +43,56 @@ Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
   llvm_unreachable("Expected MemRefType or TensorType");
 }
 
+Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
+                                       CombiningKind kind, Value v1, Value v2) {
+  Type t1 = getElementTypeOrSelf(v1.getType());
+  Type t2 = getElementTypeOrSelf(v2.getType());
+  switch (kind) {
+  case CombiningKind::ADD:
+    if (t1.isIntOrIndex() && t2.isIntOrIndex())
+      return b.createOrFold<arith::AddIOp>(loc, v1, v2);
+    else if (t1.isa<FloatType>() && t2.isa<FloatType>())
+      return b.createOrFold<arith::AddFOp>(loc, v1, v2);
+    llvm_unreachable("invalid value types for ADD reduction");
+  case CombiningKind::AND:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::AndIOp>(loc, v1, v2);
+  case CombiningKind::MAXF:
+    assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
+           "expected float values");
+    return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
+  case CombiningKind::MINF:
+    assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
+           "expected float values");
+    return b.createOrFold<arith::MinFOp>(loc, v1, v2);
+  case CombiningKind::MAXSI:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
+  case CombiningKind::MINSI:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
+  case CombiningKind::MAXUI:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
+  case CombiningKind::MINUI:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
+  case CombiningKind::MUL:
+    if (t1.isIntOrIndex() && t2.isIntOrIndex())
+      return b.createOrFold<arith::MulIOp>(loc, v1, v2);
+    else if (t1.isa<FloatType>() && t2.isa<FloatType>())
+      return b.createOrFold<arith::MulFOp>(loc, v1, v2);
+    llvm_unreachable("invalid value types for MUL reduction");
+  case CombiningKind::OR:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::OrIOp>(loc, v1, v2);
+  case CombiningKind::XOR:
+    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
+    return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
+  };
+  llvm_unreachable("unknown CombiningKind");
+}
+
 /// Return the number of elements of basis, `0` if empty.
 int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
   if (basis.empty())


        


More information about the Mlir-commits mailing list