/******************************************************************************
    Copyright (C) 2015 by Hugh Bailey <obs.jim@gmail.com>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/

#include <windows.h>
#include <dbghelp.h>
#include <shellapi.h>
#include <tlhelp32.h>
#include <inttypes.h>

#include "obs-config.h"
#include "util/dstr.h"
#include "util/platform.h"
#include "util/windows/win-version.h"

typedef BOOL (WINAPI *ENUMERATELOADEDMODULES64)(HANDLE process,
		PENUMLOADED_MODULES_CALLBACK64 enum_loaded_modules_callback,
		PVOID user_context);
typedef DWORD (WINAPI *SYMSETOPTIONS)(DWORD sym_options);
typedef BOOL (WINAPI *SYMINITIALIZE)(HANDLE process, PCTSTR user_search_path,
		BOOL invade_process);
typedef BOOL (WINAPI *SYMCLEANUP)(HANDLE process);
typedef BOOL (WINAPI *STACKWALK64)(DWORD machine_type, HANDLE process,
		HANDLE thread, LPSTACKFRAME64 stack_frame,
		PVOID context_record,
		PREAD_PROCESS_MEMORY_ROUTINE64 read_memory_routine,
		PFUNCTION_TABLE_ACCESS_ROUTINE64 function_table_access_routine,
		PGET_MODULE_BASE_ROUTINE64 get_module_base_routine,
		PTRANSLATE_ADDRESS_ROUTINE64 translate_address);
typedef BOOL (WINAPI *SYMREFRESHMODULELIST)(HANDLE process);

typedef PVOID (WINAPI *SYMFUNCTIONTABLEACCESS64)(HANDLE process,
		DWORD64 addr_base);
typedef DWORD64 (WINAPI *SYMGETMODULEBASE64)(HANDLE process, DWORD64 addr);
typedef BOOL (WINAPI *SYMFROMADDR)(HANDLE process, DWORD64 address,
		PDWORD64 displacement, PSYMBOL_INFOW symbol);
typedef BOOL (WINAPI *SYMGETMODULEINFO64)(HANDLE process, DWORD64 addr,
		PIMAGEHLP_MODULE64 module_info);

typedef DWORD64 (WINAPI *SYMLOADMODULE64)(HANDLE process, HANDLE file,
		PSTR image_name, PSTR module_name, DWORD64 base_of_dll,
		DWORD size_of_dll);

typedef BOOL (WINAPI *MINIDUMPWRITEDUMP)(HANDLE process, DWORD process_id,
		HANDLE file, MINIDUMP_TYPE dump_type,
		PMINIDUMP_EXCEPTION_INFORMATION exception_param,
		PMINIDUMP_USER_STREAM_INFORMATION user_stream_param,
		PMINIDUMP_CALLBACK_INFORMATION callback_param);

typedef HINSTANCE (WINAPI *SHELLEXECUTEA)(HWND hwnd, LPCTSTR operation,
		LPCTSTR file, LPCTSTR parameters, LPCTSTR directory,
		INT show_flags);

struct stack_trace {
	CONTEXT                               context;
	DWORD64                               instruction_ptr;
	STACKFRAME64                          frame;
	DWORD                                 image_type;
};

struct exception_handler_data {
	SYMINITIALIZE                         sym_initialize;
	SYMCLEANUP                            sym_cleanup;
	SYMSETOPTIONS                         sym_set_options;
	SYMFUNCTIONTABLEACCESS64              sym_function_table_access64;
	SYMGETMODULEBASE64                    sym_get_module_base64;
	SYMFROMADDR                           sym_from_addr;
	SYMGETMODULEINFO64                    sym_get_module_info64;
	SYMREFRESHMODULELIST                  sym_refresh_module_list;
	STACKWALK64                           stack_walk64;
	ENUMERATELOADEDMODULES64              enumerate_loaded_modules64;
	MINIDUMPWRITEDUMP                     minidump_write_dump;

	HMODULE                               dbghelp;
	SYMBOL_INFOW                          *sym_info;
	PEXCEPTION_POINTERS                   exception;
	struct win_version_info               win_version;
	SYSTEMTIME                            time_info;
	HANDLE                                process;

	struct stack_trace                    main_trace;

	struct dstr                           str;
	struct dstr                           cpu_info;
	struct dstr                           module_name;
	struct dstr                           module_list;
};

static inline void exception_handler_data_free(
		struct exception_handler_data *data)
{
	LocalFree(data->sym_info);
	dstr_free(&data->str);
	dstr_free(&data->cpu_info);
	dstr_free(&data->module_name);
	dstr_free(&data->module_list);
	FreeLibrary(data->dbghelp);
}

static inline void *get_proc(HMODULE module, const char *func)
{
	return (void*)GetProcAddress(module, func);
}

#define GET_DBGHELP_IMPORT(target, str) \
	do { \
		data->target = get_proc(data->dbghelp, str); \
		if (!data->target) \
			return false; \
	} while (false)

static inline bool get_dbghelp_imports(struct exception_handler_data *data)
{
	data->dbghelp = LoadLibraryW(L"DbgHelp");
	if (!data->dbghelp)
		return false;

	GET_DBGHELP_IMPORT(sym_initialize, "SymInitialize");
	GET_DBGHELP_IMPORT(sym_cleanup, "SymCleanup");
	GET_DBGHELP_IMPORT(sym_set_options, "SymSetOptions");
	GET_DBGHELP_IMPORT(sym_function_table_access64,
			"SymFunctionTableAccess64");
	GET_DBGHELP_IMPORT(sym_get_module_base64, "SymGetModuleBase64");
	GET_DBGHELP_IMPORT(sym_from_addr, "SymFromAddrW");
	GET_DBGHELP_IMPORT(sym_get_module_info64, "SymGetModuleInfo64");
	GET_DBGHELP_IMPORT(sym_refresh_module_list, "SymRefreshModuleList");
	GET_DBGHELP_IMPORT(stack_walk64, "StackWalk64");
	GET_DBGHELP_IMPORT(enumerate_loaded_modules64,
			"EnumerateLoadedModulesW64");
	GET_DBGHELP_IMPORT(minidump_write_dump, "MiniDumpWriteDump");

	return true;
}

static inline void init_instruction_data(struct stack_trace *trace)
{
#ifdef _WIN64
	trace->instruction_ptr = trace->context.Rip;
	trace->frame.AddrPC.Offset = trace->instruction_ptr;
	trace->frame.AddrFrame.Offset = trace->context.Rbp;
	trace->frame.AddrStack.Offset = trace->context.Rsp;
	trace->image_type = IMAGE_FILE_MACHINE_AMD64;
#else
	trace->instruction_ptr = trace->context.Eip;
	trace->frame.AddrPC.Offset = trace->instruction_ptr;
	trace->frame.AddrFrame.Offset = trace->context.Ebp;
	trace->frame.AddrStack.Offset = trace->context.Esp;
	trace->image_type = IMAGE_FILE_MACHINE_I386;
#endif

	trace->frame.AddrFrame.Mode = AddrModeFlat;
	trace->frame.AddrPC.Mode = AddrModeFlat;
	trace->frame.AddrStack.Mode = AddrModeFlat;
}

extern bool sym_initialize_called;

static inline void init_sym_info(struct exception_handler_data *data)
{
	data->sym_set_options(
			SYMOPT_UNDNAME |
			SYMOPT_FAIL_CRITICAL_ERRORS |
			SYMOPT_LOAD_ANYTHING);

	if (!sym_initialize_called)
		data->sym_initialize(data->process, NULL, true);
	else
		data->sym_refresh_module_list(data->process);

	data->sym_info = LocalAlloc(LPTR, sizeof(*data->sym_info) + 256);
	data->sym_info->SizeOfStruct = sizeof(SYMBOL_INFO);
	data->sym_info->MaxNameLen = 256;
}

static inline void init_version_info(struct exception_handler_data *data)
{
	get_win_ver(&data->win_version);
}

#define PROCESSOR_REG_KEY L"HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0"
#define CPU_ERROR "<unable to query>"

static inline void init_cpu_info(struct exception_handler_data *data)
{
	HKEY key;
	LSTATUS status;

	status = RegOpenKeyW(HKEY_LOCAL_MACHINE, PROCESSOR_REG_KEY, &key);
	if (status == ERROR_SUCCESS) {
		wchar_t str[1024];
		DWORD size = 1024;

		status = RegQueryValueExW(key, L"ProcessorNameString", NULL,
				NULL, (LPBYTE)str, &size);
		if (status == ERROR_SUCCESS)
			dstr_from_wcs(&data->cpu_info, str);
		else
			dstr_copy(&data->cpu_info, CPU_ERROR);
	} else {
		dstr_copy(&data->cpu_info, CPU_ERROR);
	}
}

static BOOL CALLBACK enum_all_modules(PCTSTR module_name, DWORD64 module_base,
		ULONG module_size, struct exception_handler_data *data)
{
	char name_utf8[MAX_PATH];
	os_wcs_to_utf8(module_name, 0, name_utf8, MAX_PATH);

	if (data->main_trace.instruction_ptr >= module_base &&
	    data->main_trace.instruction_ptr <  module_base + module_size) {

		dstr_copy(&data->module_name, name_utf8);
		strlwr(data->module_name.array);
	}

#ifdef _WIN64
	dstr_catf(&data->module_list, "%016"PRIX64"-%016"PRIX64" %s\r\n",
			module_base, module_base + module_size,
			name_utf8);
#else
	dstr_catf(&data->module_list, "%08"PRIX64"-%08"PRIX64" %s\r\n",
			module_base, module_base + module_size,
			name_utf8);
#endif
	return true;
}

static inline void init_module_info(struct exception_handler_data *data)
{
	data->enumerate_loaded_modules64(data->process,
			(PENUMLOADED_MODULES_CALLBACK64)enum_all_modules,
			data);
}

static inline void write_header(struct exception_handler_data *data)
{
	dstr_catf(&data->str, "Unhandled exception: %x\r\n"
			"Fault address: %"PRIX64" (%s)\r\n"
			"libobs version: "OBS_VERSION"\r\n"
			"Windows version: %d.%d build %d (revision %d)\r\n"
			"CPU: %s\r\n\r\n",
			data->exception->ExceptionRecord->ExceptionCode,
			data->main_trace.instruction_ptr,
			data->module_name.array,
			data->win_version.major, data->win_version.minor,
			data->win_version.build, data->win_version.revis,
			data->cpu_info.array);
}

struct module_info {
	DWORD64 addr;
	char name_utf8[MAX_PATH];
};

static BOOL CALLBACK enum_module(PCTSTR module_name, DWORD64 module_base,
		ULONG module_size, struct module_info *info)
{
	if (info->addr >= module_base &&
	    info->addr <  module_base + module_size) {

		os_wcs_to_utf8(module_name, 0, info->name_utf8, MAX_PATH);
		strlwr(info->name_utf8);
		return false;
	}

	return true;
}

static inline void get_module_name(struct exception_handler_data *data,
		struct module_info *info)
{
	data->enumerate_loaded_modules64(data->process,
			(PENUMLOADED_MODULES_CALLBACK64)enum_module, info);
}

static inline bool walk_stack(struct exception_handler_data *data,
		HANDLE thread, struct stack_trace *trace)
{
	struct module_info module_info = {0};
	DWORD64 func_offset;
	char sym_name[256];
	char *p;

	bool success = data->stack_walk64(trace->image_type,
			data->process, thread, &trace->frame, &trace->context,
			NULL, data->sym_function_table_access64,
			data->sym_get_module_base64, NULL);
	if (!success)
		return false;

	module_info.addr = trace->frame.AddrPC.Offset;
	get_module_name(data, &module_info);

	if (!!module_info.name_utf8[0]) {
		p = strrchr(module_info.name_utf8, '\\');
		p = p ? (p + 1) : module_info.name_utf8;
	} else {
		strcpy(module_info.name_utf8, "<unknown>");
		p = module_info.name_utf8;
	}

	success = !!data->sym_from_addr(data->process,
			trace->frame.AddrPC.Offset, &func_offset,
			data->sym_info);

	if (success)
		os_wcs_to_utf8(data->sym_info->Name, 0, sym_name, 256);

#ifdef _WIN64
#define SUCCESS_FORMAT \
	"%016I64X %016I64X %016I64X %016I64X " \
	"%016I64X %016I64X %s!%s+0x%I64x\r\n"
#define FAIL_FORMAT \
	"%016I64X %016I64X %016I64X %016I64X " \
	"%016I64X %016I64X %s!0x%I64x\r\n"
#else
#define SUCCESS_FORMAT \
	"%08.8I64X %08.8I64X %08.8I64X %08.8I64X " \
	"%08.8I64X %08.8I64X %s!%s+0x%I64x\r\n"
#define FAIL_FORMAT \
	"%08.8I64X %08.8I64X %08.8I64X %08.8I64X " \
	"%08.8I64X %08.8I64X %s!0x%I64x\r\n"

	trace->frame.AddrStack.Offset &= 0xFFFFFFFFF;
	trace->frame.AddrPC.Offset &= 0xFFFFFFFFF;
	trace->frame.Params[0] &= 0xFFFFFFFF;
	trace->frame.Params[1] &= 0xFFFFFFFF;
	trace->frame.Params[2] &= 0xFFFFFFFF;
	trace->frame.Params[3] &= 0xFFFFFFFF;
#endif

	if (success && (data->sym_info->Flags & SYMFLAG_EXPORT) == 0) {
		dstr_catf(&data->str, SUCCESS_FORMAT,
				trace->frame.AddrStack.Offset,
				trace->frame.AddrPC.Offset,
				trace->frame.Params[0],
				trace->frame.Params[1],
				trace->frame.Params[2],
				trace->frame.Params[3],
				p, sym_name, func_offset);
	} else {
		dstr_catf(&data->str, FAIL_FORMAT,
				trace->frame.AddrStack.Offset,
				trace->frame.AddrPC.Offset,
				trace->frame.Params[0],
				trace->frame.Params[1],
				trace->frame.Params[2],
				trace->frame.Params[3],
				p, trace->frame.AddrPC.Offset);
	}

	return true;
}

#ifdef _WIN64
#define TRACE_TOP \
	"Stack            EIP              Arg0             " \
	"Arg1             Arg2             Arg3             Address\r\n"
#else
#define TRACE_TOP \
	"Stack    EIP      Arg0     " \
	"Arg1     Arg2     Arg3     Address\r\n"
#endif

static inline void write_thread_trace(struct exception_handler_data *data,
		THREADENTRY32 *entry, bool first_thread)
{
	bool crash_thread = entry->th32ThreadID == GetCurrentThreadId();
	struct stack_trace trace = {0};
	struct stack_trace *ptrace;
	HANDLE thread;

	if (first_thread != crash_thread)
		return;

	if (entry->th32OwnerProcessID != GetCurrentProcessId())
		return;

	thread = OpenThread(THREAD_ALL_ACCESS, false, entry->th32ThreadID);
	if (!thread)
		return;

	trace.context.ContextFlags = CONTEXT_ALL;
	GetThreadContext(thread, &trace.context);
	init_instruction_data(&trace);

	dstr_catf(&data->str, "\r\nThread %lX%s\r\n"TRACE_TOP,
			entry->th32ThreadID,
			crash_thread ? " (Crashed)" : "");

	ptrace = crash_thread ? &data->main_trace : &trace;

	while (walk_stack(data, thread, ptrace));

	CloseHandle(thread);
}

static inline void write_thread_traces(struct exception_handler_data *data)
{
	THREADENTRY32 entry = {0};
	HANDLE snapshot = CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD,
			GetCurrentProcessId());
	bool success;

	if (snapshot == INVALID_HANDLE_VALUE)
		return;

	entry.dwSize = sizeof(entry);

	success = !!Thread32First(snapshot, &entry);
	while (success) {
		write_thread_trace(data, &entry, true);
		success = !!Thread32Next(snapshot, &entry);
	}

	success = !!Thread32First(snapshot, &entry);
	while (success) {
		write_thread_trace(data, &entry, false);
		success = !!Thread32Next(snapshot, &entry);
	}

	CloseHandle(snapshot);
}

static inline void write_module_list(struct exception_handler_data *data)
{
	dstr_cat(&data->str, "\r\nLoaded modules:\r\n");
#ifdef _WIN64
	dstr_cat(&data->str, "Base Address                      Module\r\n");
#else
	dstr_cat(&data->str, "Base Address      Module\r\n");
#endif
	dstr_cat_dstr(&data->str, &data->module_list);
}

/* ------------------------------------------------------------------------- */

static inline void handle_exception(struct exception_handler_data *data,
		PEXCEPTION_POINTERS exception)
{
	if (!get_dbghelp_imports(data))
		return;

	data->exception = exception;
	data->process = GetCurrentProcess();
	data->main_trace.context = *exception->ContextRecord;
	GetSystemTime(&data->time_info);

	init_sym_info(data);
	init_version_info(data);
	init_cpu_info(data);
	init_instruction_data(&data->main_trace);
	init_module_info(data);

	write_header(data);
	write_thread_traces(data);
	write_module_list(data);
}

static LONG CALLBACK exception_handler(PEXCEPTION_POINTERS exception)
{
	struct exception_handler_data data = {0};
	static bool inside_handler = false;

	/* don't use if a debugger is present */
	if (IsDebuggerPresent())
		return EXCEPTION_CONTINUE_SEARCH;

	if (inside_handler)
		return EXCEPTION_CONTINUE_SEARCH;

	inside_handler = true;

	handle_exception(&data, exception);
	bcrash(data.str.array);
	exception_handler_data_free(&data);

	inside_handler = false;

	return EXCEPTION_CONTINUE_SEARCH;
}

void initialize_crash_handler(void)
{
	static bool initialized = false;

	if (!initialized) {
		SetUnhandledExceptionFilter(exception_handler);
		initialized = true;
	}
}