[flang-commits] [flang] 91cbc3f - [flang] lower matmul intrinsic to hlfir.matmul operation

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Thu Feb 16 07:32:13 PST 2023


Author: Tom Eccles
Date: 2023-02-16T15:30:46Z
New Revision: 91cbc3f2d83dfcf064238e807b47c58279509ff7

URL: https://github.com/llvm/llvm-project/commit/91cbc3f2d83dfcf064238e807b47c58279509ff7
DIFF: https://github.com/llvm/llvm-project/commit/91cbc3f2d83dfcf064238e807b47c58279509ff7.diff

LOG: [flang] lower matmul intrinsic to hlfir.matmul operation

Differential Revision: https://reviews.llvm.org/D144096

Added: 
    flang/test/Lower/HLFIR/matmul.f90

Modified: 
    flang/include/flang/Optimizer/HLFIR/HLFIROps.td
    flang/lib/Lower/ConvertCall.cpp
    flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index 9f187ac6afca8..ef0d1d1f0462f 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -280,12 +280,6 @@ def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
     $array (`dim` $dim^)? (`mask` $mask^)? attr-dict `:` functional-type(operands, results)
   }];
 
-  // dim and mask can be NULL, array must not be.
-  let builders = [OpBuilder<(ins "mlir::Value":$array,
-                                 "mlir::Value":$dim,
-                                 "mlir::Value":$mask,
-                                 "mlir::Type":$resultType)>];
-
   let hasVerifier = 1;
 }
 
@@ -309,10 +303,6 @@ def hlfir_MatmulOp : hlfir_Op<"matmul",
     $lhs $rhs attr-dict `:` functional-type(operands, results)
   }];
 
-  let builders = [OpBuilder<(ins "mlir::Value":$lhs,
-                                 "mlir::Value":$rhs,
-                                 "mlir::Type":$resultType)>];
-
   let hasVerifier = 1;
 }
 

diff  --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index 5978441517df4..53bd066965583 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -1249,38 +1249,74 @@ genHLFIRIntrinsicRefCore(PreparedActualArguments &loweredActuals,
   if (!useHlfirIntrinsicOps)
     return genIntrinsicRefCore(loweredActuals, intrinsic, argLowering, callContext);
 
-  auto getOperandVector =
-    [](PreparedActualArguments &loweredActuals) {
-      llvm::SmallVector<mlir::Value> operands;
-      operands.reserve(loweredActuals.size());
-      for (auto arg : llvm::enumerate(loweredActuals)) {
-        if (!arg.value()) {
-          operands.emplace_back();
-          continue;
-        }
-        hlfir::Entity actual = arg.value()->getOriginalActual();
-        operands.emplace_back(actual.getBase());
-      }
-      return operands;
-    };
-
   fir::FirOpBuilder &builder = callContext.getBuilder();
   mlir::Location loc = callContext.loc;
 
+  auto getOperandVector = [&](PreparedActualArguments &loweredActuals) {
+    llvm::SmallVector<mlir::Value> operands;
+    operands.reserve(loweredActuals.size());
+
+    for (size_t i = 0; i < loweredActuals.size(); ++i) {
+      std::optional<PreparedActualArgument> arg = loweredActuals[i];
+      if (!arg) {
+        operands.emplace_back();
+        continue;
+      }
+      hlfir::Entity actual = arg->getOriginalActual();
+      mlir::Value valArg;
+
+      fir::ArgLoweringRule argRules =
+          fir::lowerIntrinsicArgumentAs(*argLowering, i);
+      if (!argRules.handleDynamicOptional &&
+          argRules.lowerAs != fir::LowerIntrinsicArgAs::Inquired)
+        valArg = hlfir::derefPointersAndAllocatables(loc, builder, actual);
+      else
+        valArg = actual.getBase();
+
+      operands.emplace_back(valArg);
+    }
+    return operands;
+  };
+
+  auto computeResultType = [&](mlir::Value argArray,
+                               mlir::Type stmtResultType) -> mlir::Type {
+    hlfir::ExprType::Shape resultShape;
+    mlir::Type normalisedResult =
+        hlfir::getFortranElementOrSequenceType(stmtResultType);
+    mlir::Type elementType;
+    if (auto array = normalisedResult.dyn_cast<fir::SequenceType>()) {
+      resultShape = hlfir::ExprType::Shape{array.getShape()};
+      elementType = array.getEleTy();
+    } else {
+      elementType = normalisedResult;
+    }
+    return hlfir::ExprType::get(builder.getContext(), resultShape, elementType,
+                                /*polymorphic=*/false);
+  };
+
   if (intrinsic.name == "sum") {
     llvm::SmallVector<mlir::Value> operands = getOperandVector(loweredActuals);
     assert(operands.size() == 3);
-    mlir::Value array = hlfir::derefPointersAndAllocatables(
-        loc, builder, hlfir::Entity{operands[0]});
+    mlir::Value array = operands[0];
     mlir::Value dim = operands[1];
     if (dim)
       dim = hlfir::loadTrivialScalar(loc, builder, hlfir::Entity{dim});
     mlir::Value mask = operands[2];
+    mlir::Type resultTy = computeResultType(array, *callContext.resultType);
     // dim, mask can be NULL if these arguments were not given
-    hlfir::SumOp sumOp = builder.create<hlfir::SumOp>(loc, array, dim, mask,
-                                                      *callContext.resultType);
+    hlfir::SumOp sumOp =
+        builder.create<hlfir::SumOp>(loc, resultTy, array, dim, mask);
     return {hlfir::EntityWithAttributes{sumOp.getResult()}};
   }
+  if (intrinsic.name == "matmul") {
+    llvm::SmallVector<mlir::Value> operands = getOperandVector(loweredActuals);
+    mlir::Type resultTy =
+        computeResultType(operands[0], *callContext.resultType);
+    hlfir::MatmulOp matmulOp = builder.create<hlfir::MatmulOp>(
+        loc, resultTy, operands[0], operands[1]);
+
+    return {hlfir::EntityWithAttributes{matmulOp.getResult()}};
+  }
 
   // TODO add hlfir operations for other transformational intrinsics here
 

diff  --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index 5fdc909f7500d..fed7e837a34dc 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -488,28 +488,6 @@ mlir::LogicalResult hlfir::SumOp::verify() {
   return mlir::success();
 }
 
-void hlfir::SumOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
-                         mlir::Value array, mlir::Value dim, mlir::Value mask,
-                         mlir::Type stmtResultType) {
-  assert(array && "array argument is not optional");
-
-  fir::SequenceType arrayTy =
-      hlfir::getFortranElementOrSequenceType(array.getType())
-          .dyn_cast<fir::SequenceType>();
-  assert(arrayTy && "array must be of array type");
-  mlir::Type numTy = arrayTy.getEleTy();
-
-  // get the result shape from the statement context
-  hlfir::ExprType::Shape resultShape;
-  if (auto array = stmtResultType.dyn_cast<fir::SequenceType>()) {
-    resultShape = hlfir::ExprType::Shape{array.getShape()};
-  }
-  mlir::Type resultType = hlfir::ExprType::get(
-      builder.getContext(), resultShape, numTy, /*polymorphic=*/false);
-
-  build(builder, result, resultType, array, dim, mask);
-}
-
 //===----------------------------------------------------------------------===//
 // MatmulOp
 //===----------------------------------------------------------------------===//

diff  --git a/flang/test/Lower/HLFIR/matmul.f90 b/flang/test/Lower/HLFIR/matmul.f90
new file mode 100644
index 0000000000000..624cd03e685d8
--- /dev/null
+++ b/flang/test/Lower/HLFIR/matmul.f90
@@ -0,0 +1,19 @@
+! Test lowering of MATMUL intrinsic to HLFIR
+! RUN: bbc -emit-fir -hlfir -o - %s 2>&1 | FileCheck %s
+
+subroutine matmul1(lhs, rhs, res)
+  integer :: lhs(:,:), rhs(:,:), res(:,:)
+  res = MATMUL(lhs, rhs)
+endsubroutine
+! CHECK-LABEL: func.func @_QPmatmul1
+! CHECK:           %[[LHS:.*]]: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "lhs"}
+! CHECK:           %[[RHS:.*]]: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "rhs"}
+! CHECK:           %[[RES:.*]]: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "res"}
+! CHECK-DAG:     %[[LHS_VAR:.*]]:2 = hlfir.declare %[[LHS]]
+! CHECK-DAG:     %[[RHS_VAR:.*]]:2 = hlfir.declare %[[RHS]]
+! CHECK-DAG:     %[[RES_VAR:.*]]:2 = hlfir.declare %[[RES]]
+! CHECK-NEXT:    %[[EXPR:.*]] = hlfir.matmul %[[LHS_VAR]]#0 %[[RHS_VAR]]#0 {fastmath = #arith.fastmath<contract>} : (!fir.box<!fir.array<?x?xi32>>, !fir.box<!fir.array<?x?xi32>>) -> !hlfir.expr<?x?xi32>
+! CHECK-NEXT:    hlfir.assign %[[EXPR]] to %[[RES_VAR]]#0 : !hlfir.expr<?x?xi32>, !fir.box<!fir.array<?x?xi32>>
+! CHECK-NEXT:    hlfir.destroy %[[EXPR]]
+! CHECK-NEXT:    return
+! CHECK-NEXT:   }


        


More information about the flang-commits mailing list