[flang-commits] [flang] 09472ba - [flang] add hlfir.matmul operation

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Thu Feb 16 07:32:12 PST 2023


Author: Tom Eccles
Date: 2023-02-16T15:30:46Z
New Revision: 09472ba315043748db01b092fdaca54daf19a513

URL: https://github.com/llvm/llvm-project/commit/09472ba315043748db01b092fdaca54daf19a513
DIFF: https://github.com/llvm/llvm-project/commit/09472ba315043748db01b092fdaca54daf19a513.diff

LOG: [flang] add hlfir.matmul operation

Add a HLFIR operation for the MATMUL transformational intrinsic,
according to the design set out in flang/doc/HighLevelFIR.md

Differential Revision: https://reviews.llvm.org/D144094

Added: 
    flang/test/HLFIR/matmul.fir

Modified: 
    flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
    flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
    flang/include/flang/Optimizer/HLFIR/HLFIROps.td
    flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
    flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
    flang/test/HLFIR/invalid.fir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
index 6a9acb443f9d8..a0bcd33f57e06 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
@@ -71,6 +71,7 @@ inline bool isBoxAddressOrValueType(mlir::Type type) {
 
 bool isFortranScalarNumericalType(mlir::Type);
 bool isFortranNumericalArrayObject(mlir::Type);
+bool isFortranNumericalOrLogicalArrayObject(mlir::Type);
 bool isPassByRefOrIntegerType(mlir::Type);
 bool isI1Type(mlir::Type);
 // scalar i1 or logical, or sequence of logical (via (boxed?) array or expr)

diff  --git a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
index eca5f05d3d02d..e210de3ce33eb 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
@@ -113,6 +113,11 @@ def IsFortranNumericalArrayObjectPred
 def AnyFortranNumericalArrayObject : Type<IsFortranNumericalArrayObjectPred,
     "any array-like object containing a numerical type">;
 
+def IsFortranNumericalOrLogicalArrayObjectPred
+        : CPred<"::hlfir::isFortranNumericalOrLogicalArrayObject($_self)">;
+def AnyFortranNumericalOrLogicalArrayObject : Type<IsFortranNumericalOrLogicalArrayObjectPred,
+    "any array-like object containing a numerical or logical type">;
+
 def IsPassByRefOrIntegerTypePred
         : CPred<"::hlfir::isPassByRefOrIntegerType($_self)">;
 def AnyPassByRefOrIntegerType : Type<IsPassByRefOrIntegerTypePred,

diff  --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index d929850d9ff28..9f187ac6afca8 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -289,6 +289,33 @@ def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
   let hasVerifier = 1;
 }
 
+def hlfir_MatmulOp : hlfir_Op<"matmul",
+    [DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+  let summary = "MATMUL transformational intrinsic";
+  let description = [{
+    Matrix multiplication
+  }];
+
+  let arguments = (ins
+    AnyFortranNumericalOrLogicalArrayObject:$lhs,
+    AnyFortranNumericalOrLogicalArrayObject:$rhs,
+    DefaultValuedAttr<Arith_FastMathAttr,
+                      "::mlir::arith::FastMathFlags::none">:$fastmath
+  );
+
+  let results = (outs hlfir_ExprType);
+
+  let assemblyFormat = [{
+    $lhs $rhs attr-dict `:` functional-type(operands, results)
+  }];
+
+  let builders = [OpBuilder<(ins "mlir::Value":$lhs,
+                                 "mlir::Value":$rhs,
+                                 "mlir::Type":$resultType)>];
+
+  let hasVerifier = 1;
+}
+
 def hlfir_AssociateOp : hlfir_Op<"associate", [AttrSizedOperandSegments,
     DeclareOpInterfaceMethods<fir_FortranVariableOpInterface>]> {
   let summary = "Create a variable from an expression value";

diff  --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index f23be5de3be14..3dd757f993848 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -115,6 +115,18 @@ bool hlfir::isFortranNumericalArrayObject(mlir::Type type) {
   return false;
 }
 
+bool hlfir::isFortranNumericalOrLogicalArrayObject(mlir::Type type) {
+  if (isBoxAddressType(type))
+    return false;
+  if (auto arrayTy =
+          getFortranElementOrSequenceType(type).dyn_cast<fir::SequenceType>()) {
+    mlir::Type eleTy = arrayTy.getEleTy();
+    return isFortranScalarNumericalType(eleTy) ||
+           mlir::isa<fir::LogicalType>(eleTy);
+  }
+  return false;
+}
+
 bool hlfir::isPassByRefOrIntegerType(mlir::Type type) {
   mlir::Type unwrappedType = fir::unwrapPassByRefType(type);
   return fir::isa_integer(unwrappedType);

diff  --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index 39b36e7c4a3a1..5fdc909f7500d 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -510,6 +510,76 @@ void hlfir::SumOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
   build(builder, result, resultType, array, dim, mask);
 }
 
+//===----------------------------------------------------------------------===//
+// MatmulOp
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult hlfir::MatmulOp::verify() {
+  mlir::Value lhs = getLhs();
+  mlir::Value rhs = getRhs();
+  fir::SequenceType lhsTy =
+      hlfir::getFortranElementOrSequenceType(lhs.getType())
+          .cast<fir::SequenceType>();
+  fir::SequenceType rhsTy =
+      hlfir::getFortranElementOrSequenceType(rhs.getType())
+          .cast<fir::SequenceType>();
+  llvm::ArrayRef<int64_t> lhsShape = lhsTy.getShape();
+  llvm::ArrayRef<int64_t> rhsShape = rhsTy.getShape();
+  std::size_t lhsRank = lhsShape.size();
+  std::size_t rhsRank = rhsShape.size();
+  mlir::Type lhsEleTy = lhsTy.getEleTy();
+  mlir::Type rhsEleTy = rhsTy.getEleTy();
+  hlfir::ExprType resultTy = getResult().getType().cast<hlfir::ExprType>();
+  llvm::ArrayRef<int64_t> resultShape = resultTy.getShape();
+  mlir::Type resultEleTy = resultTy.getEleTy();
+
+  if (((lhsRank != 1) && (lhsRank != 2)) || ((rhsRank != 1) && (rhsRank != 2)))
+    return emitOpError("array must have either rank 1 or rank 2");
+
+  if ((lhsRank == 1) && (rhsRank == 1))
+    return emitOpError("at least one array must have rank 2");
+
+  if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
+      mlir::isa<fir::LogicalType>(rhsEleTy))
+    return emitOpError("if one array is logical, so should the other be");
+
+  int64_t lastLhsDim = lhsShape[lhsRank - 1];
+  int64_t firstRhsDim = rhsShape[0];
+  constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
+  if (lastLhsDim != firstRhsDim)
+    if ((lastLhsDim != unknownExtent) && (firstRhsDim != unknownExtent))
+      return emitOpError(
+          "the last dimension of LHS should match the first dimension of RHS");
+
+  if (mlir::isa<fir::LogicalType>(lhsEleTy) !=
+      mlir::isa<fir::LogicalType>(resultEleTy))
+    return emitOpError("the result type should be a logical only if the "
+                       "argument types are logical");
+
+  llvm::SmallVector<int64_t, 2> expectedResultShape;
+  if (lhsRank == 2) {
+    if (rhsRank == 2) {
+      expectedResultShape.push_back(lhsShape[0]);
+      expectedResultShape.push_back(rhsShape[1]);
+    } else {
+      // rhsRank == 1
+      expectedResultShape.push_back(lhsShape[0]);
+    }
+  } else {
+    // lhsRank == 1
+    // rhsRank == 2
+    expectedResultShape.push_back(rhsShape[1]);
+  }
+  if (resultShape.size() != expectedResultShape.size())
+    return emitOpError("incorrect result shape");
+  if (resultShape[0] != expectedResultShape[0])
+    return emitOpError("incorrect result shape");
+  if (resultShape.size() == 2 && resultShape[1] != expectedResultShape[1])
+    return emitOpError("incorrect result shape");
+
+  return mlir::success();
+}
+
 //===----------------------------------------------------------------------===//
 // AssociateOp
 //===----------------------------------------------------------------------===//

diff  --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir
index dd801634c6f7d..b4afc8fe48a11 100644
--- a/flang/test/HLFIR/invalid.fir
+++ b/flang/test/HLFIR/invalid.fir
@@ -319,3 +319,59 @@ func.func @bad_sum4(%arg0: !hlfir.expr<?xi32>, %arg1: i32, %arg2: !fir.box<!fir.
   // expected-error at +1 {{'hlfir.sum' op result rank must be one less than ARRAY}}
   %0 = hlfir.sum %arg0 dim %arg1 mask %arg2 : (!hlfir.expr<?xi32>, i32, !fir.box<!fir.logical<4>>) -> !hlfir.expr<?x?xi32>
 }
+
+// -----
+func.func @bad_matmul1(%arg0: !hlfir.expr<?x?x?xi32>, %arg1: !hlfir.expr<?x?xi32>) {
+  // expected-error at +1 {{'hlfir.matmul' op array must have either rank 1 or rank 2}}
+  %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?x?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+  return
+}
+
+// -----
+func.func @bad_matmul2(%arg0: !hlfir.expr<?xi32>, %arg1: !hlfir.expr<?xi32>) {
+  // expected-error at +1 {{'hlfir.matmul' op at least one array must have rank 2}}
+  %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?xi32>, !hlfir.expr<?xi32>) -> !hlfir.expr<?x?xi32>
+  return
+}
+
+// -----
+func.func @bad_matmul3(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: !hlfir.expr<?x?xi32>) {
+  // expected-error at +1 {{'hlfir.matmul' op if one array is logical, so should the other be}}
+  %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+  return
+}
+
+// -----
+func.func @bad_matmul4(%arg0: !hlfir.expr<?x2xi32>, %arg1: !hlfir.expr<200x?xi32>) {
+  // expected-error at +1 {{'hlfir.matmul' op the last dimension of LHS should match the first dimension of RHS}}
+  %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x2xi32>, !hlfir.expr<200x?xi32>) -> !hlfir.expr<?x?xi32>
+  return
+}
+
+// -----
+func.func @bad_matmul5(%arg0: !hlfir.expr<?x?xi32>, %arg1: !hlfir.expr<?x?xi32>) {
+  // expected-error at +1 {{'hlfir.matmul' op the result type should be a logical only if the argument types are logical}}
+  %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?x!fir.logical<4>>
+  return
+}
+
+// -----
+func.func @bad_matmul6(%arg0: !hlfir.expr<1x2xi32>, %arg1: !hlfir.expr<2x3xi32>) {
+  // expected-error at +1 {{'hlfir.matmul' op incorrect result shape}}
+  %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<1x2xi32>, !hlfir.expr<2x3xi32>) -> !hlfir.expr<10x30xi32>
+  return
+}
+
+// -----
+func.func @bad_matmul7(%arg0: !hlfir.expr<1x2xi32>, %arg1: !hlfir.expr<2xi32>) {
+  // expected-error at +1 {{'hlfir.matmul' op incorrect result shape}}
+  %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<1x2xi32>, !hlfir.expr<2xi32>) -> !hlfir.expr<1x3xi32>
+  return
+}
+
+// -----
+func.func @bad_matmul8(%arg0: !hlfir.expr<2xi32>, %arg1: !hlfir.expr<2x3xi32>) {
+  // expected-error at +1 {{'hlfir.matmul' op incorrect result shape}}
+  %0 = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<2xi32>, !hlfir.expr<2x3xi32>) -> !hlfir.expr<1x3xi32>
+  return
+}

diff  --git a/flang/test/HLFIR/matmul.fir b/flang/test/HLFIR/matmul.fir
new file mode 100644
index 0000000000000..eaf85c3969737
--- /dev/null
+++ b/flang/test/HLFIR/matmul.fir
@@ -0,0 +1,99 @@
+// Test hlfir.matmul operation parse, verify (no errors), and unparse
+
+// RUN: fir-opt %s | fir-opt | FileCheck %s
+
+// arguments are expressions of known shape
+func.func @matmul0(%arg0: !hlfir.expr<2x2xi32>, %arg1: !hlfir.expr<2x2xi32>) {
+  %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<2x2xi32>, !hlfir.expr<2x2xi32>) -> !hlfir.expr<2x2xi32>
+  return
+}
+// CHECK-LABEL: func.func @matmul0
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<2x2xi32>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<2x2xi32>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr<2x2xi32>, !hlfir.expr<2x2xi32>) -> !hlfir.expr<2x2xi32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// arguments are expressions of assumed shape
+func.func @matmul1(%arg0: !hlfir.expr<?x?xi32>, %arg1: !hlfir.expr<?x?xi32>) {
+  %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+  return
+}
+// CHECK-LABEL: func.func @matmul1
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<?x?xi32>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<?x?xi32>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr<?x?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?x?xi32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// arguments are expressions where only some dimensions are known #1
+func.func @matmul2(%arg0: !hlfir.expr<2x?xi32>, %arg1: !hlfir.expr<?x2xi32>) {
+  %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<2x?xi32>, !hlfir.expr<?x2xi32>) -> !hlfir.expr<2x2xi32>
+  return
+}
+// CHECK-LABEL: func.func @matmul2
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<2x?xi32>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<?x2xi32>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr<2x?xi32>, !hlfir.expr<?x2xi32>) -> !hlfir.expr<2x2xi32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// arguments are expressions where only some dimensions are known #2
+func.func @matmul3(%arg0: !hlfir.expr<?x2xi32>, %arg1: !hlfir.expr<2x?xi32>) {
+  %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x2xi32>, !hlfir.expr<2x?xi32>) -> !hlfir.expr<?x?xi32>
+  return
+}
+// CHECK-LABEL: func.func @matmul3
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<?x2xi32>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<2x?xi32>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr<?x2xi32>, !hlfir.expr<2x?xi32>) -> !hlfir.expr<?x?xi32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// arguments are logicals
+func.func @matmul4(%arg0: !hlfir.expr<?x?x!fir.logical<4>>, %arg1: !hlfir.expr<?x?x!fir.logical<4>>) {
+  %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?x!fir.logical<4>>, !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<?x?x!fir.logical<4>>
+  return
+}
+// CHECK-LABEL: func.func @matmul4
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<?x?x!fir.logical<4>>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<?x?x!fir.logical<4>>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr<?x?x!fir.logical<4>>, !hlfir.expr<?x?x!fir.logical<4>>) -> !hlfir.expr<?x?x!fir.logical<4>>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// lhs is rank 1
+func.func @matmul5(%arg0: !hlfir.expr<?xi32>, %arg1: !hlfir.expr<?x?xi32>) {
+  %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?xi32>
+  return
+}
+// CHECK-LABEL: func.func @matmul5
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<?xi32>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<?x?xi32>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr<?xi32>, !hlfir.expr<?x?xi32>) -> !hlfir.expr<?xi32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// rhs is rank 1
+func.func @matmul6(%arg0: !hlfir.expr<?x?xi32>, %arg1: !hlfir.expr<?xi32>) {
+  %res = hlfir.matmul %arg0 %arg1 : (!hlfir.expr<?x?xi32>, !hlfir.expr<?xi32>) -> !hlfir.expr<?xi32>
+  return
+}
+// CHECK-LABEL: func.func @matmul6
+// CHECK:           %[[ARG0:.*]]: !hlfir.expr<?x?xi32>,
+// CHECK:           %[[ARG1:.*]]: !hlfir.expr<?xi32>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!hlfir.expr<?x?xi32>, !hlfir.expr<?xi32>) -> !hlfir.expr<?xi32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+// arguments are boxed arrays
+func.func @matmul7(%arg0: !fir.box<!fir.array<2x2xf32>>, %arg1: !fir.box<!fir.array<2x2xf32>>) {
+  %res = hlfir.matmul %arg0 %arg1 : (!fir.box<!fir.array<2x2xf32>>, !fir.box<!fir.array<2x2xf32>>) -> !hlfir.expr<2x2xf32>
+  return
+}
+// CHECK-LABEL: func.func @matmul7
+// CHECK:           %[[ARG0:.*]]: !fir.box<!fir.array<2x2xf32>>,
+// CHECK:           %[[ARG1:.*]]: !fir.box<!fir.array<2x2xf32>>) {
+// CHECK-NEXT:    %[[RES:.*]] = hlfir.matmul %[[ARG0]] %[[ARG1]] : (!fir.box<!fir.array<2x2xf32>>, !fir.box<!fir.array<2x2xf32>>) -> !hlfir.expr<2x2xf32>
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }


        


More information about the flang-commits mailing list