Skip to content

Conversation

@wangzpgi
Copy link
Contributor

Refactor computeWidth from CUFOpConversion into a shared helper function computeElementByteSize in CUFCommon.

@wangzpgi wangzpgi requested a review from clementval November 10, 2025 21:48
@wangzpgi wangzpgi self-assigned this Nov 10, 2025
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Nov 10, 2025
@llvmbot
Copy link
Member

llvmbot commented Nov 10, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Zhen Wang (wangzpgi)

Changes

Refactor computeWidth from CUFOpConversion into a shared helper function computeElementByteSize in CUFCommon.


Full diff: https://github.com/llvm/llvm-project/pull/167398.diff

3 Files Affected:

  • (modified) flang/include/flang/Optimizer/Builder/CUFCommon.h (+5)
  • (modified) flang/lib/Optimizer/Builder/CUFCommon.cpp (+23)
  • (modified) flang/lib/Optimizer/Transforms/CUFOpConversion.cpp (+5-27)
diff --git a/flang/include/flang/Optimizer/Builder/CUFCommon.h b/flang/include/flang/Optimizer/Builder/CUFCommon.h
index 5c56dd6b695f8..6e2442745f9a0 100644
--- a/flang/include/flang/Optimizer/Builder/CUFCommon.h
+++ b/flang/include/flang/Optimizer/Builder/CUFCommon.h
@@ -18,6 +18,7 @@ static constexpr llvm::StringRef cudaSharedMemSuffix = "__shared_mem";
 
 namespace fir {
 class FirOpBuilder;
+class KindMapping;
 } // namespace fir
 
 namespace cuf {
@@ -34,6 +35,10 @@ bool isRegisteredDeviceAttr(std::optional<cuf::DataAttribute> attr);
 
 void genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder);
 
+int computeElementByteSize(mlir::Location loc, mlir::Type type,
+                           fir::KindMapping &kindMap,
+                           bool emitErrorOnFailure = true);
+
 } // namespace cuf
 
 #endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_
diff --git a/flang/lib/Optimizer/Builder/CUFCommon.cpp b/flang/lib/Optimizer/Builder/CUFCommon.cpp
index cf7588f275d22..461deb8e4cb55 100644
--- a/flang/lib/Optimizer/Builder/CUFCommon.cpp
+++ b/flang/lib/Optimizer/Builder/CUFCommon.cpp
@@ -9,6 +9,7 @@
 #include "flang/Optimizer/Builder/CUFCommon.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Optimizer/Dialect/Support/KindMapping.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
@@ -91,3 +92,25 @@ void cuf::genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder) {
     }
   }
 }
+
+int cuf::computeElementByteSize(mlir::Location loc, mlir::Type type,
+                                fir::KindMapping &kindMap,
+                                bool emitErrorOnFailure) {
+  auto eleTy = fir::unwrapSequenceType(type);
+  if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
+    return t.getWidth() / 8;
+  if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
+    return t.getWidth() / 8;
+  if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
+    return kindMap.getLogicalBitsize(t.getFKind()) / 8;
+  if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
+    int elemSize =
+        mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
+    return 2 * elemSize;
+  }
+  if (auto t{mlir::dyn_cast<fir::CharacterType>(eleTy)})
+    return kindMap.getCharacterBitsize(t.getFKind()) / 8;
+  if (emitErrorOnFailure)
+    mlir::emitError(loc, "unsupported type");
+  return 0;
+}
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 8d00272b09f42..3c3782cc234f8 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -263,28 +263,6 @@ static bool inDeviceContext(mlir::Operation *op) {
   return false;
 }
 
-static int computeWidth(mlir::Location loc, mlir::Type type,
-                        fir::KindMapping &kindMap) {
-  auto eleTy = fir::unwrapSequenceType(type);
-  if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
-    return t.getWidth() / 8;
-  if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
-    return t.getWidth() / 8;
-  if (eleTy.isInteger(1))
-    return 1;
-  if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
-    return kindMap.getLogicalBitsize(t.getFKind()) / 8;
-  if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
-    int elemSize =
-        mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
-    return 2 * elemSize;
-  }
-  if (auto t{mlir::dyn_cast_or_null<fir::CharacterType>(eleTy)})
-    return kindMap.getCharacterBitsize(t.getFKind()) / 8;
-  mlir::emitError(loc, "unsupported type");
-  return 0;
-}
-
 struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -320,7 +298,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
       mlir::Value bytes;
       fir::KindMapping kindMap{fir::getKindMapping(mod)};
       if (fir::isa_trivial(op.getInType())) {
-        int width = computeWidth(loc, op.getInType(), kindMap);
+        int width = cuf::computeElementByteSize(loc, op.getInType(), kindMap);
         bytes =
             builder.createIntegerConstant(loc, builder.getIndexType(), width);
       } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
@@ -330,7 +308,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
           mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy());
           size = dl->getTypeSizeInBits(structTy) / 8;
         } else {
-          size = computeWidth(loc, seqTy.getEleTy(), kindMap);
+          size = cuf::computeElementByteSize(loc, seqTy.getEleTy(), kindMap);
         }
         mlir::Value width =
             builder.createIntegerConstant(loc, builder.getIndexType(), size);
@@ -619,8 +597,8 @@ struct CUFDataTransferOpConversion
                               const mlir::SymbolTable &symtab,
                               mlir::DataLayout *dl,
                               const fir::LLVMTypeConverter *typeConverter)
-      : OpRewritePattern(context), symtab{symtab}, dl{dl},
-        typeConverter{typeConverter} {}
+      : OpRewritePattern(context), symtab{symtab}, dl{dl}, typeConverter{
+                                                               typeConverter} {}
 
   mlir::LogicalResult
   matchAndRewrite(cuf::DataTransferOp op,
@@ -704,7 +682,7 @@ struct CUFDataTransferOpConversion
             typeConverter->convertType(fir::unwrapSequenceType(dstTy));
         width = dl->getTypeSizeInBits(structTy) / 8;
       } else {
-        width = computeWidth(loc, dstTy, kindMap);
+        width = cuf::computeElementByteSize(loc, dstTy, kindMap);
       }
       mlir::Value widthValue = mlir::arith::ConstantOp::create(
           rewriter, loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width));

@github-actions
Copy link

github-actions bot commented Nov 10, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@clementval clementval left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Refactor computeWidth from CUFOpConversion into a shared helper
function computeElementByteSize in CUFCommon.
@wangzpgi wangzpgi merged commit d5125b3 into llvm:main Nov 10, 2025
10 checks passed
@wangzpgi wangzpgi deleted the helper-upstream branch November 10, 2025 22:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

flang:fir-hlfir flang Flang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants