diff --git a/CMakeLists.txt b/CMakeLists.txt index 059729c..5b80dad 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,21 +11,24 @@ include_directories(includes/log) aux_source_directory(${CMAKE_SOURCE_DIR}/src/coroutine SRC_LIST) -# set(ASM_FILES ${CMAKE_SOURCE_DIR}/src/swap.S) +set(ASM_FILES ${CMAKE_SOURCE_DIR}/src/coroutine/coctx_swap.S) set(EXECUTABLE_OUTPUT_PATH ${CMAKE_SOURCE_DIR}/bin) -set(LIBRARY_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/lib) +set(LIBRARY_OUTPUT_PATH ${CMAKE_SOURCE_DIR}/lib) +find_library(TINYRPC_LIB tinyrpc ${CMAKE_SOURCE_DIR}/lib) add_library(tinyrpc ${SRC_LIST} - # ${ASM_FILES} - + ${ASM_FILES} ) -aux_source_directory(${CMAKE_SOURCE_DIR}/test/logtest TEST_SRC_LIST) + +aux_source_directory(${CMAKE_SOURCE_DIR}/test/coroutine TEST_SRC_LIST) + + add_executable(test_tinyrpc ${TEST_SRC_LIST} - ) +target_link_libraries(test_tinyrpc ${TINYRPC_LIB}) \ No newline at end of file diff --git a/includes/coroutine/coctx.h b/includes/coroutine/coctx.h index 15d0706..8c56708 100644 --- a/includes/coroutine/coctx.h +++ b/includes/coroutine/coctx.h @@ -1,14 +1,15 @@ #pragma once namespace tinyrpc { - enum class reg : unsigned int { // https://wiki.osdev.org/System_V_ABI - kRBP = 6, // 栈底指针 - kRDI = 7, // rdi,调用函数时的第一个参数 - kRSI = 8, // rsi, 调用函数时的第二个参数 这两个是 根据调用约定确定的 - kRETAddr = 9, // 下一个要执行的命令地址,它将被分配给 rip - kRSP = 13, // 堆栈顶部指针 - /* + namespace reg { + enum { // https://wiki.osdev.org/System_V_ABI + kRBP = 6, // 栈底指针 + kRDI = 7, // rdi,调用函数时的第一个参数 + kRSI = 8, // rsi, 调用函数时的第二个参数 这两个是 根据调用约定确定的 + kRETAddr = 9, // 下一个要执行的命令地址,它将被分配给 rip + kRSP = 13, // 堆栈顶部指针 + /* High memory ----------------- | 调用者的 rbp | <- 被调用函数栈帧的起点 @@ -19,7 +20,11 @@ namespace tinyrpc { ----------------- Low memory */ + }; + } + + struct coctx { // Coroutine Context void* regs[14]{}; // 初始化为 0 diff --git a/includes/coroutine/coroutine.hpp b/includes/coroutine/coroutine.hpp index e786f09..1825b4f 100644 --- a/includes/coroutine/coroutine.hpp +++ b/includes/coroutine/coroutine.hpp @@ -5,21 +5,31 @@ namespace tinyrpc { class Coroutine { + friend void coFunction(Coroutine* co); private: Coroutine(); public: - Coroutine(std::size_t stack_size, void* stack_sp); - Coroutine(std::size_t stack_size, void* stack_sp, std::function cb); - + Coroutine(std::size_t stack_size, char* stack_sp); + Coroutine(std::size_t stack_size, char* stack_sp, std::function cb); int getCorID() const {return m_cor_id;} + void operator()() const { + m_callback(); + } + + // coctx* getContext() {return &m_ctx;} + + static void yeild(); + void resume(); + + ~Coroutine(); private: coctx m_ctx {}; // 这个协程的上下文信息 int m_cor_id {0}; // 这个协程的 id char* m_stack_sp {nullptr}; // 这个协程的栈空间指针 std::size_t m_stack_size {0}; bool m_is_in_cofunc {true}; // 调用 CoFunction 时为真,CoFunction 完成时为假。 - std::function m_callback {}; + std::function m_callback {}; // 这个协程的回调 }; diff --git a/includes/log/logger.h b/includes/log/logger.h index 02d61f3..f3aaef4 100644 --- a/includes/log/logger.h +++ b/includes/log/logger.h @@ -8,11 +8,13 @@ struct logger { logger() = default; template - std::ostream& operator <<(T msg) { - return std::cout << __FILE__ << ":" << __LINE__ << " " << msg; + std::ostream& operator << (T&& msg) { + return std::cout << msg; } ~logger() { std::cout << std::endl; } -}; \ No newline at end of file +}; + +#define logger() (logger() << __FILE__ << ":" << __LINE__ << " ") \ No newline at end of file diff --git a/src/coroutine/coroutine.cc b/src/coroutine/coroutine.cc new file mode 100644 index 0000000..05ec6d2 --- /dev/null +++ b/src/coroutine/coroutine.cc @@ -0,0 +1,78 @@ +#include "coroutine.hpp" +#include "coctx.h" +#include "logger.h" +#include +#include + +namespace tinyrpc { + static thread_local Coroutine* t_main_coroutine = nullptr; // thread_local: 每个线程有一个主协程 + static thread_local Coroutine* t_curr_coroutine = nullptr; + static std::atomic_int t_coroutine_count {0}; + + void coFunction(Coroutine* co) { + if (co != nullptr) { + co->m_is_in_cofunc = true; + (*co)(); + co->m_is_in_cofunc = false; + } + Coroutine::yeild(); + } + + Coroutine::Coroutine() { // 构造主协程 + m_cor_id = t_coroutine_count++; + // t_main_coroutine = this; + t_main_coroutine = t_curr_coroutine = this; + + logger() << "main coroutine has built"; + } + + Coroutine::Coroutine(std::size_t stack_size, char* stack_sp, std::function cb) : + m_stack_sp(stack_sp), + m_stack_size(stack_size), + m_callback(cb) + + { // 构造协程 + m_cor_id = t_coroutine_count++; + + if (t_main_coroutine == nullptr) { + t_main_coroutine = new Coroutine(); + } + + char* top = stack_sp + stack_size; + top = reinterpret_cast((reinterpret_cast(top) & (~0xfull))); // 8字节对齐 + + m_ctx.regs[reg::kRBP] = top; + m_ctx.regs[reg::kRSP] = top; + m_ctx.regs[reg::kRDI] = this; + m_ctx.regs[reg::kRETAddr] = reinterpret_cast(&coFunction); + m_ctx.regs[reg::kRDI] = reinterpret_cast(this); + logger() << "user coroutine has built"; + } + + void Coroutine::yeild() { + + if (t_curr_coroutine == t_main_coroutine) { + logger() << "current coroutine is main coroutine !"; + return; + } + Coroutine* cur = t_curr_coroutine; + t_curr_coroutine = t_main_coroutine; + coctx_swap(&(cur->m_ctx), &(t_main_coroutine->m_ctx)); + + } + + void Coroutine::resume() { + if (t_curr_coroutine != t_main_coroutine) { + logger() << "swap error, current coroutine must be main coroutine !"; + return; + } + t_curr_coroutine = this; + coctx_swap(&(t_main_coroutine->m_ctx), &(this->m_ctx)); + + } + + Coroutine::~Coroutine() { + free(m_stack_sp); + } + +} \ No newline at end of file diff --git a/test/coroutine/main.cc b/test/coroutine/main.cc new file mode 100644 index 0000000..4ed679b --- /dev/null +++ b/test/coroutine/main.cc @@ -0,0 +1,44 @@ +#include "coroutine.hpp" +#include + +using namespace std; +using namespace tinyrpc; + +Coroutine* co1; +Coroutine* co2; + +void coro1() { + + cout << "this is coro1 begin" << endl; + + Coroutine::yeild(); + + cout << "this is coro1 end" << endl; + +} + +void coro2() { + + cout << "this is coro2 begin" << endl; + + Coroutine::yeild(); + + cout << "this is coro2 end" << endl; + +} + + +int main() { + int stk_size = 4 * 1024 * 1024; + char* stk1 = static_cast(malloc(stk_size)); + char* stk2 = static_cast(malloc(stk_size)); + co1 = new Coroutine(stk_size, stk1, coro1); + co2 = new Coroutine(stk_size, stk2, coro2); + co1->resume(); + co2->resume(); + co1->resume(); + co2->resume(); +} + + +