44 #include "ngram_model_internal.h"
45 #include "lm_trie_quant.h"
47 #define FLOAT_INF (0x7f800000)
55 lm_trie_quant_type_t quant_type;
56 bins_t tables[NGRAM_MAX_ORDER - 1][2];
67 bins_create(
bins_t * bins, uint8 bits,
float *begin)
70 bins->end = bins->begin + (1ULL << bits);
74 lower_bound(
float *first,
const float *last,
float val)
96 bins_encode(
bins_t * bins,
float value)
98 float *above = lower_bound(bins->begin, bins->end, value);
99 if (above == bins->begin)
101 if (above == bins->end)
102 return bins->end - bins->begin - 1;
103 return above - bins->begin - (value - *(above - 1) < *above - value);
107 bins_decode(
bins_t * bins,
size_t off)
109 return bins->begin[off];
113 quant_apply_size(
int order,
int prob_bits,
int bo_bits)
115 size_t longest_table = (1
U << prob_bits) *
sizeof(
float);
116 size_t middle_table = (1
U << bo_bits) *
sizeof(
float) + longest_table;
118 return (order - 2) * middle_table + longest_table;
122 quant_size(lm_trie_quant_type_t quant_type,
int order)
124 switch (quant_type) {
128 return quant_apply_size(order, 16, 16);
131 E_INFO(
"Unsupported quantatization type\n");
137 lm_trie_quant_create(lm_trie_quant_type_t quant_type,
int order)
143 quant->quant_type = quant_type;
144 quant->mem_size = quant_size(quant_type, order);
146 (uint8 *)
ckd_calloc(quant->mem_size,
sizeof(*quant->mem));
147 switch (quant_type) {
151 quant->prob_bits = 16;
153 quant->prob_mask = (1
U << quant->prob_bits) - 1;
154 quant->bo_mask = (1
U << quant->bo_bits) - 1;
157 E_INFO(
"Unsupported quantization type\n");
160 start = (
float *) (quant->mem);
161 for (i = 0; i < order - 2; i++) {
162 bins_create(&quant->tables[i][0], quant->prob_bits, start);
163 start += (1ULL << quant->prob_bits);
164 bins_create(&quant->tables[i][1], quant->bo_bits, start);
165 start += (1ULL << quant->bo_bits);
167 bins_create(&quant->tables[order - 2][0], quant->prob_bits, start);
168 quant->longest = &quant->tables[order - 2][0];
174 lm_trie_quant_read_bin(FILE * fp,
int order)
177 lm_trie_quant_type_t quant_type;
180 fread(&quant_type_int,
sizeof(quant_type_int), 1, fp);
181 quant_type = (lm_trie_quant_type_t) quant_type_int;
182 quant = lm_trie_quant_create(quant_type, order);
183 fread(quant->mem,
sizeof(*quant->mem), quant->mem_size, fp);
191 int quant_type_int = (int) quant->quant_type;
193 fwrite(&quant_type_int,
sizeof(quant_type_int), 1, fp);
194 fwrite(quant->mem,
sizeof(*quant->mem), quant->mem_size, fp);
208 switch (quant->quant_type) {
215 E_INFO(
"Unsupported quantatization type\n");
223 switch (quant->quant_type) {
230 E_INFO(
"Unsupported quantatization type\n");
238 return quant->quant_type > 0;
242 weights_comparator(
const void *a,
const void *b)
244 return (
int) (*(
float *) a - *(
float *) b);
248 make_bins(
float *values, uint32 values_num,
float *centers, uint32 bins)
250 float *finish, *start;
253 qsort(values, values_num,
sizeof(*values), &weights_comparator);
255 for (i = 0; i < bins; i++, centers++, start = finish) {
256 finish = values + (size_t) ((uint64) values_num * (i + 1) / bins);
257 if (finish == start) {
259 *centers = i ? *(centers - 1) : -FLOAT_INF;
264 for (ptr = start; ptr != finish; ptr++) {
267 *centers = sum / (float) (finish - start);
283 probs = (
float *)
ckd_calloc(counts,
sizeof(*probs));
284 backoffs = (
float *)
ckd_calloc(counts,
sizeof(*backoffs));
285 raw_ngrams_end = raw_ngrams + counts;
287 for (backoff_num = 0, prob_num = 0; raw_ngrams != raw_ngrams_end;
289 float *weights = raw_ngrams->weights;
290 probs[prob_num++] = *weights;
292 backoffs[backoff_num++] = *weights;
295 make_bins(probs, prob_num, quant->tables[order - 2][0].begin,
296 1ULL << quant->prob_bits);
297 centers = quant->tables[order - 2][1].begin;
298 make_bins(backoffs, backoff_num, centers, (1ULL << quant->bo_bits));
304 lm_trie_quant_train_prob(
lm_trie_quant_t * quant,
int order, uint32 counts,
311 probs = (
float *)
ckd_calloc(counts,
sizeof(*probs));
312 raw_ngrams_end = raw_ngrams + counts;
314 for (prob_num = 0; raw_ngrams != raw_ngrams_end; raw_ngrams++) {
315 float *weights = raw_ngrams->weights;
316 probs[prob_num++] = *weights;
319 make_bins(probs, prob_num, quant->tables[order - 2][0].begin,
320 1ULL << quant->prob_bits);
326 int order_minus_2,
float prob,
float backoff)
328 switch (quant->quant_type) {
331 address.offset += 31;
336 (uint64) ((bins_encode
337 (&quant->tables[order_minus_2][0],
339 bo_bits) | bins_encode(&quant->
347 E_INFO(
"Unsupported quantatization type\n");
355 switch (quant->quant_type) {
361 (uint32) bins_encode(quant->longest, prob));
365 E_INFO(
"Unsupported quantization type\n");
373 switch (quant->quant_type) {
375 address.offset += 31;
378 return bins_decode(&quant->tables[order_minus_2][1],
383 E_INFO(
"Unsupported quantatization type\n");
392 switch (quant->quant_type) {
396 address.offset += quant->bo_bits;
397 return bins_decode(&quant->tables[order_minus_2][0],
402 E_INFO(
"Unsupported quantatization type\n");
410 switch (quant->quant_type) {
414 return bins_decode(quant->longest,
419 E_INFO(
"Unsupported quantatization type\n");
#define E_INFO(...)
Print logging information to standard error stream.
#define ckd_calloc(n, sz)
Macros to simplify the use of above functions.
SPHINXBASE_EXPORT void bitarr_write_negfloat(bitarr_address_t address, float value)
Writes non positive float32 to bit array.
Sphinx's memory allocation/deallocation routines.
Basic type definitions used in Sphinx.
SPHINXBASE_EXPORT float bitarr_read_negfloat(bitarr_address_t address)
Read non positive float32 from bit array.
SPHINXBASE_EXPORT void bitarr_write_int25(bitarr_address_t address, uint8 length, uint32 value)
Write specified value into bit array.
SPHINXBASE_EXPORT void ckd_free(void *ptr)
Test and free a 1-D array.
Structure that stores address of certain value in bit array.
SPHINXBASE_EXPORT uint32 bitarr_read_int25(bitarr_address_t address, uint8 length, uint32 mask)
Read uint32 value from bit array.
SPHINXBASE_EXPORT void bitarr_write_float(bitarr_address_t address, float value)
Writes float32 to bit array.
SPHINXBASE_EXPORT void bitarr_write_int57(bitarr_address_t address, uint8 length, uint64 value)
Write specified value into bit array.
Implementation of logging routines.
SPHINXBASE_EXPORT float bitarr_read_float(bitarr_address_t address)
Reads float32 from bit array.