21namespace ranges = std::ranges;
35using namespace asthelpers;
38QualType CoroutinesCodeGenerator::GetFramePointerType()
const
40 return Ptr(GetFrameType());
75FieldDecl* CoroutinesCodeGenerator::AddField(std::string_view name, QualType type)
77 return ::clang::insights::AddField(mASTData, name, type);
115 size_t& mSuspendsCount;
116 llvm::DenseMap<VarDecl*, MemberExpr*> mVarNamePrefix{};
120 size_t& suspendsCounter,
122 llvm::DenseMap<VarDecl*, MemberExpr*> varNamePrefix,
123 Stmt* prev =
nullptr)
125 , mASTData{coroutineASTData}
126 , mSuspendsCount{suspendsCounter}
127 , mVarNamePrefix{varNamePrefix}
129 if(
nullptr == mPrevStmt) {
145 for(
auto* child : stmt->body()) {
150 mBodyStmts.
Add(child);
152 if(
const auto* coret = dyn_cast_or_null<CoreturnStmt>(child);
153 coret and (coret->getOperand() ==
nullptr)) {
170 Visit(stmt->getCond());
177 Visit(stmt->getCond());
184 Visit(stmt->getCond());
191 Visit(stmt->getCond());
201 Visit(stmt->getInit());
206 auto* oldInit = stmt->getInit();
207 auto* newInit = mBodyStmts.
mStmts.back();
208 mBodyStmts.
mStmts.pop_back();
215 Visit(stmt->getCond());
217 Visit(stmt->getInc());
226 Visit(stmt->getRangeStmt());
235 if(
auto* vd = dyn_cast_or_null<VarDecl>(stmt->getDecl())) {
236 RETURN_IF(not vd->isLocalVarDeclOrParm() or vd->isStaticLocal() or not
Contains(mVarNamePrefix, vd));
238 auto* memberExpr = mVarNamePrefix[vd];
246 for(
auto* decl : stmt->decls()) {
247 if(
auto* varDecl = dyn_cast_or_null<VarDecl>(decl)) {
248 if(varDecl->isStaticLocal()) {
254 auto* field =
AddField(mASTData,
GetName(*varDecl), varDecl->getType());
256 auto* assign =
Assign(me, field, varDecl->getInit());
258 mVarNamePrefix.insert(std::make_pair(varDecl, me));
260 Visit(varDecl->getInit());
263 mBodyStmts.
Add(assign);
265 }
else if(
const auto* recordDecl = dyn_cast_or_null<CXXRecordDecl>(decl)) {
266 mASTData.
mFrameType->addDecl(
const_cast<CXXRecordDecl*
>(recordDecl));
276 ReplaceNode(mPrevStmt, stmt, indirectThisMemberExpr);
285 auto* tmp = mPrevStmt;
288 for(
auto* arg : stmt->arguments()) {
297 auto* tmp = mPrevStmt;
298 mPrevStmt = stmt->getCallee();
300 Visit(stmt->getCallee());
309 Visit(stmt->getOperand());
310 Visit(stmt->getPromiseCall());
317 if(isa<ExprWithCleanups>(mStaged)) {
318 mBodyStmts.
Add(stmt);
322 Visit(stmt->getOperand());
329 if(
const bool returnsVoid{stmt->getResumeExpr()->getType()->isVoidType()}; returnsVoid) {
330 Visit(stmt->getOperand());
336 mBodyStmts.
Add(stmt);
341 auto* resultVar =
Variable(name, stmt->getType());
346 Visit(stmt->getCommonExpr());
347 Visit(stmt->getOperand());
348 Visit(stmt->getSuspendExpr());
349 Visit(stmt->getReadyExpr());
350 Visit(stmt->getResumeExpr());
355 auto* varDecl = stmt->getPromiseDecl();
360 mVarNamePrefix.insert(std::make_pair(varDecl, me));
373 for(
auto* param : stmt->getParamMoves()) {
374 if(
auto* declStmt = dyn_cast_or_null<DeclStmt>(param)) {
375 if(
auto* varDecl2 = dyn_cast_or_null<VarDecl>(declStmt->getSingleDecl())) {
377 if(
auto* declRef =
FindDeclRef(varDecl2->getAnyInitializer())) {
378 auto* varDecl = dyn_cast<ParmVarDecl>(declRef->getDecl());
380 auto* field =
AddField(mASTData,
GetName(*varDecl), varDecl->getType());
383 mVarNamePrefix.insert(std::make_pair(
const_cast<ParmVarDecl*
>(varDecl), me));
389 Visit(stmt->getBody());
391 Visit(stmt->getReturnStmt());
392 Visit(stmt->getReturnValue());
393 Visit(stmt->getReturnValueInit());
394 Visit(stmt->getExceptionHandler());
395 Visit(stmt->getReturnStmtOnAllocFailure());
396 Visit(stmt->getFallthroughHandler());
397 Visit(stmt->getInitSuspendStmt());
398 Visit(stmt->getFinalSuspendStmt());
403 auto* tmp = mPrevStmt;
406 for(
auto* child : stmt->children()) {
427 if(
const auto* args = fd.getTemplateSpecializationArgs()) {
430 for(
OnceFalse needsUnderscore{};
const auto& arg : args->asArray()) {
431 if(needsUnderscore) {
439 auto str = std::move(ofm.GetString());
446 if(fd.isOverloadedOperator()) {
469 auto resumeFnType =
Ptr(resumeFnFd->getType());
473 auto destroyFnType =
Ptr(destroyFnFd->getType());
481 auto* reicast =
ReinterpretCast(GetFramePointerType(), stmt->getAllocate());
483 coroFrameVar->setInit(reicast);
489 if(stmt->getReturnStmtOnAllocFailure()) {
490 auto* nptr =
new(ctx) CXXNullPtrLiteralExpr({});
501 mASTData, mSuspendsCounter,
const_cast<CoroutineBodyStmt*
>(stmt), llvm::DenseMap<VarDecl*, MemberExpr*>{}};
505 InsertArgWithNull(setSuspendIndexToZero);
508 auto* initializeInitialAwaitResume =
510 InsertArgWithNull(initializeInitialAwaitResume);
513 for(
auto* param : stmt->getParamMoves()) {
514 if(
const auto* declStmt = dyn_cast_or_null<DeclStmt>(param)) {
515 if(
const auto* varDecl = dyn_cast_or_null<VarDecl>(declStmt->getSingleDecl())) {
516 const auto varName =
GetName(*varDecl);
533 ArrayRef<ParmVarDecl*> funParams = fd.parameters();
534 SmallVector<ParmVarDecl*, 16> funParamStorage{};
535 QualType cxxMethodType{};
537 if(
const auto* cxxMethodDecl = dyn_cast_or_null<CXXMethodDecl>(&fd)) {
538 funParamStorage.reserve(funParams.size() + 1);
540 cxxMethodType = cxxMethodDecl->getFunctionObjectParameterType();
546 ranges::copy(funParams, std::back_inserter(funParamStorage));
548 funParams = funParamStorage;
551 auto getNonRefType = [&](
auto* var) -> QualType {
552 if(
const auto* et = var->getType().getNonReferenceType()->template getAs<ElaboratedType>()) {
553 return et->getNamedType();
555 return QualType(var->getType().getNonReferenceType().getTypePtrOrNull(), 0);
559 SmallVector<Expr*, 16> exprs{};
561 for(
auto* promiseTypeRecordDecl = mASTData.
mPromiseField->getType()->getAsCXXRecordDecl();
562 auto* ctor : promiseTypeRecordDecl->ctors()) {
564 if(not ranges::equal(
565 ctor->parameters(), funParams, [&](
auto& a,
auto& b) { return getNonRefType(a) == getNonRefType(b); })) {
573 if(not ctor->param_empty() and
574 (getNonRefType(ctor->getParamDecl(0)) == QualType(cxxMethodType.getTypePtrOrNull(), 0))) {
576 mASTData.
mThisExprs.push_back(CXXThisExpr::Create(ctx, {},
Ptr(cxxMethodType),
false));
579 (void)
static_cast<bool>(derefFirstParam);
582 for(
const auto& fparam : funParams) {
583 if(derefFirstParam) {
591 if(funParams.size()) {
608 auto* ctorArgs =
new(ctx) InitListExpr{ctx, {}, exprs, {}};
612 InsertArgWithNull(newFrame);
634 InsertArgWithNull(assignResumeFn);
637 InsertArgWithNull(assignDestroyFn);
641 R
"A(Call the made up function with the coroutine body for initial suspend.
642 This function will be called subsequently by coroutine_handle<>::resume()
643 which calls __builtin_coro_resume(__handle_))A"sv);
646 InsertArgWithNull(callCoroFSM);
673 if(
auto* dtor = mASTData.
mPromiseField->getType()->getAsCXXRecordDecl()->getDestructor()) {
674 deallocFuncBodyStmts.Add(
Comment(
"Deallocating the coroutine promise type"sv));
677 auto* deallocPromise =
AccessMember(promiseAccess, dtor,
false);
678 auto* dtorCall =
CallMemberFun(deallocPromise, dtor->getType());
679 deallocFuncBodyStmts.Add(dtorCall);
682 deallocFuncBodyStmts.Add(
683 Comment(
"promise_type is trivially destructible, no dtor required."sv));
698 auto* dtorCall =
CallMemberFun(deallocPromise, GetFrameType());
699 deallocFuncBodyStmts.Add(dtorCall);
701 deallocFuncBodyStmts.Add(
Comment(
"Deallocating the coroutine frame"sv));
702 deallocFuncBodyStmts.Add(
703 Comment(
"Note: The actual argument to delete is __builtin_coro_frame with the promise as parameter"sv));
705 deallocFuncBodyStmts.Add(stmt->getDeallocate());
718 auto* initialSuspendCase =
Case(0,
Break());
722 switchBodyStmts.
Add(
Case(i + 1,
Goto(BuildResumeLabelName(i + 1))));
726 sstmt->setBody(switchBody);
729 Comment(
"Create a switch to get to the correct resume point"sv), sstmt, stmt->getInitSuspendStmt()};
732 mState = eState::InitialSuspend;
738 mInsertVarDecl =
false;
739 mSupressRecordDecls =
true;
741 for(
const auto* c : stmt->getBody()->children()) {
742 funcBodyStmts.Add(c);
745 if(
const auto* coReturnVoid = dyn_cast_or_null<CoreturnStmt>(stmt->getFallthroughHandler())) {
746 funcBodyStmts.Add(coReturnVoid);
750 funcBodyStmts.Add(gotoFinalSuspend);
752 auto* body = [&]() ->
Stmt* {
756 if(
const auto* exceptionHandler = stmt->getExceptionHandler()) {
763 return Try(tryBody,
Catch(catchBodyStmts));
775 mState = eState::FinalSuspend;
779 mInsertVarDecl =
true;
785 if(not mSupressRecordDecls) {
806 if(
const auto* callee = dyn_cast_or_null<DeclRefExpr>(stmt->getCallee()->IgnoreCasts())) {
811 }
else if(
GetPlainName(*callee) ==
"__builtin_coro_free"sv) {
815 }
else if(
GetPlainName(*callee) ==
"__builtin_coro_size"sv) {
825static std::optional<std::string>
826FindValue(llvm::DenseMap<
const Expr*, std::pair<const DeclRefExpr*, std::string>>& map,
const Expr* key)
828 if(
const auto& s = map.find(key); s != map.end()) {
829 return s->second.second;
838 const auto* sourceExpr = stmt->getSourceExpr();
840 if(
const auto& s =
FindValue(mOpaqueValues, sourceExpr)) {
856 if(
auto [thisDeref, v] = value; (thisDeref == dref) and (v == lookupName)) {
863 mOpaqueValues.insert(std::make_pair(sourceExpr, std::make_pair(dref, accessName)));
868 auto* promiseField = AddField(name, stmt->getType());
872 codeGenerator.InsertArg(assignPromiseSuspend);
873 ofm.AppendSemiNewLine();
883std::string CoroutinesCodeGenerator::BuildResumeLabelName(
int index)
const
893 if(isa<CoawaitExpr>(stmt)) {
894 return kwCoAwaitSpace;
896 return kwCoYieldSpace;
900 mPosBeforeSuspendExpr = mOutputFormatHelper.CurrentPos();
919 mSupressCasts =
true;
921 auto* il =
Int32(++mSuspendsCount);
922 auto* bop =
Assign(mASTData.mSuspendIndexAccess, mASTData.mSuspendIndexField, il);
927 const bool returnsVoid{stmt->getSuspendExpr()->getType()->isVoidType()};
932 StmtsContainer bodyStmts{};
933 Expr* initializeInitialAwaitResume =
nullptr;
935 const bool canThrow{[&] {
936 if(
const auto* e = dyn_cast_or_null<ExprWithCleanups>(stmt->getSuspendExpr())) {
937 if(
const auto* ce = dyn_cast_or_null<CallExpr>(e->getSubExpr())) {
938 if(
const FunctionDecl* fd = ce->getDirectCallee()) {
939 if(
const FunctionProtoType* fpt = fd->getType()->getAs<FunctionProtoType>()) {
940 return not fpt->isNothrow(
false);
949 auto addInitialAwaitSuspendCalled = [&] {
950 if(eState::InitialSuspend == mState) {
951 mState = eState::Body;
953 initializeInitialAwaitResume =
954 Assign(mASTData.mFrameAccessDeclRef, mASTData.mInitialAwaitResumeCalledField,
Bool(
true));
955 bodyStmts.Add(initializeInitialAwaitResume);
959 auto insertTryCatchIfNecessary = [&](StmtsContainer& cont) {
963 StmtsContainer catchBodyStmts{
964 Assign(mASTData.mSuspendIndexAccess, mASTData.mSuspendIndexField,
Int32(mSuspendsCount - 1)),
Throw()};
967 cont.Add(
Try(tryBody,
Catch(catchBodyStmts)));
973 bodyStmts.Add(stmt->getSuspendExpr());
975 insertTryCatchIfNecessary(bodyStmts);
977 addInitialAwaitSuspendCalled();
980 InsertArg(
If(
Not(stmt->getReadyExpr()), bodyStmts));
983 addInitialAwaitSuspendCalled();
986 auto* ifSuspend =
If(stmt->getSuspendExpr(), bodyStmts);
988 StmtsContainer innerBodyStmts{bop, ifSuspend};
989 insertTryCatchIfNecessary(innerBodyStmts);
991 InsertArg(
If(
Not(stmt->getReadyExpr()), innerBodyStmts));
994 if(not returnsVoid and initializeInitialAwaitResume) {
996 InsertArgWithNull(initializeInitialAwaitResume);
997 mOutputFormatHelper.AppendNewLine();
1000 auto* suspendLabel =
Label(BuildResumeLabelName(mSuspendsCount));
1001 InsertArg(suspendLabel);
1003 if(eState::FinalSuspend == mState) {
1004 auto* memExpr =
AccessMember(mASTData.mFrameAccessDeclRef, mASTData.mDestroyFnField,
true);
1005 auto* callCoroFSM =
Call(memExpr, {mASTData.mFrameAccessDeclRef});
1006 InsertArg(callCoroFSM);
1010 const auto* resumeExpr = stmt->getResumeExpr();
1012 if(not resumeExpr->getType()->isVoidType()) {
1013 const auto* sourceExpr = stmt->getOpaqueValue()->getSourceExpr();
1015 if(
const auto& s =
FindValue(mOpaqueValues, sourceExpr)) {
1019 AddField(fieldName, resumeExpr->getType());
1023 InsertArg(resumeExpr);
1027void CoroutinesCodeGenerator::InsertArg(
const CoreturnStmt* stmt)
1031 if(stmt->getPromiseCall()) {
1032 InsertArg(stmt->getPromiseCall());
1034 if(stmt->isImplicit()) {
1035 mOutputFormatHelper.AppendComment(
"implicit"sv);
1041void CoroutinesCodeGenerator::InsertArgWithNull(
const Stmt* stmt)
const ASTContext & GetGlobalAST()
Get access to the ASTContext.
constexpr std::string_view hlpDestroyFn
constexpr std::string_view kwCoReturnSpace
constexpr std::string_view kwInternalThis
constexpr std::string_view hlpResumeFn
constexpr std::string_view hlpAssing
#define RETURN_IF(cond)
! A helper inspired by https://github.com/Microsoft/wil/wiki/Error-handling-helpers
virtual void InsertArg(const Decl *stmt)
void InsertInstantiationPoint(const SourceManager &sm, const SourceLocation &instLoc, std::string_view text={})
Inserts the instantiation point of a template.
void InsertTemplateArg(const TemplateArgument &arg)
OutputFormatHelper & mOutputFormatHelper
A special container which creates either a CodeGenerator or a CfrontCodeGenerator depending on the co...
A special generator for coroutines. It is only activated, if -show-coroutines-transformation is given...
void InsertArg(const ImplicitCastExpr *stmt) override
void InsertCoroutine(const FunctionDecl &fd, const CoroutineBodyStmt *body)
~CoroutinesCodeGenerator() override
BinaryOperator * Equal(Expr *var, Expr *assignExpr)
UnaryOperator * AddrOf(const Expr *stmt)
CXXReinterpretCastExpr * ReinterpretCast(QualType toType, const Expr *toExpr, bool makePointer)
CXXRecordDecl * Struct(std::string_view name)
FieldDecl * mkFieldDecl(DeclContext *dc, std::string_view name, QualType type)
DeclRefExpr * mkDeclRefExpr(const ValueDecl *vd)
CXXNewExpr * New(ArrayRef< Expr * > placementArgs, const Expr *expr, QualType t)
std::vector< std::pair< std::string_view, QualType > > params_vector
CallExpr * Call(const FunctionDecl *fd, ArrayRef< Expr * > params)
UnaryOperator * Ref(const Expr *e)
VarDecl * Variable(std::string_view name, QualType type, DeclContext *dc)
MemberExpr * AccessMember(const Expr *expr, const ValueDecl *vd, bool isArrow)
DeclRefExpr * mkVarDeclRefExpr(std::string_view name, QualType type)
ReturnStmt * Return(Expr *stmt)
CaseStmt * Case(int value, Stmt *stmt)
Stmt * Comment(std::string_view comment)
IfStmt * If(const Expr *condition, ArrayRef< Stmt * > bodyStmts)
GotoStmt * Goto(std::string_view labelName)
CXXCatchStmt * Catch(ArrayRef< Stmt * > body)
CXXThrowExpr * Throw(const Expr *expr)
void ReplaceNode(Stmt *parent, Stmt *oldNode, Stmt *newNode)
UnaryExprOrTypeTraitExpr * Sizeof(QualType toType)
LabelStmt * Label(std::string_view name)
UnaryOperator * Not(const Expr *stmt)
CXXBoolLiteralExpr * Bool(bool b)
FunctionDecl * Function(std::string_view name, QualType returnType, const params_vector ¶meters)
CompoundStmt * mkCompoundStmt(ArrayRef< Stmt * > bodyStmts, SourceLocation beginLoc, SourceLocation endLoc)
SwitchStmt * Switch(Expr *stmt)
IntegerLiteral * Int32(uint64_t value)
ParmVarDecl * Parameter(const FunctionDecl *fd, std::string_view name, QualType type)
UnaryOperator * Dref(const Expr *stmt)
QualType Ptr(QualType srcType)
CXXStaticCastExpr * StaticCast(QualType toType, const Expr *toExpr, bool makePointer)
CXXMemberCallExpr * CallMemberFun(Expr *memExpr, QualType retType)
BinaryOperator * Assign(const VarDecl *var, Expr *assignExpr)
CXXTryStmt * Try(const Stmt *tryBody, CXXCatchStmt *catchAllBody)
const std::string SUSPEND_INDEX_NAME
const std::string RESUME_LABEL_PREFIX
bool Contains(const std::string_view source, const std::string_view search)
const DeclRefExpr * FindDeclRef(const Stmt *stmt)
Go deep in a Stmt if necessary and look to all childs for a DeclRefExpr.
std::string GetPlainName(const DeclRefExpr &DRE)
static auto * CreateCoroFunctionDecl(std::string funcName, QualType type)
static void SetFunctionBody(FunctionDecl *fd, StmtsContainer &bodyStmts)
constexpr std::string_view CORO_FRAME_NAME
std::string MakeLineColumnName(const SourceManager &sm, const SourceLocation &loc, const std::string_view &prefix)
std::string GetName(const NamedDecl &nd, const QualifiedName qualifiedName)
void ReplaceAll(std::string &str, std::string_view from, std::string_view to)
std::string BuildInternalVarName(const std::string_view &varName)
const std::string CORO_FRAME_ACCESS_THIS
std::string BuildTemplateParamObjectName(std::string name)
const std::string FINAL_SUSPEND_NAME
const std::string INITIAL_AWAIT_SUSPEND_CALLED_NAME
static std::string BuildSuspendVarName(const OpaqueValueExpr *stmt)
static std::optional< std::string > FindValue(llvm::DenseMap< const Expr *, std::pair< const DeclRefExpr *, std::string > > &map, const Expr *key)
void EnableGlobalInsert(GlobalInserts idx)
const std::string CORO_FRAME_ACCESS
static FieldDecl * AddField(CoroutineASTData &astData, std::string_view name, QualType type)
std::string StrCat(const auto &... args)
FieldDecl * mInitialAwaitResumeCalledField
std::vector< const CXXThisExpr * > mThisExprs
FieldDecl * mPromiseField
FieldDecl * mResumeFnField
FieldDecl * mSuspendIndexField
MemberExpr * mSuspendIndexAccess
DeclRefExpr * mFrameAccessDeclRef
CXXRecordDecl * mFrameType
FieldDecl * mDestroyFnField
MemberExpr * mInitialAwaitResumeCalledAccess
! A helper type to have a container for ArrayRef
void Add(const Stmt *stmt)
SmallVector< Stmt *, 64 > mStmts