[Mlir-commits] [mlir] 921d91f - [mlir] Support multi-dimensional vectors in MathToLibm conversion.

Adrian Kuegel llvmlistbot at llvm.org
Tue Nov 16 02:14:04 PST 2021


Author: Adrian Kuegel
Date: 2021-11-16T11:13:52+01:00
New Revision: 921d91f3aca33ab28bdc145c668ccc0726a0ca30

URL: https://github.com/llvm/llvm-project/commit/921d91f3aca33ab28bdc145c668ccc0726a0ca30
DIFF: https://github.com/llvm/llvm-project/commit/921d91f3aca33ab28bdc145c668ccc0726a0ca30.diff

LOG: [mlir] Support multi-dimensional vectors in MathToLibm conversion.

Differential Revision: https://reviews.llvm.org/D113969

Added: 
    

Modified: 
    mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
    mlir/test/Conversion/MathToLibm/convert-to-libm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 9447bed434efc..8026c6b5b41e6 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorUtils.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/PatternMatch.h"
 
@@ -58,21 +59,23 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
   if (!vecType.hasRank())
     return failure();
   auto shape = vecType.getShape();
-  // TODO: support multidimensional vectors
-  if (shape.size() != 1)
-    return failure();
+  int64_t numElements = vecType.getNumElements();
 
   Value result = rewriter.create<arith::ConstantOp>(
       loc, DenseElementsAttr::get(
                vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
-  for (auto i = 0; i < shape.front(); ++i) {
+  SmallVector<int64_t> ones(shape.size(), 1);
+  SmallVector<int64_t> strides = computeStrides(shape, ones);
+  for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
+    SmallVector<int64_t> positions = delinearize(strides, linearIndex);
     SmallVector<Value> operands;
     for (auto input : op->getOperands())
       operands.push_back(
-          rewriter.create<vector::ExtractElementOp>(loc, input, i));
+          rewriter.create<vector::ExtractOp>(loc, input, positions));
     Value scalarOp =
         rewriter.create<Op>(loc, vecType.getElementType(), operands);
-    result = rewriter.create<vector::InsertElementOp>(loc, scalarOp, result, i);
+    result =
+        rewriter.create<vector::InsertOp>(loc, scalarOp, result, positions);
   }
   rewriter.replaceOp(op, {result});
   return success();

diff  --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index 51c5d4a235322..003155b0fe44a 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -68,20 +68,39 @@ func @expm1_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector
 // CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
 // CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
 // CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
-// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : i32
-// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : i32
-// CHECK:           %[[IN0_F32:.*]] = vector.extractelement %[[VAL_0]]{{\[}}%[[C0]] : i32] : vector<2xf32>
+// CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : vector<2xf32>
 // CHECK:           %[[OUT0_F32:.*]] = call @expm1f(%[[IN0_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_8:.*]] = vector.insertelement %[[OUT0_F32]], %[[CVF]]{{\[}}%[[C0]] : i32] : vector<2xf32>
-// CHECK:           %[[IN1_F32:.*]] = vector.extractelement %[[VAL_0]]{{\[}}%[[C1]] : i32] : vector<2xf32>
+// CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
+// CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : vector<2xf32>
 // CHECK:           %[[OUT1_F32:.*]] = call @expm1f(%[[IN1_F32]]) : (f32) -> f32
-// CHECK:           %[[VAL_11:.*]] = vector.insertelement %[[OUT1_F32]], %[[VAL_8]]{{\[}}%[[C1]] : i32] : vector<2xf32>
-// CHECK:           %[[IN0_F64:.*]] = vector.extractelement %[[VAL_1]]{{\[}}%[[C0]] : i32] : vector<2xf64>
+// CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+// CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : vector<2xf64>
 // CHECK:           %[[OUT0_F64:.*]] = call @expm1(%[[IN0_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_14:.*]] = vector.insertelement %[[OUT0_F64]], %[[CVD]]{{\[}}%[[C0]] : i32] : vector<2xf64>
-// CHECK:           %[[IN1_F64:.*]] = vector.extractelement %[[VAL_1]]{{\[}}%[[C1]] : i32] : vector<2xf64>
+// CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
+// CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : vector<2xf64>
 // CHECK:           %[[OUT1_F64:.*]] = call @expm1(%[[IN1_F64]]) : (f64) -> f64
-// CHECK:           %[[VAL_17:.*]] = vector.insertelement %[[OUT1_F64]], %[[VAL_14]]{{\[}}%[[C1]] : i32] : vector<2xf64>
+// CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
 // CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
 // CHECK:         }
 
+func @expm1_multidim_vec_caller(%float: vector<2x2xf32>) -> (vector<2x2xf32>) {
+  %result = math.expm1 %float : vector<2x2xf32>
+  return %result : vector<2x2xf32>
+}
+// CHECK-LABEL:   func @expm1_multidim_vec_caller(
+// CHECK-SAME:                           %[[VAL:.*]]: vector<2x2xf32>
+// CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK:           %[[IN0_0_F32:.*]] = vector.extract %[[VAL]][0, 0] : vector<2x2xf32>
+// CHECK:           %[[OUT0_0_F32:.*]] = call @expm1f(%[[IN0_0_F32]]) : (f32) -> f32
+// CHECK:           %[[VAL_1:.*]] = vector.insert %[[OUT0_0_F32]], %[[CVF]] [0, 0] : f32 into vector<2x2xf32>
+// CHECK:           %[[IN0_1_F32:.*]] = vector.extract %[[VAL]][0, 1] : vector<2x2xf32>
+// CHECK:           %[[OUT0_1_F32:.*]] = call @expm1f(%[[IN0_1_F32]]) : (f32) -> f32
+// CHECK:           %[[VAL_2:.*]] = vector.insert %[[OUT0_1_F32]], %[[VAL_1]] [0, 1] : f32 into vector<2x2xf32>
+// CHECK:           %[[IN1_0_F32:.*]] = vector.extract %[[VAL]][1, 0] : vector<2x2xf32>
+// CHECK:           %[[OUT1_0_F32:.*]] = call @expm1f(%[[IN1_0_F32]]) : (f32) -> f32
+// CHECK:           %[[VAL_3:.*]] = vector.insert %[[OUT1_0_F32]], %[[VAL_2]] [1, 0] : f32 into vector<2x2xf32>
+// CHECK:           %[[IN1_1_F32:.*]] = vector.extract %[[VAL]][1, 1] : vector<2x2xf32>
+// CHECK:           %[[OUT1_1_F32:.*]] = call @expm1f(%[[IN1_1_F32]]) : (f32) -> f32
+// CHECK:           %[[VAL_4:.*]] = vector.insert %[[OUT1_1_F32]], %[[VAL_3]] [1, 1] : f32 into vector<2x2xf32>
+// CHECK:           return %[[VAL_4]] : vector<2x2xf32>
+// CHECK:         }


        


More information about the Mlir-commits mailing list