44 #include <sphinxbase/priority_queue.h>
47 #include "lm_trie_quant.h"
50 base_size(uint32 entries, uint32 max_vocab, uint8 remaining_bits)
57 return ((1 + entries) * total_bits + 7) / 8 +
sizeof(uint64);
61 middle_size(uint8 quant_bits, uint32 entries, uint32 max_vocab,
64 return base_size(entries, max_vocab,
69 longest_size(uint8 quant_bits, uint32 entries, uint32 max_vocab)
71 return base_size(entries, max_vocab, quant_bits);
75 base_init(
base_t * base,
void *base_mem, uint32 max_vocab,
79 base->word_mask = (1
U << base->word_bits) - 1
U;
80 if (base->word_bits > 25)
82 (
"Sorry, word indices more than %d are not implemented. Edit util/bit_packing.hh and fix the bit packing functions\n",
84 base->total_bits = base->word_bits + remaining_bits;
86 base->base = (uint8 *) base_mem;
87 base->insert_index = 0;
88 base->max_vocab = max_vocab;
92 middle_init(
middle_t * middle,
void *base_mem, uint8 quant_bits,
93 uint32 entries, uint32 max_vocab, uint32 max_next,
96 middle->quant_bits = quant_bits;
98 middle->next_source = next_source;
99 if (entries + 1 >= (1
U << 25) || (max_next >= (1
U << 25)))
101 (
"Sorry, this does not support more than %d n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions\n",
103 base_init(&middle->base, base_mem, max_vocab,
104 quant_bits + middle->next_mask.bits);
108 longest_init(
longest_t * longest,
void *base_mem, uint8 quant_bits,
111 base_init(&longest->base, base_mem, max_vocab, quant_bits);
115 middle_insert(
middle_t * middle, uint32 word,
int order,
int max_order)
120 assert(word <= middle->base.word_mask);
121 address.base = middle->base.base;
122 address.offset = middle->base.insert_index * middle->base.total_bits;
124 address.offset += middle->base.word_bits;
125 at_pointer = address.offset;
126 address.offset += middle->quant_bits;
127 if (order == max_order - 1) {
128 next = ((
longest_t *) middle->next_source)->base.insert_index;
131 next = ((
middle_t *) middle->next_source)->base.insert_index;
135 middle->base.insert_index++;
136 address.offset = at_pointer;
141 longest_insert(
longest_t * longest, uint32 index)
144 assert(index <= longest->base.word_mask);
145 address.base = longest->base.base;
146 address.offset = longest->base.insert_index * longest->base.total_bits;
148 address.offset += longest->base.word_bits;
149 longest->base.insert_index++;
154 middle_finish_loading(
middle_t * middle, uint32 next_end)
157 address.base = middle->base.base;
159 (middle->base.insert_index + 1) * middle->base.total_bits -
160 middle->next_mask.bits;
165 unigram_next(
lm_trie_t * trie,
int order)
168 2 ? trie->longest->base.insert_index : trie->middle_begin->base.
174 uint32 * counts,
int order)
176 uint32 unigram_idx = 0;
179 const uint32 unigram_count = (uint32) counts[0];
181 priority_queue_create(order, &ngram_ord_comparator);
183 uint32 *raw_ngrams_ptr;
186 words = (uint32 *)
ckd_calloc(order,
sizeof(*words));
187 probs = (
float *)
ckd_calloc(order - 1,
sizeof(*probs));
190 ngram->instance.words = &unigram_idx;
191 priority_queue_add(ngrams, ngram);
193 (uint32 *)
ckd_calloc(order - 1,
sizeof(*raw_ngrams_ptr));
194 for (i = 2; i <= order; ++i) {
197 if (counts[i - 1] <= 0)
201 tmp_ngram->order = i;
202 raw_ngrams_ptr[i - 2] = 0;
203 tmp_ngram->instance = raw_ngrams[i - 2][0];
204 priority_queue_add(ngrams, tmp_ngram);
210 if (top->order == 1) {
211 trie->unigrams[unigram_idx].next = unigram_next(trie, order);
212 words[0] = unigram_idx;
213 probs[0] = trie->unigrams[unigram_idx].prob;
214 if (++unigram_idx == unigram_count + 1) {
218 priority_queue_add(ngrams, top);
221 for (i = 0; i < top->order - 1; i++) {
222 if (words[i] != top->instance.words[i]) {
226 for (j = i; j < top->order - 1; j++) {
227 middle_t *middle = &trie->middle_begin[j - 1];
229 middle_insert(middle, top->instance.words[j],
234 trie->unigrams[top->instance.words[j]].bo;
235 probs[j] = calc_prob;
236 lm_trie_quant_mwrite(trie->quant, address, j - 1,
241 memcpy(words, top->instance.words,
242 top->order *
sizeof(*words));
243 if (top->order == order) {
244 float *weights = top->instance.weights;
246 longest_insert(trie->longest,
247 top->instance.words[top->order - 1]);
248 lm_trie_quant_lwrite(trie->quant, address, weights[0]);
251 float *weights = top->instance.weights;
252 middle_t *middle = &trie->middle_begin[top->order - 2];
254 middle_insert(middle,
255 top->instance.words[top->order - 1],
258 probs[top->order - 1] = weights[0];
259 lm_trie_quant_mwrite(trie->quant, address, top->order - 2,
260 weights[0], weights[1]);
262 raw_ngrams_ptr[top->order - 2]++;
263 if (raw_ngrams_ptr[top->order - 2] < counts[top->order - 1]) {
265 raw_ngrams[top->order -
266 2][raw_ngrams_ptr[top->order - 2]];
267 priority_queue_add(ngrams, top);
274 assert(priority_queue_size(ngrams) == 0);
275 priority_queue_free(ngrams, NULL);
282 lm_trie_init(uint32 unigram_count)
287 memset(trie->prev_hist, -1,
sizeof(trie->prev_hist));
288 memset(trie->backoff, 0,
sizeof(trie->backoff));
291 sizeof(*trie->unigrams));
292 trie->ngram_mem = NULL;
297 lm_trie_create(uint32 unigram_count, lm_trie_quant_type_t quant_type,
300 lm_trie_t *trie = lm_trie_init(unigram_count);
302 (order > 1) ? lm_trie_quant_create(quant_type, order) : 0;
307 lm_trie_read_bin(uint32 * counts,
int order, FILE * fp)
309 lm_trie_t *trie = lm_trie_init(counts[0]);
310 trie->quant = (order > 1) ? lm_trie_quant_read_bin(fp, order) : NULL;
311 fread(trie->unigrams,
sizeof(*trie->unigrams), (counts[0] + 1), fp);
313 lm_trie_alloc_ngram(trie, counts, order);
314 fread(trie->ngram_mem, 1, trie->ngram_mem_size, fp);
320 lm_trie_write_bin(
lm_trie_t * trie, uint32 unigram_count, FILE * fp)
324 lm_trie_quant_write_bin(trie->quant, fp);
325 fwrite(trie->unigrams,
sizeof(*trie->unigrams), (unigram_count + 1),
328 fwrite(trie->ngram_mem, 1, trie->ngram_mem_size, fp);
334 if (trie->ngram_mem) {
340 lm_trie_quant_free(trie->quant);
346 lm_trie_alloc_ngram(
lm_trie_t * trie, uint32 * counts,
int order)
350 uint8 **middle_starts;
352 trie->ngram_mem_size = 0;
353 for (i = 1; i < order - 1; i++) {
354 trie->ngram_mem_size +=
355 middle_size(lm_trie_quant_msize(trie->quant), counts[i],
356 counts[0], counts[i + 1]);
358 trie->ngram_mem_size +=
359 longest_size(lm_trie_quant_lsize(trie->quant), counts[order - 1],
363 sizeof(*trie->ngram_mem));
364 mem_ptr = trie->ngram_mem;
367 trie->middle_end = trie->middle_begin + (order - 2);
369 (uint8 **)
ckd_calloc(order - 2,
sizeof(*middle_starts));
370 for (i = 2; i < order; i++) {
371 middle_starts[i - 2] = mem_ptr;
373 middle_size(lm_trie_quant_msize(trie->quant), counts[i - 1],
374 counts[0], counts[i]);
378 for (i = order - 1; i >= 2; --i) {
379 middle_t *middle_ptr = &trie->middle_begin[i - 2];
380 middle_init(middle_ptr, middle_starts[i - 2],
381 lm_trie_quant_msize(trie->quant), counts[i - 1],
382 counts[0], counts[i],
385 1) ? (
void *) trie->longest : (
void *) &trie->
386 middle_begin[i - 1]);
389 longest_init(trie->longest, mem_ptr, lm_trie_quant_lsize(trie->quant),
398 if (lm_trie_quant_to_train(trie->quant)) {
399 E_INFO(
"Training quantizer\n");
400 for (i = 2; i < order; i++) {
401 lm_trie_quant_train(trie->quant, i, counts[i - 1],
404 lm_trie_quant_train_prob(trie->quant, order, counts[order - 1],
405 raw_ngrams[order - 2]);
407 E_INFO(
"Building LM trie\n");
408 recursive_insert(trie, raw_ngrams, counts, order);
411 if (trie->middle_begin != trie->middle_end) {
413 for (middle_ptr = trie->middle_begin;
414 middle_ptr != trie->middle_end - 1; ++middle_ptr) {
415 middle_t *next_middle_ptr = middle_ptr + 1;
416 middle_finish_loading(middle_ptr,
417 next_middle_ptr->base.insert_index);
419 middle_ptr = trie->middle_end - 1;
420 middle_finish_loading(middle_ptr,
421 trie->longest->base.insert_index);
429 next->begin = ptr->next;
430 next->end = (ptr + 1)->next;
435 calc_pivot(uint32 off, uint32 range, uint32 width)
437 return (
size_t) ((off * width) / (range + 1));
441 uniform_find(
void *base, uint8 total_bits, uint8 key_bits, uint32 key_mask,
442 uint32 before_it, uint32 before_v,
443 uint32 after_it, uint32 after_v, uint32 key, uint32 * out)
447 while (after_it - before_it > 1) {
451 calc_pivot(key - before_v, after_v - before_v,
452 after_it - before_it - 1));
454 address.offset = pivot * (uint32) total_bits;
460 else if (mid > key) {
480 ((
void *) middle->base.base, middle->base.total_bits,
481 middle->base.word_bits, middle->base.word_mask, range->begin - 1,
482 0, range->end, middle->base.max_vocab, word, &at_pointer)) {
488 address.base = middle->base.base;
489 at_pointer *= middle->base.total_bits;
490 at_pointer += middle->base.word_bits;
491 address.offset = at_pointer + middle->quant_bits;
494 middle->next_mask.mask);
495 address.offset += middle->base.total_bits;
498 middle->next_mask.mask);
499 address.offset = at_pointer;
512 ((
void *) longest->base.base, longest->base.total_bits,
513 longest->base.word_bits, longest->base.word_mask,
514 range->begin - 1, 0, range->end, longest->base.max_vocab, word,
520 address.base = longest->base.base;
522 at_pointer * longest->base.total_bits + longest->base.word_bits;
527 get_available_prob(
lm_trie_t * trie, int32 wid, int32 * hist,
528 int max_order, int32 n_hist, int32 * n_used)
534 uint8 independent_left;
535 int32 *hist_iter, *hist_end;
538 prob = unigram_find(trie->unigrams, wid, &node)->prob;
545 independent_left = (node.begin == node.end);
547 hist_end = hist + n_hist;
548 for (;; order_minus_2++, hist_iter++) {
549 if (hist_iter == hist_end)
551 if (independent_left)
553 if (order_minus_2 == max_order - 2)
557 middle_find(&trie->middle_begin[order_minus_2], *hist_iter,
559 independent_left = (address.base == NULL)
560 || (node.begin == node.end);
563 if (address.base == NULL)
565 prob = lm_trie_quant_mpread(trie->quant, address, order_minus_2);
566 *n_used = order_minus_2 + 2;
569 address = longest_find(trie->longest, *hist_iter, &node);
570 if (address.base != NULL) {
571 prob = lm_trie_quant_lpread(trie->quant, address);
578 get_available_backoff(
lm_trie_t * trie, int32 start, int32 * hist,
581 float backoff = 0.0f;
585 unigram_t *first_hist = unigram_find(trie->unigrams, hist[0], &node);
587 backoff += first_hist->bo;
590 order_minus_2 = start - 2;
591 for (hist_iter = hist + start - 1; hist_iter < hist + n_hist;
592 hist_iter++, order_minus_2++) {
594 middle_find(&trie->middle_begin[order_minus_2], *hist_iter,
596 if (address.base == NULL)
599 lm_trie_quant_mboread(trie->quant, address, order_minus_2);
605 lm_trie_nobo_score(
lm_trie_t * trie, int32 wid, int32 * hist,
606 int max_order, int32 n_hist, int32 * n_used)
609 get_available_prob(trie, wid, hist, max_order, n_hist, n_used);
610 if (n_hist < *n_used)
612 return prob + get_available_backoff(trie, *n_used, hist, n_hist);
616 lm_trie_hist_score(
lm_trie_t * trie, int32 wid, int32 * hist, int32 n_hist,
625 prob = unigram_find(trie->unigrams, wid, &node)->prob;
628 for (i = 0; i < n_hist - 1; i++) {
629 address = middle_find(&trie->middle_begin[i], hist[i], &node);
630 if (address.base == NULL) {
631 for (j = i; j < n_hist; j++) {
632 prob += trie->backoff[j];
638 prob = lm_trie_quant_mpread(trie->quant, address, i);
641 address = longest_find(trie->longest, hist[n_hist - 1], &node);
642 if (address.base == NULL) {
643 return prob + trie->backoff[n_hist - 1];
647 return lm_trie_quant_lpread(trie->quant, address);
652 history_matches(int32 * hist, int32 * prev_hist, int32 n_hist)
655 for (i = 0; i < n_hist; i++) {
656 if (hist[i] != prev_hist[i]) {
664 update_backoff(
lm_trie_t * trie, int32 * hist, int32 n_hist)
670 memset(trie->backoff, 0,
sizeof(trie->backoff));
671 trie->backoff[0] = unigram_find(trie->unigrams, hist[0], &node)->bo;
672 for (i = 1; i < n_hist; i++) {
673 address = middle_find(&trie->middle_begin[i - 1], hist[i], &node);
674 if (address.base == NULL) {
678 lm_trie_quant_mboread(trie->quant, address, i - 1);
680 memcpy(trie->prev_hist, hist, n_hist *
sizeof(*hist));
684 lm_trie_score(
lm_trie_t * trie,
int order, int32 wid, int32 * hist,
685 int32 n_hist, int32 * n_used)
687 if (n_hist < order - 1) {
688 return lm_trie_nobo_score(trie, wid, hist, order, n_hist, n_used);
691 assert(n_hist == order - 1);
692 if (!history_matches(hist, (int32 *) trie->prev_hist, n_hist)) {
693 update_backoff(trie, hist, n_hist);
695 return lm_trie_hist_score(trie, wid, hist, n_hist, n_used);
#define E_INFO(...)
Print logging information to standard error stream.
#define ckd_calloc(n, sz)
Macros to simplify the use of above functions.
SPHINXBASE_EXPORT void bitarr_mask_from_max(bitarr_mask_t *bit_mask, uint32 max_value)
Fills mask for certain int range according to provided max value.
#define E_ERROR(...)
Print error message to error log.
SPHINXBASE_EXPORT uint8 bitarr_required_bits(uint32 max_value)
Computes amount of bits required ti store integers upto value provided.
Sphinx's memory allocation/deallocation routines.
Basic type definitions used in Sphinx.
SPHINXBASE_EXPORT void bitarr_write_int25(bitarr_address_t address, uint8 length, uint32 value)
Write specified value into bit array.
SPHINXBASE_EXPORT void ckd_free(void *ptr)
Test and free a 1-D array.
Structure that stores address of certain value in bit array.
SPHINXBASE_EXPORT uint32 bitarr_read_int25(bitarr_address_t address, uint8 length, uint32 mask)
Read uint32 value from bit array.
Implementation of logging routines.