[drdynvc] assert and proper cleanup

This commit is contained in:
Armin Novak 2022-12-06 19:44:30 +01:00 committed by David Fort
parent 558d5b5e8d
commit c3e42de5b5

View File

@ -31,6 +31,8 @@
#define TAG CHANNELS_TAG("drdynvc.client")
static void dvcman_channel_free(DVCMAN_CHANNEL* channel);
static UINT dvcman_channel_close(DVCMAN_CHANNEL* channel, BOOL perRequest, BOOL fromHashTableFn);
static void dvcman_free(drdynvcPlugin* drdynvc, IWTSVirtualChannelManager* pChannelMgr);
static UINT drdynvc_write_data(drdynvcPlugin* drdynvc, UINT32 ChannelId, const BYTE* data,
UINT32 dataSize, BOOL* close);
@ -50,6 +52,7 @@ static void dvcman_wtslistener_free(DVCMAN_LISTENER* listener)
*/
static UINT dvcman_get_configuration(IWTSListener* pListener, void** ppPropertyBag)
{
WINPR_ASSERT(ppPropertyBag);
WINPR_UNUSED(pListener);
*ppPropertyBag = NULL;
return ERROR_INTERNAL_ERROR;
@ -68,6 +71,7 @@ static UINT dvcman_create_listener(IWTSVirtualChannelManager* pChannelMgr,
DVCMAN* dvcman = (DVCMAN*)pChannelMgr;
DVCMAN_LISTENER* listener;
WINPR_ASSERT(dvcman);
WLog_DBG(TAG, "create_listener: %" PRIuz ".%s.", HashTable_Count(dvcman->listeners) + 1,
pszChannelName);
listener = (DVCMAN_LISTENER*)calloc(1, sizeof(DVCMAN_LISTENER));
@ -125,8 +129,10 @@ static UINT dvcman_destroy_listener(IWTSVirtualChannelManager* pChannelMgr, IWTS
static UINT dvcman_register_plugin(IDRDYNVC_ENTRY_POINTS* pEntryPoints, const char* name,
IWTSPlugin* pPlugin)
{
WINPR_ASSERT(pEntryPoints);
DVCMAN* dvcman = ((DVCMAN_ENTRY_POINTS*)pEntryPoints)->dvcman;
WINPR_ASSERT(dvcman);
if (!ArrayList_Append(dvcman->plugin_names, _strdup(name)))
return ERROR_INTERNAL_ERROR;
if (!ArrayList_Append(dvcman->plugins, pPlugin))
@ -140,6 +146,7 @@ static IWTSPlugin* dvcman_get_plugin(IDRDYNVC_ENTRY_POINTS* pEntryPoints, const
{
IWTSPlugin* plugin = NULL;
size_t i, nc, pc;
WINPR_ASSERT(pEntryPoints);
DVCMAN* dvcman = ((DVCMAN_ENTRY_POINTS*)pEntryPoints)->dvcman;
if (!dvcman || !pEntryPoints || !name)
return NULL;
@ -167,6 +174,7 @@ static IWTSPlugin* dvcman_get_plugin(IDRDYNVC_ENTRY_POINTS* pEntryPoints, const
static const ADDIN_ARGV* dvcman_get_plugin_data(IDRDYNVC_ENTRY_POINTS* pEntryPoints)
{
WINPR_ASSERT(pEntryPoints);
return ((DVCMAN_ENTRY_POINTS*)pEntryPoints)->args;
}
@ -188,12 +196,14 @@ static rdpSettings* dvcman_get_rdp_settings(IDRDYNVC_ENTRY_POINTS* pEntryPoints)
static UINT32 dvcman_get_channel_id(IWTSVirtualChannel* channel)
{
DVCMAN_CHANNEL* dvc = (DVCMAN_CHANNEL*)channel;
WINPR_ASSERT(dvc);
return dvc->channel_id;
}
static const char* dvcman_get_channel_name(IWTSVirtualChannel* channel)
{
DVCMAN_CHANNEL* dvc = (DVCMAN_CHANNEL*)channel;
WINPR_ASSERT(dvc);
return dvc->channel_name;
}
@ -203,6 +213,7 @@ static DVCMAN_CHANNEL* dvcman_get_channel_by_id(IWTSVirtualChannelManager* pChan
DVCMAN* dvcman = (DVCMAN*)pChannelMgr;
DVCMAN_CHANNEL* dvcChannel;
WINPR_ASSERT(dvcman);
HashTable_Lock(dvcman->channelsById);
dvcChannel = HashTable_GetItemValue(dvcman->channelsById, &ChannelId);
if (dvcChannel)
@ -229,6 +240,7 @@ static void dvcman_plugin_terminate(void* plugin)
{
IWTSPlugin* pPlugin = plugin;
WINPR_ASSERT(pPlugin);
UINT error = IFCALLRESULT(CHANNEL_RC_OK, pPlugin->Terminated, pPlugin);
if (error != CHANNEL_RC_OK)
WLog_ERR(TAG, "Terminated failed with error %" PRIu32 "!", error);
@ -242,19 +254,31 @@ static void wts_listener_free(void* arg)
static BOOL channelIdMatch(const void* k1, const void* k2)
{
WINPR_ASSERT(k1);
WINPR_ASSERT(k2);
return *((UINT32*)k1) == *((UINT32*)k2);
}
static UINT32 channelIdHash(const void* id)
{
WINPR_ASSERT(id);
return *((UINT32*)id);
}
static void channelByIdCleanerFn(void* value)
{
DVCMAN_CHANNEL* channel = (DVCMAN_CHANNEL*)value;
if (channel)
{
dvcman_channel_close(channel, FALSE, TRUE);
dvcman_channel_free(channel);
}
}
static IWTSVirtualChannelManager* dvcman_new(drdynvcPlugin* plugin)
{
wObject* obj;
DVCMAN* dvcman;
dvcman = (DVCMAN*)calloc(1, sizeof(DVCMAN));
DVCMAN* dvcman = (DVCMAN*)calloc(1, sizeof(DVCMAN));
if (!dvcman)
return NULL;
@ -272,8 +296,13 @@ static IWTSVirtualChannelManager* dvcman_new(drdynvcPlugin* plugin)
HashTable_SetHashFunction(dvcman->channelsById, channelIdHash);
obj = HashTable_KeyObject(dvcman->channelsById);
WINPR_ASSERT(obj);
obj->fnObjectEquals = channelIdMatch;
obj = HashTable_ValueObject(dvcman->channelsById);
WINPR_ASSERT(obj);
obj->fnObjectFree = channelByIdCleanerFn;
dvcman->pool = StreamPool_New(TRUE, 10);
if (!dvcman->pool)
goto fail;
@ -345,7 +374,8 @@ static UINT dvcman_load_addin(drdynvcPlugin* drdynvc, IWTSVirtualChannelManager*
static void dvcman_channel_free(DVCMAN_CHANNEL* channel)
{
WINPR_ASSERT(channel);
if (!channel)
return;
if (channel->dvc_data)
Stream_Release(channel->dvc_data);
@ -359,17 +389,17 @@ static void dvcman_channel_unref(DVCMAN_CHANNEL* channel)
{
DVCMAN* dvcman;
WINPR_ASSERT(channel);
if (InterlockedDecrement(&channel->refCounter))
return;
dvcman = channel->dvcman;
HashTable_Remove(dvcman->channelsById, &channel->channel_id);
dvcman_channel_free(channel);
}
static UINT dvcchannel_send_close(DVCMAN_CHANNEL* channel)
{
WINPR_ASSERT(channel);
DVCMAN* dvcman = channel->dvcman;
drdynvcPlugin* drdynvc = dvcman->drdynvc;
wStream* s = StreamPool_Take(dvcman->pool, 5);
@ -385,12 +415,13 @@ static UINT dvcchannel_send_close(DVCMAN_CHANNEL* channel)
return drdynvc_send(drdynvc, s);
}
static UINT dvcman_channel_close(DVCMAN_CHANNEL* channel, BOOL perRequest)
static UINT dvcman_channel_close(DVCMAN_CHANNEL* channel, BOOL perRequest, BOOL fromHashTableFn)
{
UINT error = CHANNEL_RC_OK;
drdynvcPlugin* drdynvc;
DrdynvcClientContext* context;
WINPR_ASSERT(channel);
switch (channel->state)
{
case DVC_CHANNEL_INIT:
@ -429,6 +460,7 @@ static UINT dvcman_channel_close(DVCMAN_CHANNEL* channel, BOOL perRequest)
}
}
if (!fromHashTableFn)
dvcman_channel_unref(channel);
break;
case DVC_CHANNEL_CLOSED:
@ -444,6 +476,8 @@ static DVCMAN_CHANNEL* dvcman_channel_new(drdynvcPlugin* drdynvc,
{
DVCMAN_CHANNEL* channel;
WINPR_ASSERT(drdynvc);
WINPR_ASSERT(pChannelMgr);
channel = (DVCMAN_CHANNEL*)calloc(1, sizeof(DVCMAN_CHANNEL));
if (!channel)
@ -471,6 +505,7 @@ static void dvcman_clear(drdynvcPlugin* drdynvc, IWTSVirtualChannelManager* pCha
{
DVCMAN* dvcman = (DVCMAN*)pChannelMgr;
WINPR_ASSERT(dvcman);
WINPR_UNUSED(drdynvc);
HashTable_Clear(dvcman->channelsById);
@ -482,6 +517,7 @@ static void dvcman_free(drdynvcPlugin* drdynvc, IWTSVirtualChannelManager* pChan
{
DVCMAN* dvcman = (DVCMAN*)pChannelMgr;
WINPR_ASSERT(dvcman);
WINPR_UNUSED(drdynvc);
ArrayList_Free(dvcman->plugins);
@ -504,6 +540,7 @@ static UINT dvcman_init(drdynvcPlugin* drdynvc, IWTSVirtualChannelManager* pChan
DVCMAN* dvcman = (DVCMAN*)pChannelMgr;
UINT error = CHANNEL_RC_OK;
WINPR_ASSERT(dvcman);
ArrayList_Lock(dvcman->plugins);
for (i = 0; i < ArrayList_Count(dvcman->plugins); i++)
{
@ -545,7 +582,7 @@ static UINT dvcman_write_channel(IWTSVirtualChannel* pChannel, ULONG cbSize, con
LeaveCriticalSection(&(channel->lock));
/* Close delayed, it removes the channel struct */
if (close)
dvcman_channel_close(channel, FALSE);
dvcman_channel_close(channel, FALSE, FALSE);
return status;
}
@ -563,7 +600,7 @@ static UINT dvcman_close_channel_iface(IWTSVirtualChannel* pChannel)
return CHANNEL_RC_BAD_CHANNEL;
WLog_DBG(TAG, "close_channel_iface: id=%" PRIu32 "", channel->channel_id);
return dvcman_channel_close(channel, FALSE);
return dvcman_channel_close(channel, FALSE, FALSE);
}
/**
@ -582,6 +619,7 @@ static DVCMAN_CHANNEL* dvcman_create_channel(drdynvcPlugin* drdynvc,
DVCMAN_LISTENER* listener;
IWTSVirtualChannelCallback* pCallback = NULL;
WINPR_ASSERT(dvcman);
WINPR_ASSERT(res);
HashTable_Lock(dvcman->listeners);
@ -687,6 +725,8 @@ static UINT dvcman_open_channel(drdynvcPlugin* drdynvc, DVCMAN_CHANNEL* channel)
IWTSVirtualChannelCallback* pCallback;
UINT error = CHANNEL_RC_OK;
WINPR_ASSERT(drdynvc);
WINPR_ASSERT(channel);
if (channel->state == DVC_CHANNEL_RUNNING)
{
pCallback = channel->channel_callback;
@ -717,6 +757,8 @@ out:
*/
static UINT dvcman_receive_channel_data_first(DVCMAN_CHANNEL* channel, UINT32 length)
{
WINPR_ASSERT(channel);
WINPR_ASSERT(channel->dvcman);
if (channel->dvc_data)
Stream_Release(channel->dvc_data);
@ -744,6 +786,8 @@ static UINT dvcman_receive_channel_data(DVCMAN_CHANNEL* channel, wStream* data,
UINT status = CHANNEL_RC_OK;
size_t dataSize = Stream_GetRemainingLength(data);
WINPR_ASSERT(channel);
WINPR_ASSERT(channel->dvcman);
if (channel->dvc_data)
{
drdynvcPlugin* drdynvc = channel->dvcman->drdynvc;
@ -864,6 +908,7 @@ static UINT drdynvc_write_data(drdynvcPlugin* drdynvc, UINT32 ChannelId, const B
return CHANNEL_RC_BAD_CHANNEL_HANDLE;
dvcman = (DVCMAN*)drdynvc->channel_mgr;
WINPR_ASSERT(dvcman);
WLog_Print(drdynvc->log, WLOG_TRACE, "write_data: ChannelId=%" PRIu32 " size=%" PRIu32 "",
ChannelId, dataSize);
@ -960,6 +1005,8 @@ static UINT drdynvc_send_capability_response(drdynvcPlugin* drdynvc)
return CHANNEL_RC_BAD_CHANNEL_HANDLE;
dvcman = (DVCMAN*)drdynvc->channel_mgr;
WINPR_ASSERT(dvcman);
WLog_Print(drdynvc->log, WLOG_TRACE, "capability_response");
s = StreamPool_Take(dvcman->pool, 4);
@ -1081,6 +1128,8 @@ static UINT drdynvc_process_create_request(drdynvcPlugin* drdynvc, int Sp, int c
return CHANNEL_RC_BAD_CHANNEL_HANDLE;
dvcman = (DVCMAN*)drdynvc->channel_mgr;
WINPR_ASSERT(dvcman);
if (drdynvc->state == DRDYNVC_STATE_CAPABILITIES)
{
/**
@ -1182,6 +1231,7 @@ static UINT drdynvc_process_data_first(drdynvcPlugin* drdynvc, int Sp, int cbChI
UINT32 ChannelId;
DVCMAN_CHANNEL* channel;
WINPR_ASSERT(drdynvc);
if (!Stream_CheckAndLogRequiredLength(
TAG, s, drdynvc_cblen_to_bytes(cbChId) + drdynvc_cblen_to_bytes(Sp)))
return ERROR_INVALID_DATA;
@ -1213,7 +1263,7 @@ static UINT drdynvc_process_data_first(drdynvcPlugin* drdynvc, int Sp, int cbChI
status = dvcman_receive_channel_data(channel, s, ThreadingFlags);
if (status != CHANNEL_RC_OK)
status = dvcman_channel_close(channel, FALSE);
status = dvcman_channel_close(channel, FALSE, FALSE);
out:
dvcman_channel_unref(channel);
@ -1232,6 +1282,7 @@ static UINT drdynvc_process_data(drdynvcPlugin* drdynvc, int Sp, int cbChId, wSt
DVCMAN_CHANNEL* channel;
UINT status = CHANNEL_RC_OK;
WINPR_ASSERT(drdynvc);
if (!Stream_CheckAndLogRequiredLength(TAG, s, drdynvc_cblen_to_bytes(cbChId)))
return ERROR_INVALID_DATA;
@ -1256,7 +1307,7 @@ static UINT drdynvc_process_data(drdynvcPlugin* drdynvc, int Sp, int cbChId, wSt
status = dvcman_receive_channel_data(channel, s, ThreadingFlags);
if (status != CHANNEL_RC_OK)
status = dvcman_channel_close(channel, FALSE);
status = dvcman_channel_close(channel, FALSE, FALSE);
out:
dvcman_channel_unref(channel);
@ -1273,6 +1324,7 @@ static UINT drdynvc_process_close_request(drdynvcPlugin* drdynvc, int Sp, int cb
UINT32 ChannelId;
DVCMAN_CHANNEL* channel;
WINPR_ASSERT(drdynvc);
if (!Stream_CheckAndLogRequiredLength(TAG, s, drdynvc_cblen_to_bytes(cbChId)))
return ERROR_INVALID_DATA;
@ -1289,7 +1341,7 @@ static UINT drdynvc_process_close_request(drdynvcPlugin* drdynvc, int Sp, int cb
return CHANNEL_RC_OK;
}
dvcman_channel_close(channel, TRUE);
dvcman_channel_close(channel, TRUE, FALSE);
dvcman_channel_unref(channel);
return CHANNEL_RC_OK;
}
@ -1306,6 +1358,7 @@ static UINT drdynvc_order_recv(drdynvcPlugin* drdynvc, wStream* s, UINT32 Thread
int Sp;
int cbChId;
WINPR_ASSERT(drdynvc);
if (!Stream_CheckAndLogRequiredLength(TAG, s, 1))
return ERROR_INVALID_DATA;
@ -1349,6 +1402,7 @@ static UINT drdynvc_virtual_channel_event_data_received(drdynvcPlugin* drdynvc,
{
wStream* data_in;
WINPR_ASSERT(drdynvc);
if ((dataFlags & CHANNEL_FLAG_SUSPEND) || (dataFlags & CHANNEL_FLAG_RESUME))
{
return CHANNEL_RC_OK;
@ -1425,6 +1479,7 @@ static void VCAPITYPE drdynvc_virtual_channel_open_event_ex(LPVOID lpUserParam,
UINT error = CHANNEL_RC_OK;
drdynvcPlugin* drdynvc = (drdynvcPlugin*)lpUserParam;
WINPR_ASSERT(drdynvc);
switch (event)
{
case CHANNEL_EVENT_DATA_RECEIVED:
@ -1459,14 +1514,6 @@ static void VCAPITYPE drdynvc_virtual_channel_open_event_ex(LPVOID lpUserParam,
"drdynvc_virtual_channel_open_event reported an error");
}
static BOOL channelByIdCleanerFn(const void* key, void* value, void* arg)
{
DVCMAN_CHANNEL* channel = (DVCMAN_CHANNEL*)value;
dvcman_channel_close(channel, FALSE);
return TRUE;
}
static DWORD WINAPI drdynvc_virtual_channel_client_thread(LPVOID arg)
{
/* TODO: rewrite this */
@ -1521,7 +1568,7 @@ static DWORD WINAPI drdynvc_virtual_channel_client_thread(LPVOID arg)
* event handlers. */
DVCMAN* drdynvcMgr = (DVCMAN*)drdynvc->channel_mgr;
HashTable_Foreach(drdynvcMgr->channelsById, channelByIdCleanerFn, drdynvc);
HashTable_Clear(drdynvcMgr->channelsById);
}
if (error && drdynvc->rdpcontext)
@ -1695,7 +1742,7 @@ static UINT drdynvc_virtual_channel_event_disconnected(drdynvcPlugin* drdynvc)
* event handlers. */
DVCMAN* drdynvcMgr = (DVCMAN*)drdynvc->channel_mgr;
HashTable_Foreach(drdynvcMgr->channelsById, channelByIdCleanerFn, drdynvc);
HashTable_Clear(drdynvcMgr->channelsById);
}
}
@ -1886,7 +1933,9 @@ static VOID VCAPITYPE drdynvc_virtual_channel_init_event_ex(LPVOID lpUserParam,
static int drdynvc_get_version(DrdynvcClientContext* context)
{
WINPR_ASSERT(context);
drdynvcPlugin* drdynvc = (drdynvcPlugin*)context->handle;
WINPR_ASSERT(drdynvc);
return drdynvc->version;
}
@ -1901,6 +1950,7 @@ BOOL VCAPITYPE VirtualChannelEntryEx(PCHANNEL_ENTRY_POINTS_EX pEntryPoints, PVOI
CHANNEL_ENTRY_POINTS_FREERDP_EX* pEntryPointsEx;
drdynvc = (drdynvcPlugin*)calloc(1, sizeof(drdynvcPlugin));
WINPR_ASSERT(pEntryPoints);
if (!drdynvc)
{
WLog_ERR(TAG, "calloc failed!");