#include "captions-mssapi.hpp"

#define do_log(type, format, ...) \
	blog(type, "[Captions] " format, ##__VA_ARGS__)

#define error(format, ...) do_log(LOG_ERROR, format, ##__VA_ARGS__)
#define debug(format, ...) do_log(LOG_DEBUG, format, ##__VA_ARGS__)

mssapi_captions::mssapi_captions(captions_cb callback, const std::string &lang)
try : captions_handler(callback, AUDIO_FORMAT_16BIT, 16000) {
	HRESULT hr;

	std::wstring wlang;
	wlang.resize(lang.size());

	for (size_t i = 0; i < lang.size(); i++)
		wlang[i] = (wchar_t)lang[i];

	LCID lang_id = LocaleNameToLCID(wlang.c_str(), 0);

	wchar_t lang_str[32];
	_snwprintf(lang_str, 31, L"language=%x", (int)lang_id);

	stop = CreateEvent(nullptr, false, false, nullptr);
	if (!stop.Valid())
		throw "Failed to create event";

	hr = SpFindBestToken(SPCAT_RECOGNIZERS, lang_str, nullptr, &token);
	if (FAILED(hr))
		throw HRError("SpFindBestToken failed", hr);

	hr = CoCreateInstance(CLSID_SpInprocRecognizer, nullptr, CLSCTX_ALL,
			      __uuidof(ISpRecognizer), (void **)&recognizer);
	if (FAILED(hr))
		throw HRError("CoCreateInstance for recognizer failed", hr);

	hr = recognizer->SetRecognizer(token);
	if (FAILED(hr))
		throw HRError("SetRecognizer failed", hr);

	hr = recognizer->SetRecoState(SPRST_INACTIVE);
	if (FAILED(hr))
		throw HRError("SetRecoState(SPRST_INACTIVE) failed", hr);

	hr = recognizer->CreateRecoContext(&context);
	if (FAILED(hr))
		throw HRError("CreateRecoContext failed", hr);

	ULONGLONG interest = SPFEI(SPEI_RECOGNITION) |
			     SPFEI(SPEI_END_SR_STREAM);
	hr = context->SetInterest(interest, interest);
	if (FAILED(hr))
		throw HRError("SetInterest failed", hr);

	hr = context->SetNotifyWin32Event();
	if (FAILED(hr))
		throw HRError("SetNotifyWin32Event", hr);

	notify = context->GetNotifyEventHandle();
	if (notify == INVALID_HANDLE_VALUE)
		throw HRError("GetNotifyEventHandle failed", E_NOINTERFACE);

	size_t sample_rate = audio_output_get_sample_rate(obs_get_audio());
	audio = new CaptionStream((DWORD)sample_rate, this);
	audio->Release();

	hr = recognizer->SetInput(audio, false);
	if (FAILED(hr))
		throw HRError("SetInput failed", hr);

	hr = context->CreateGrammar(1, &grammar);
	if (FAILED(hr))
		throw HRError("CreateGrammar failed", hr);

	hr = grammar->LoadDictation(nullptr, SPLO_STATIC);
	if (FAILED(hr))
		throw HRError("LoadDictation failed", hr);

	try {
		t = std::thread([this]() { main_thread(); });
	} catch (...) {
		throw "Failed to create thread";
	}

} catch (const char *err) {
	blog(LOG_WARNING, "%s: %s", __FUNCTION__, err);
	throw CAPTIONS_ERROR_GENERIC_FAIL;

} catch (HRError err) {
	blog(LOG_WARNING, "%s: %s (%lX)", __FUNCTION__, err.str, err.hr);
	throw CAPTIONS_ERROR_GENERIC_FAIL;
}

mssapi_captions::~mssapi_captions()
{
	if (t.joinable()) {
		SetEvent(stop);
		t.join();
	}
}

void mssapi_captions::main_thread()
try {
	HRESULT hr;

	os_set_thread_name(__FUNCTION__);

	hr = grammar->SetDictationState(SPRS_ACTIVE);
	if (FAILED(hr))
		throw HRError("SetDictationState failed", hr);

	hr = recognizer->SetRecoState(SPRST_ACTIVE);
	if (FAILED(hr))
		throw HRError("SetRecoState(SPRST_ACTIVE) failed", hr);

	HANDLE events[] = {notify, stop};

	started = true;

	for (;;) {
		DWORD ret = WaitForMultipleObjects(2, events, false, INFINITE);
		if (ret != WAIT_OBJECT_0)
			break;

		CSpEvent event;
		bool exit = false;

		while (event.GetFrom(context) == S_OK) {
			if (event.eEventId == SPEI_RECOGNITION) {
				ISpRecoResult *result = event.RecoResult();

				CoTaskMemPtr<wchar_t> text;
				hr = result->GetText((ULONG)-1, (ULONG)-1, true,
						     &text, nullptr);
				if (FAILED(hr))
					continue;

				char text_utf8[512];
				os_wcs_to_utf8(text, 0, text_utf8, 512);

				callback(text_utf8);

				blog(LOG_DEBUG, "\"%s\"", text_utf8);

			} else if (event.eEventId == SPEI_END_SR_STREAM) {
				exit = true;
				break;
			}
		}

		if (exit)
			break;
	}

	audio->Stop();

} catch (HRError err) {
	blog(LOG_WARNING, "%s failed: %s (%lX)", __FUNCTION__, err.str, err.hr);
}

void mssapi_captions::pcm_data(const void *data, size_t frames)
{
	if (started)
		audio->PushAudio(data, frames);
}

captions_handler_info mssapi_info = {
	[]() -> std::string { return "Microsoft Speech-to-Text"; },
	[](captions_cb cb, const std::string &lang) -> captions_handler * {
		return new mssapi_captions(cb, lang);
	}};