mirror of
https://github.com/facebook/zstd.git
synced 2024-11-28 19:26:46 +08:00
Move XXH64_update() into worker threads
* Computes the XXH hash in the worker threads. * Workers get a sequence number and wait until ther number shows up. On error, ensures that its sequence is finished, so future threads don't get blocked. * Sets up for ldm integration, which will go in the same spot.
This commit is contained in:
parent
f2562c02e4
commit
2253d01b27
@ -304,16 +304,81 @@ static void ZSTDMT_releaseCCtx(ZSTDMT_CCtxPool* pool, ZSTD_CCtx* cctx)
|
||||
ZSTD_pthread_mutex_unlock(&pool->poolMutex);
|
||||
}
|
||||
|
||||
|
||||
/* ------------------------------------------ */
|
||||
/* ===== Worker thread ===== */
|
||||
/* ------------------------------------------ */
|
||||
/* ==== Serial State ==== */
|
||||
|
||||
typedef struct {
|
||||
void const* start;
|
||||
size_t size;
|
||||
} range_t;
|
||||
|
||||
typedef struct {
|
||||
ZSTD_pthread_mutex_t mutex;
|
||||
ZSTD_pthread_cond_t cond;
|
||||
ZSTD_CCtx_params params;
|
||||
XXH64_state_t xxhState;
|
||||
unsigned nextJobID;
|
||||
} serialState_t;
|
||||
|
||||
static void ZSTDMT_serialState_reset(serialState_t* serialState, ZSTD_CCtx_params params)
|
||||
{
|
||||
serialState->nextJobID = 0;
|
||||
if (params.fParams.checksumFlag)
|
||||
XXH64_reset(&serialState->xxhState, 0);
|
||||
serialState->params = params;
|
||||
}
|
||||
|
||||
static int ZSTDMT_serialState_init(serialState_t* serialState)
|
||||
{
|
||||
int initError = 0;
|
||||
initError |= ZSTD_pthread_mutex_init(&serialState->mutex, NULL);
|
||||
initError |= ZSTD_pthread_cond_init(&serialState->cond, NULL);
|
||||
return initError;
|
||||
}
|
||||
|
||||
static void ZSTDMT_serialState_free(serialState_t* serialState)
|
||||
{
|
||||
ZSTD_pthread_mutex_destroy(&serialState->mutex);
|
||||
ZSTD_pthread_cond_destroy(&serialState->cond);
|
||||
}
|
||||
|
||||
static void ZSTDMT_serialState_update(serialState_t* serialState, range_t src, unsigned jobID)
|
||||
{
|
||||
/* Wait for our turn */
|
||||
ZSTD_PTHREAD_MUTEX_LOCK(&serialState->mutex);
|
||||
while (serialState->nextJobID < jobID) {
|
||||
ZSTD_pthread_cond_wait(&serialState->cond, &serialState->mutex);
|
||||
}
|
||||
/* A future job may error and skip our job */
|
||||
if (serialState->nextJobID == jobID) {
|
||||
/* It is now our turn, do any processing necessary */
|
||||
if (serialState->params.fParams.checksumFlag && src.size > 0)
|
||||
XXH64_update(&serialState->xxhState, src.start, src.size);
|
||||
}
|
||||
/* Now it is the next jobs turn */
|
||||
serialState->nextJobID++;
|
||||
ZSTD_pthread_cond_broadcast(&serialState->cond);
|
||||
ZSTD_pthread_mutex_unlock(&serialState->mutex);
|
||||
}
|
||||
|
||||
static void ZSTDMT_serialState_ensureFinished(serialState_t* serialState,
|
||||
unsigned jobID, size_t cSize)
|
||||
{
|
||||
ZSTD_PTHREAD_MUTEX_LOCK(&serialState->mutex);
|
||||
if (serialState->nextJobID <= jobID) {
|
||||
assert(ZSTD_isError(cSize)); (void)cSize;
|
||||
DEBUGLOG(5, "Skipping past job %u because of error", jobID);
|
||||
serialState->nextJobID = jobID + 1;
|
||||
ZSTD_pthread_cond_broadcast(&serialState->cond);
|
||||
}
|
||||
ZSTD_pthread_mutex_unlock(&serialState->mutex);
|
||||
|
||||
}
|
||||
|
||||
|
||||
/* ------------------------------------------ */
|
||||
/* ===== Worker thread ===== */
|
||||
/* ------------------------------------------ */
|
||||
|
||||
static const range_t kNullRange = { NULL, 0 };
|
||||
|
||||
typedef struct {
|
||||
@ -323,9 +388,11 @@ typedef struct {
|
||||
ZSTD_pthread_cond_t job_cond; /* Thread-safe - used by mtctx and worker */
|
||||
ZSTDMT_CCtxPool* cctxPool; /* Thread-safe - used by mtctx and (all) workers */
|
||||
ZSTDMT_bufferPool* bufPool; /* Thread-safe - used by mtctx and (all) workers */
|
||||
serialState_t* serial; /* Thread-safe - used by mtctx and (all) workers */
|
||||
buffer_t dstBuff; /* set by worker (or mtctx), then read by worker & mtctx, then modified by mtctx => no barrier */
|
||||
range_t prefix; /* set by mtctx, then read by worker & mtctx => no barrier */
|
||||
range_t src; /* set by mtctx, then read by worker & mtctx => no barrier */
|
||||
unsigned jobID; /* set by mtctx, then read by worker => no barrier */
|
||||
unsigned firstJob; /* set by mtctx, then read by worker => no barrier */
|
||||
unsigned lastJob; /* set by mtctx, then read by worker => no barrier */
|
||||
ZSTD_CCtx_params params; /* set by mtctx, then read by worker => no barrier */
|
||||
@ -339,9 +406,13 @@ typedef struct {
|
||||
void ZSTDMT_compressionJob(void* jobDescription)
|
||||
{
|
||||
ZSTDMT_jobDescription* const job = (ZSTDMT_jobDescription*)jobDescription;
|
||||
ZSTD_CCtx_params jobParams = job->params; /* do not modify job->params ! copy it, modify the copy */
|
||||
ZSTD_CCtx* const cctx = ZSTDMT_getCCtx(job->cctxPool);
|
||||
buffer_t dstBuff = job->dstBuff;
|
||||
|
||||
/* Don't compute the checksum for chunks, but write it in the header */
|
||||
if (job->jobID != 0) jobParams.fParams.checksumFlag = 0;
|
||||
|
||||
/* ressources */
|
||||
if (cctx==NULL) {
|
||||
job->cSize = ERROR(memory_allocation);
|
||||
@ -358,12 +429,11 @@ void ZSTDMT_compressionJob(void* jobDescription)
|
||||
|
||||
/* init */
|
||||
if (job->cdict) {
|
||||
size_t const initError = ZSTD_compressBegin_advanced_internal(cctx, NULL, 0, ZSTD_dm_auto, job->cdict, job->params, job->fullFrameSize);
|
||||
size_t const initError = ZSTD_compressBegin_advanced_internal(cctx, NULL, 0, ZSTD_dm_auto, job->cdict, jobParams, job->fullFrameSize);
|
||||
assert(job->firstJob); /* only allowed for first job */
|
||||
if (ZSTD_isError(initError)) { job->cSize = initError; goto _endJob; }
|
||||
} else { /* srcStart points at reloaded section */
|
||||
U64 const pledgedSrcSize = job->firstJob ? job->fullFrameSize : job->src.size;
|
||||
ZSTD_CCtx_params jobParams = job->params; /* do not modify job->params ! copy it, modify the copy */
|
||||
{ size_t const forceWindowError = ZSTD_CCtxParam_setParameter(&jobParams, ZSTD_p_forceMaxWindow, !job->firstJob);
|
||||
if (ZSTD_isError(forceWindowError)) {
|
||||
job->cSize = forceWindowError;
|
||||
@ -377,6 +447,10 @@ void ZSTDMT_compressionJob(void* jobDescription)
|
||||
job->cSize = initError;
|
||||
goto _endJob;
|
||||
} } }
|
||||
|
||||
/* Perform serial step as early as possible */
|
||||
ZSTDMT_serialState_update(job->serial, job->src, job->jobID);
|
||||
|
||||
if (!job->firstJob) { /* flush and overwrite frame header when it's not first job */
|
||||
size_t const hSize = ZSTD_compressContinue(cctx, dstBuff.start, dstBuff.capacity, job->src.start, 0);
|
||||
if (ZSTD_isError(hSize)) { job->cSize = hSize; /* save error code */ goto _endJob; }
|
||||
@ -425,6 +499,7 @@ void ZSTDMT_compressionJob(void* jobDescription)
|
||||
} }
|
||||
|
||||
_endJob:
|
||||
ZSTDMT_serialState_ensureFinished(job->serial, job->jobID, job->cSize);
|
||||
if (job->prefix.size > 0)
|
||||
DEBUGLOG(5, "Finished with prefix: %zx", (size_t)job->prefix.start);
|
||||
DEBUGLOG(5, "Finished with source: %zx", (size_t)job->src.start);
|
||||
@ -475,7 +550,7 @@ struct ZSTDMT_CCtx_s {
|
||||
roundBuff_t roundBuff;
|
||||
inBuff_t inBuff;
|
||||
int jobReady; /* 1 => one job is already prepared, but pool has shortage of workers. Don't create another one. */
|
||||
XXH64_state_t xxhState;
|
||||
serialState_t serial;
|
||||
unsigned singleBlockingThread;
|
||||
unsigned jobIDMask;
|
||||
unsigned doneJobID;
|
||||
@ -540,6 +615,7 @@ ZSTDMT_CCtx* ZSTDMT_createCCtx_advanced(unsigned nbWorkers, ZSTD_customMem cMem)
|
||||
{
|
||||
ZSTDMT_CCtx* mtctx;
|
||||
U32 nbJobs = nbWorkers + 2;
|
||||
int initError;
|
||||
DEBUGLOG(3, "ZSTDMT_createCCtx_advanced (nbWorkers = %u)", nbWorkers);
|
||||
|
||||
if (nbWorkers < 1) return NULL;
|
||||
@ -559,8 +635,9 @@ ZSTDMT_CCtx* ZSTDMT_createCCtx_advanced(unsigned nbWorkers, ZSTD_customMem cMem)
|
||||
mtctx->jobIDMask = nbJobs - 1;
|
||||
mtctx->bufPool = ZSTDMT_createBufferPool(nbWorkers, cMem);
|
||||
mtctx->cctxPool = ZSTDMT_createCCtxPool(nbWorkers, cMem);
|
||||
initError = ZSTDMT_serialState_init(&mtctx->serial);
|
||||
mtctx->roundBuff = kNullRoundBuff;
|
||||
if (!mtctx->factory | !mtctx->jobs | !mtctx->bufPool | !mtctx->cctxPool) {
|
||||
if (!mtctx->factory | !mtctx->jobs | !mtctx->bufPool | !mtctx->cctxPool | initError) {
|
||||
ZSTDMT_freeCCtx(mtctx);
|
||||
return NULL;
|
||||
}
|
||||
@ -615,6 +692,7 @@ size_t ZSTDMT_freeCCtx(ZSTDMT_CCtx* mtctx)
|
||||
ZSTDMT_freeJobsTable(mtctx->jobs, mtctx->jobIDMask+1, mtctx->cMem);
|
||||
ZSTDMT_freeBufferPool(mtctx->bufPool);
|
||||
ZSTDMT_freeCCtxPool(mtctx->cctxPool);
|
||||
ZSTDMT_serialState_free(&mtctx->serial);
|
||||
ZSTD_freeCDict(mtctx->cdictLocal);
|
||||
if (mtctx->roundBuff.buffer)
|
||||
ZSTD_free(mtctx->roundBuff.buffer, mtctx->cMem);
|
||||
@ -779,7 +857,6 @@ static size_t ZSTDMT_compress_advanced_internal(
|
||||
size_t remainingSrcSize = srcSize;
|
||||
unsigned const compressWithinDst = (dstCapacity >= ZSTD_compressBound(srcSize)) ? nbJobs : (unsigned)(dstCapacity / ZSTD_compressBound(avgJobSize)); /* presumes avgJobSize >= 256 KB, which should be the case */
|
||||
size_t frameStartPos = 0, dstBufferPos = 0;
|
||||
XXH64_state_t xxh64;
|
||||
assert(jobParams.nbWorkers == 0);
|
||||
assert(mtctx->cctxPool->totalCCtx == params.nbWorkers);
|
||||
|
||||
@ -795,7 +872,7 @@ static size_t ZSTDMT_compress_advanced_internal(
|
||||
|
||||
assert(avgJobSize >= 256 KB); /* condition for ZSTD_compressBound(A) + ZSTD_compressBound(B) <= ZSTD_compressBound(A+B), required to compress directly into Dst (no additional buffer) */
|
||||
ZSTDMT_setBufferSize(mtctx->bufPool, ZSTD_compressBound(avgJobSize) );
|
||||
XXH64_reset(&xxh64, 0);
|
||||
ZSTDMT_serialState_reset(&mtctx->serial, params);
|
||||
|
||||
if (nbJobs > mtctx->jobIDMask+1) { /* enlarge job table */
|
||||
U32 jobsTableSize = nbJobs;
|
||||
@ -825,17 +902,14 @@ static size_t ZSTDMT_compress_advanced_internal(
|
||||
mtctx->jobs[u].fullFrameSize = srcSize;
|
||||
mtctx->jobs[u].params = jobParams;
|
||||
/* do not calculate checksum within sections, but write it in header for first section */
|
||||
if (u!=0) mtctx->jobs[u].params.fParams.checksumFlag = 0;
|
||||
mtctx->jobs[u].dstBuff = dstBuffer;
|
||||
mtctx->jobs[u].cctxPool = mtctx->cctxPool;
|
||||
mtctx->jobs[u].bufPool = mtctx->bufPool;
|
||||
mtctx->jobs[u].serial = &mtctx->serial;
|
||||
mtctx->jobs[u].jobID = u;
|
||||
mtctx->jobs[u].firstJob = (u==0);
|
||||
mtctx->jobs[u].lastJob = (u==nbJobs-1);
|
||||
|
||||
if (params.fParams.checksumFlag) {
|
||||
XXH64_update(&xxh64, srcStart + frameStartPos, jobSize);
|
||||
}
|
||||
|
||||
DEBUGLOG(5, "ZSTDMT_compress_advanced_internal: posting job %u (%u bytes)", u, (U32)jobSize);
|
||||
DEBUG_PRINTHEX(6, mtctx->jobs[u].prefix.start, 12);
|
||||
POOL_add(mtctx->factory, ZSTDMT_compressionJob, &mtctx->jobs[u]);
|
||||
@ -876,7 +950,7 @@ static size_t ZSTDMT_compress_advanced_internal(
|
||||
|
||||
DEBUGLOG(4, "checksumFlag : %u ", params.fParams.checksumFlag);
|
||||
if (params.fParams.checksumFlag) {
|
||||
U32 const checksum = (U32)XXH64_digest(&xxh64);
|
||||
U32 const checksum = (U32)XXH64_digest(&mtctx->serial.xxhState);
|
||||
if (dstPos + 4 > dstCapacity) {
|
||||
error = ERROR(dstSize_tooSmall);
|
||||
} else {
|
||||
@ -1016,7 +1090,7 @@ size_t ZSTDMT_initCStream_internal(
|
||||
mtctx->allJobsCompleted = 0;
|
||||
mtctx->consumed = 0;
|
||||
mtctx->produced = 0;
|
||||
if (params.fParams.checksumFlag) XXH64_reset(&mtctx->xxhState, 0);
|
||||
ZSTDMT_serialState_reset(&mtctx->serial, params);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -1113,21 +1187,18 @@ static size_t ZSTDMT_createCompressionJob(ZSTDMT_CCtx* mtctx, size_t srcSize, ZS
|
||||
mtctx->jobs[jobID].consumed = 0;
|
||||
mtctx->jobs[jobID].cSize = 0;
|
||||
mtctx->jobs[jobID].params = mtctx->params;
|
||||
/* do not calculate checksum within sections, but write it in header for first section */
|
||||
if (mtctx->nextJobID) mtctx->jobs[jobID].params.fParams.checksumFlag = 0;
|
||||
mtctx->jobs[jobID].cdict = mtctx->nextJobID==0 ? mtctx->cdict : NULL;
|
||||
mtctx->jobs[jobID].fullFrameSize = mtctx->frameContentSize;
|
||||
mtctx->jobs[jobID].dstBuff = g_nullBuffer;
|
||||
mtctx->jobs[jobID].cctxPool = mtctx->cctxPool;
|
||||
mtctx->jobs[jobID].bufPool = mtctx->bufPool;
|
||||
mtctx->jobs[jobID].serial = &mtctx->serial;
|
||||
mtctx->jobs[jobID].jobID = mtctx->nextJobID;
|
||||
mtctx->jobs[jobID].firstJob = (mtctx->nextJobID==0);
|
||||
mtctx->jobs[jobID].lastJob = endFrame;
|
||||
mtctx->jobs[jobID].frameChecksumNeeded = endFrame && (mtctx->nextJobID>0) && mtctx->params.fParams.checksumFlag;
|
||||
mtctx->jobs[jobID].dstFlushed = 0;
|
||||
|
||||
if (mtctx->params.fParams.checksumFlag && srcSize > 0)
|
||||
XXH64_update(&mtctx->xxhState, src, srcSize);
|
||||
|
||||
/* Update the round buffer pos and clear the input buffer to be reset */
|
||||
mtctx->roundBuff.pos += srcSize;
|
||||
mtctx->inBuff.buffer = g_nullBuffer;
|
||||
@ -1214,7 +1285,7 @@ static size_t ZSTDMT_flushProduced(ZSTDMT_CCtx* mtctx, ZSTD_outBuffer* output, u
|
||||
assert(srcConsumed <= srcSize);
|
||||
if ( (srcConsumed == srcSize) /* job completed -> worker no longer active */
|
||||
&& mtctx->jobs[wJobID].frameChecksumNeeded ) {
|
||||
U32 const checksum = (U32)XXH64_digest(&mtctx->xxhState);
|
||||
U32 const checksum = (U32)XXH64_digest(&mtctx->serial.xxhState);
|
||||
DEBUGLOG(4, "ZSTDMT_flushProduced: writing checksum : %08X \n", checksum);
|
||||
MEM_writeLE32((char*)mtctx->jobs[wJobID].dstBuff.start + mtctx->jobs[wJobID].cSize, checksum);
|
||||
cSize += 4;
|
||||
|
Loading…
Reference in New Issue
Block a user