[Mlir-commits] [mlir] 921d91f - [mlir] Support multi-dimensional vectors in MathToLibm conversion.
Adrian Kuegel
llvmlistbot at llvm.org
Tue Nov 16 02:14:04 PST 2021
Author: Adrian Kuegel
Date: 2021-11-16T11:13:52+01:00
New Revision: 921d91f3aca33ab28bdc145c668ccc0726a0ca30
URL: https://github.com/llvm/llvm-project/commit/921d91f3aca33ab28bdc145c668ccc0726a0ca30
DIFF: https://github.com/llvm/llvm-project/commit/921d91f3aca33ab28bdc145c668ccc0726a0ca30.diff
LOG: [mlir] Support multi-dimensional vectors in MathToLibm conversion.
Differential Revision: https://reviews.llvm.org/D113969
Added:
Modified:
mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 9447bed434efc..8026c6b5b41e6 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/PatternMatch.h"
@@ -58,21 +59,23 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
if (!vecType.hasRank())
return failure();
auto shape = vecType.getShape();
- // TODO: support multidimensional vectors
- if (shape.size() != 1)
- return failure();
+ int64_t numElements = vecType.getNumElements();
Value result = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(
vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
- for (auto i = 0; i < shape.front(); ++i) {
+ SmallVector<int64_t> ones(shape.size(), 1);
+ SmallVector<int64_t> strides = computeStrides(shape, ones);
+ for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
+ SmallVector<int64_t> positions = delinearize(strides, linearIndex);
SmallVector<Value> operands;
for (auto input : op->getOperands())
operands.push_back(
- rewriter.create<vector::ExtractElementOp>(loc, input, i));
+ rewriter.create<vector::ExtractOp>(loc, input, positions));
Value scalarOp =
rewriter.create<Op>(loc, vecType.getElementType(), operands);
- result = rewriter.create<vector::InsertElementOp>(loc, scalarOp, result, i);
+ result =
+ rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
}
rewriter.replaceOp(op, {result});
return success();
diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index 51c5d4a235322..003155b0fe44a 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -68,20 +68,39 @@ func @expm1_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector
// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
// CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
-// CHECK: %[[IN0_F32:.*]] = vector.extractelement %[[VAL_0]]{{\[}}%[[C0]] : i32] : vector<2xf32>
+// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : vector<2xf32>
// CHECK: %[[OUT0_F32:.*]] = call @expm1f(%[[IN0_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_8:.*]] = vector.insertelement %[[OUT0_F32]], %[[CVF]]{{\[}}%[[C0]] : i32] : vector<2xf32>
-// CHECK: %[[IN1_F32:.*]] = vector.extractelement %[[VAL_0]]{{\[}}%[[C1]] : i32] : vector<2xf32>
+// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
+// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : vector<2xf32>
// CHECK: %[[OUT1_F32:.*]] = call @expm1f(%[[IN1_F32]]) : (f32) -> f32
-// CHECK: %[[VAL_11:.*]] = vector.insertelement %[[OUT1_F32]], %[[VAL_8]]{{\[}}%[[C1]] : i32] : vector<2xf32>
-// CHECK: %[[IN0_F64:.*]] = vector.extractelement %[[VAL_1]]{{\[}}%[[C0]] : i32] : vector<2xf64>
+// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : vector<2xf64>
// CHECK: %[[OUT0_F64:.*]] = call @expm1(%[[IN0_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_14:.*]] = vector.insertelement %[[OUT0_F64]], %[[CVD]]{{\[}}%[[C0]] : i32] : vector<2xf64>
-// CHECK: %[[IN1_F64:.*]] = vector.extractelement %[[VAL_1]]{{\[}}%[[C1]] : i32] : vector<2xf64>
+// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
+// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : vector<2xf64>
// CHECK: %[[OUT1_F64:.*]] = call @expm1(%[[IN1_F64]]) : (f64) -> f64
-// CHECK: %[[VAL_17:.*]] = vector.insertelement %[[OUT1_F64]], %[[VAL_14]]{{\[}}%[[C1]] : i32] : vector<2xf64>
+// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
// CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
// CHECK: }
+func @expm1_multidim_vec_caller(%float: vector<2x2xf32>) -> (vector<2x2xf32>) {
+ %result = math.expm1 %float : vector<2x2xf32>
+ return %result : vector<2x2xf32>
+}
+// CHECK-LABEL: func @expm1_multidim_vec_caller(
+// CHECK-SAME: %[[VAL:.*]]: vector<2x2xf32>
+// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK: %[[IN0_0_F32:.*]] = vector.extract %[[VAL]][0, 0] : vector<2x2xf32>
+// CHECK: %[[OUT0_0_F32:.*]] = call @expm1f(%[[IN0_0_F32]]) : (f32) -> f32
+// CHECK: %[[VAL_1:.*]] = vector.insert %[[OUT0_0_F32]], %[[CVF]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[IN0_1_F32:.*]] = vector.extract %[[VAL]][0, 1] : vector<2x2xf32>
+// CHECK: %[[OUT0_1_F32:.*]] = call @expm1f(%[[IN0_1_F32]]) : (f32) -> f32
+// CHECK: %[[VAL_2:.*]] = vector.insert %[[OUT0_1_F32]], %[[VAL_1]] [0, 1] : f32 into vector<2x2xf32>
+// CHECK: %[[IN1_0_F32:.*]] = vector.extract %[[VAL]][1, 0] : vector<2x2xf32>
+// CHECK: %[[OUT1_0_F32:.*]] = call @expm1f(%[[IN1_0_F32]]) : (f32) -> f32
+// CHECK: %[[VAL_3:.*]] = vector.insert %[[OUT1_0_F32]], %[[VAL_2]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK: %[[IN1_1_F32:.*]] = vector.extract %[[VAL]][1, 1] : vector<2x2xf32>
+// CHECK: %[[OUT1_1_F32:.*]] = call @expm1f(%[[IN1_1_F32]]) : (f32) -> f32
+// CHECK: %[[VAL_4:.*]] = vector.insert %[[OUT1_1_F32]], %[[VAL_3]] [1, 1] : f32 into vector<2x2xf32>
+// CHECK: return %[[VAL_4]] : vector<2x2xf32>
+// CHECK: }
More information about the Mlir-commits
mailing list