[flang-commits] [flang] [flang] Added definition of hlfir.cshift operation. (PR #118732)

via flang-commits flang-commits at lists.llvm.org
Wed Dec 4 18:28:12 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Slava Zakharin (vzakhari)

<details>
<summary>Changes</summary>

CSHIFT intrinsic will be lowered to this operation, which
then can be optimized as inline sequence or lowered into
a runtime call.


---
Full diff: https://github.com/llvm/llvm-project/pull/118732.diff


6 Files Affected:

- (modified) flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h (+3) 
- (modified) flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td (+6) 
- (modified) flang/include/flang/Optimizer/HLFIR/HLFIROps.td (+21) 
- (modified) flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp (+9) 
- (modified) flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp (+85) 
- (modified) flang/test/HLFIR/invalid.fir (+59) 


``````````diff
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
index 447d5fbab89998..15296aa7e8c75c 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
@@ -136,6 +136,9 @@ mlir::Value genExprShape(mlir::OpBuilder &builder, const mlir::Location &loc,
 /// This has to be cleaned up, when HLFIR is the default.
 bool mayHaveAllocatableComponent(mlir::Type ty);
 
+/// Scalar integer or a sequence of integers (via boxed array or expr).
+bool isFortranIntegerScalarOrArrayObject(mlir::Type type);
+
 } // namespace hlfir
 
 #endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
index d967a407a75880..404ab5f633bf78 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
@@ -155,6 +155,12 @@ def IsPolymorphicObjectPred
 def AnyPolymorphicObject : Type<IsPolymorphicObjectPred,
     "any polymorphic object">;
 
+def IsFortranIntegerScalarOrArrayPred
+    : CPred<"::hlfir::isFortranIntegerScalarOrArrayObject($_self)">;
+def AnyFortranIntegerScalarOrArrayObject
+    : Type<IsFortranIntegerScalarOrArrayPred,
+           "A scalar or array object containing integers">;
+
 def hlfir_CharExtremumPredicateAttr : I32EnumAttr<
     "CharExtremumPredicate", "",
     [
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index a9826543f48b69..f11162dc0d95e1 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -699,6 +699,27 @@ def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose",
   let hasVerifier = 1;
 }
 
+def hlfir_CShiftOp
+    : hlfir_Op<
+          "cshift", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let summary = "CSHIFT transformational intrinsic";
+  let description = [{
+    Circular shift of an array
+  }];
+
+  let arguments = (ins AnyFortranArrayObject:$array,
+      AnyFortranIntegerScalarOrArrayObject:$shift,
+      Optional<AnyIntegerType>:$dim);
+
+  let results = (outs hlfir_ExprType);
+
+  let assemblyFormat = [{
+    $array $shift (`dim` $dim^)? attr-dict `:` functional-type(operands, results)
+  }];
+
+  let hasVerifier = 1;
+}
+
 // An allocation effect is needed because the value produced by the associate
 // is "deallocated" by hlfir.end_associate (the end_associate must not be
 // removed, and there must be only one hlfir.end_associate).
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index d67b5fa6598075..cb77aef74acd56 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -228,3 +228,12 @@ mlir::Type hlfir::getExprType(mlir::Type variableType) {
   return hlfir::ExprType::get(variableType.getContext(), typeShape, type,
                               isPolymorphic);
 }
+
+bool hlfir::isFortranIntegerScalarOrArrayObject(mlir::Type type) {
+  if (isBoxAddressType(type))
+    return false;
+
+  mlir::Type unwrappedType = fir::unwrapPassByRefType(fir::unwrapRefType(type));
+  mlir::Type elementType = getFortranElementType(unwrappedType);
+  return mlir::isa<mlir::IntegerType>(elementType);
+}
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index 3a172d1b8b5400..af8a81d42aac3b 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -1341,6 +1341,91 @@ void hlfir::MatmulTransposeOp::getEffects(
   getIntrinsicEffects(getOperation(), effects);
 }
 
+//===----------------------------------------------------------------------===//
+// CShiftOp
+//===----------------------------------------------------------------------===//
+
+llvm::LogicalResult hlfir::CShiftOp::verify() {
+  mlir::Value array = getArray();
+  fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
+      hlfir::getFortranElementOrSequenceType(array.getType()));
+  llvm::ArrayRef<int64_t> inShape = arrayTy.getShape();
+  std::size_t arrayRank = inShape.size();
+  mlir::Type eleTy = arrayTy.getEleTy();
+  hlfir::ExprType resultTy = mlir::cast<hlfir::ExprType>(getResult().getType());
+  llvm::ArrayRef<int64_t> resultShape = resultTy.getShape();
+  std::size_t resultRank = resultShape.size();
+  mlir::Type resultEleTy = resultTy.getEleTy();
+  mlir::Value shift = getShift();
+  mlir::Type shiftTy = hlfir::getFortranElementOrSequenceType(shift.getType());
+
+  if (eleTy != resultEleTy)
+    return emitOpError(
+        "input and output arrays should have the same element type");
+
+  if (arrayRank != resultRank)
+    return emitOpError("input and output arrays should have the same rank");
+
+  constexpr int64_t unknownExtent = fir::SequenceType::getUnknownExtent();
+  for (auto [inDim, resultDim] : llvm::zip(inShape, resultShape))
+    if (inDim != unknownExtent && resultDim != unknownExtent &&
+        inDim != resultDim)
+      return emitOpError(
+          "output array's shape conflicts with the input array's shape");
+
+  int64_t dimVal = -1;
+  if (!getDim())
+    dimVal = 1;
+  else if (auto dim = fir::getIntIfConstant(getDim()))
+    dimVal = *dim;
+
+  if (dimVal != -1) {
+    if (dimVal < 1)
+      return emitOpError("DIM must be >= 1");
+    if (dimVal > static_cast<int64_t>(arrayRank))
+      return emitOpError("DIM must be <= input array's rank");
+  }
+
+  if (auto shiftSeqTy = mlir::dyn_cast<fir::SequenceType>(shiftTy)) {
+    // SHIFT is an array. Verify the rank and the shape (if DIM is constant).
+    llvm::ArrayRef<int64_t> shiftShape = shiftSeqTy.getShape();
+    std::size_t shiftRank = shiftShape.size();
+    if (shiftRank != arrayRank - 1)
+      return emitOpError(
+          "SHIFT's rank must be 1 less than the input array's rank");
+
+    if (dimVal != -1) {
+      // SHIFT's shape must be [d(1), d(2), ..., d(DIM-1), d(DIM+1), ..., d(n)],
+      // where [d(1), d(2), ..., d(n)] is the shape of the ARRAY.
+      int64_t arrayDimIdx = 0;
+      int64_t shiftDimIdx = 0;
+      for (auto shiftDim : shiftShape) {
+        if (arrayDimIdx == dimVal - 1)
+          ++arrayDimIdx;
+
+        if (inShape[arrayDimIdx] != unknownExtent &&
+            shiftDim != unknownExtent && inShape[arrayDimIdx] != shiftDim)
+          return emitOpError("SHAPE(ARRAY)(" + llvm::Twine(arrayDimIdx + 1) +
+                             ") must be equal to SHAPE(SHIFT)(" +
+                             llvm::Twine(shiftDimIdx + 1) +
+                             "): " + llvm::Twine(inShape[arrayDimIdx]) +
+                             " != " + llvm::Twine(shiftDim));
+        ++arrayDimIdx;
+        ++shiftDimIdx;
+      }
+    }
+  }
+
+  return mlir::success();
+}
+
+void hlfir::CShiftOp::getEffects(
+    llvm::SmallVectorImpl<
+        mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>>
+        &effects) {
+  getIntrinsicEffects(getOperation(), effects);
+}
+
 //===----------------------------------------------------------------------===//
 // AssociateOp
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/HLFIR/invalid.fir b/flang/test/HLFIR/invalid.fir
index 5c5db7aac06970..d0ac6864630e0c 100644
--- a/flang/test/HLFIR/invalid.fir
+++ b/flang/test/HLFIR/invalid.fir
@@ -1348,3 +1348,62 @@ func.func @bad_eval_in_mem_3() {
   }
   return
 }
+
+// -----
+
+func.func @bad_cshift1(%arg0: !hlfir.expr<?x?xi32>, %arg1: i32) {
+  // expected-error at +1 {{'hlfir.cshift' op input and output arrays should have the same element type}}
+  %0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<?x?xi32>, i32) -> !hlfir.expr<?x?xf32>
+  return
+}
+
+// -----
+
+func.func @bad_cshift2(%arg0: !hlfir.expr<?x?xi32>, %arg1: i32) {
+  // expected-error at +1 {{'hlfir.cshift' op input and output arrays should have the same rank}}
+  %0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<?x?xi32>, i32) -> !hlfir.expr<?xi32>
+  return
+}
+
+// -----
+
+func.func @bad_cshift3(%arg0: !hlfir.expr<2x2xi32>, %arg1: i32) {
+  // expected-error at +1 {{'hlfir.cshift' op output array's shape conflicts with the input array's shape}}
+  %0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<2x2xi32>, i32) -> !hlfir.expr<2x3xi32>
+  return
+}
+
+// -----
+
+func.func @bad_cshift4(%arg0: !hlfir.expr<2x2xi32>, %arg1: i32) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{'hlfir.cshift' op DIM must be >= 1}}
+  %0 = hlfir.cshift %arg0 %arg1 dim %c0 : (!hlfir.expr<2x2xi32>, i32, index) -> !hlfir.expr<2x2xi32>
+  return
+}
+
+// -----
+
+func.func @bad_cshift5(%arg0: !hlfir.expr<2x2xi32>, %arg1: i32) {
+  %c10 = arith.constant 10 : index
+  // expected-error at +1 {{'hlfir.cshift' op DIM must be <= input array's rank}}
+  %0 = hlfir.cshift %arg0 %arg1 dim %c10 : (!hlfir.expr<2x2xi32>, i32, index) -> !hlfir.expr<2x2xi32>
+  return
+}
+
+// -----
+
+func.func @bad_cshift6(%arg0: !hlfir.expr<2x2xi32>, %arg1: !hlfir.expr<2x2xi32>) {
+  // expected-error at +1 {{'hlfir.cshift' op SHIFT's rank must be 1 less than the input array's rank}}
+  %0 = hlfir.cshift %arg0 %arg1 : (!hlfir.expr<2x2xi32>, !hlfir.expr<2x2xi32>) -> !hlfir.expr<2x2xi32>
+  return
+}
+
+// -----
+
+func.func @bad_cshift7(%arg0: !hlfir.expr<?x2xi32>, %arg1: !hlfir.expr<3xi32>) {
+  %c1 = arith.constant 1 : index
+  // expected-error at +1 {{'hlfir.cshift' op SHAPE(ARRAY)(2) must be equal to SHAPE(SHIFT)(1): 2 != 3}}
+  %0 = hlfir.cshift %arg0 %arg1 dim %c1 : (!hlfir.expr<?x2xi32>, !hlfir.expr<3xi32>, index) -> !hlfir.expr<2x2xi32>
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/118732


More information about the flang-commits mailing list