-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[flang][CUDA] Unify element size computation in CUF helpers #167398
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-flang-fir-hlfir Author: Zhen Wang (wangzpgi) ChangesRefactor computeWidth from CUFOpConversion into a shared helper function computeElementByteSize in CUFCommon. Full diff: https://github.com/llvm/llvm-project/pull/167398.diff 3 Files Affected:
diff --git a/flang/include/flang/Optimizer/Builder/CUFCommon.h b/flang/include/flang/Optimizer/Builder/CUFCommon.h
index 5c56dd6b695f8..6e2442745f9a0 100644
--- a/flang/include/flang/Optimizer/Builder/CUFCommon.h
+++ b/flang/include/flang/Optimizer/Builder/CUFCommon.h
@@ -18,6 +18,7 @@ static constexpr llvm::StringRef cudaSharedMemSuffix = "__shared_mem";
namespace fir {
class FirOpBuilder;
+class KindMapping;
} // namespace fir
namespace cuf {
@@ -34,6 +35,10 @@ bool isRegisteredDeviceAttr(std::optional<cuf::DataAttribute> attr);
void genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder);
+int computeElementByteSize(mlir::Location loc, mlir::Type type,
+ fir::KindMapping &kindMap,
+ bool emitErrorOnFailure = true);
+
} // namespace cuf
#endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_
diff --git a/flang/lib/Optimizer/Builder/CUFCommon.cpp b/flang/lib/Optimizer/Builder/CUFCommon.cpp
index cf7588f275d22..461deb8e4cb55 100644
--- a/flang/lib/Optimizer/Builder/CUFCommon.cpp
+++ b/flang/lib/Optimizer/Builder/CUFCommon.cpp
@@ -9,6 +9,7 @@
#include "flang/Optimizer/Builder/CUFCommon.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Optimizer/Dialect/Support/KindMapping.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
@@ -91,3 +92,25 @@ void cuf::genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder) {
}
}
}
+
+int cuf::computeElementByteSize(mlir::Location loc, mlir::Type type,
+ fir::KindMapping &kindMap,
+ bool emitErrorOnFailure) {
+ auto eleTy = fir::unwrapSequenceType(type);
+ if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
+ return t.getWidth() / 8;
+ if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
+ return t.getWidth() / 8;
+ if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
+ return kindMap.getLogicalBitsize(t.getFKind()) / 8;
+ if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
+ int elemSize =
+ mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
+ return 2 * elemSize;
+ }
+ if (auto t{mlir::dyn_cast<fir::CharacterType>(eleTy)})
+ return kindMap.getCharacterBitsize(t.getFKind()) / 8;
+ if (emitErrorOnFailure)
+ mlir::emitError(loc, "unsupported type");
+ return 0;
+}
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 8d00272b09f42..3c3782cc234f8 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -263,28 +263,6 @@ static bool inDeviceContext(mlir::Operation *op) {
return false;
}
-static int computeWidth(mlir::Location loc, mlir::Type type,
- fir::KindMapping &kindMap) {
- auto eleTy = fir::unwrapSequenceType(type);
- if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
- return t.getWidth() / 8;
- if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
- return t.getWidth() / 8;
- if (eleTy.isInteger(1))
- return 1;
- if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
- return kindMap.getLogicalBitsize(t.getFKind()) / 8;
- if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
- int elemSize =
- mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
- return 2 * elemSize;
- }
- if (auto t{mlir::dyn_cast_or_null<fir::CharacterType>(eleTy)})
- return kindMap.getCharacterBitsize(t.getFKind()) / 8;
- mlir::emitError(loc, "unsupported type");
- return 0;
-}
-
struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
using OpRewritePattern::OpRewritePattern;
@@ -320,7 +298,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
mlir::Value bytes;
fir::KindMapping kindMap{fir::getKindMapping(mod)};
if (fir::isa_trivial(op.getInType())) {
- int width = computeWidth(loc, op.getInType(), kindMap);
+ int width = cuf::computeElementByteSize(loc, op.getInType(), kindMap);
bytes =
builder.createIntegerConstant(loc, builder.getIndexType(), width);
} else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
@@ -330,7 +308,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy());
size = dl->getTypeSizeInBits(structTy) / 8;
} else {
- size = computeWidth(loc, seqTy.getEleTy(), kindMap);
+ size = cuf::computeElementByteSize(loc, seqTy.getEleTy(), kindMap);
}
mlir::Value width =
builder.createIntegerConstant(loc, builder.getIndexType(), size);
@@ -619,8 +597,8 @@ struct CUFDataTransferOpConversion
const mlir::SymbolTable &symtab,
mlir::DataLayout *dl,
const fir::LLVMTypeConverter *typeConverter)
- : OpRewritePattern(context), symtab{symtab}, dl{dl},
- typeConverter{typeConverter} {}
+ : OpRewritePattern(context), symtab{symtab}, dl{dl}, typeConverter{
+ typeConverter} {}
mlir::LogicalResult
matchAndRewrite(cuf::DataTransferOp op,
@@ -704,7 +682,7 @@ struct CUFDataTransferOpConversion
typeConverter->convertType(fir::unwrapSequenceType(dstTy));
width = dl->getTypeSizeInBits(structTy) / 8;
} else {
- width = computeWidth(loc, dstTy, kindMap);
+ width = cuf::computeElementByteSize(loc, dstTy, kindMap);
}
mlir::Value widthValue = mlir::arith::ConstantOp::create(
rewriter, loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width));
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
clementval
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Refactor computeWidth from CUFOpConversion into a shared helper function computeElementByteSize in CUFCommon.
851703e to
3961076
Compare
060ac13 to
c68eaba
Compare
Refactor computeWidth from CUFOpConversion into a shared helper function computeElementByteSize in CUFCommon.