SphinxBase  5prealpha
lm_trie_quant.c
1 /* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */
2 /* ====================================================================
3  * Copyright (c) 2015 Carnegie Mellon University. All rights
4  * reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  *
10  * 1. Redistributions of source code must retain the above copyright
11  * notice, this list of conditions and the following disclaimer.
12  *
13  * 2. Redistributions in binary form must reproduce the above copyright
14  * notice, this list of conditions and the following disclaimer in
15  * the documentation and/or other materials provided with the
16  * distribution.
17  *
18  * This work was supported in part by funding from the Defense Advanced
19  * Research Projects Agency and the National Science Foundation of the
20  * United States of America, and the CMU Sphinx Speech Consortium.
21  *
22  * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND
23  * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
24  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
25  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
26  * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33  *
34  * ====================================================================
35  *
36  */
37 
38 #include <math.h>
39 
40 #include <sphinxbase/prim_type.h>
41 #include <sphinxbase/ckd_alloc.h>
42 #include <sphinxbase/err.h>
43 
44 #include "ngram_model_internal.h"
45 #include "lm_trie_quant.h"
46 
47 #define FLOAT_INF (0x7f800000)
48 
49 typedef struct bins_s {
50  float *begin;
51  const float *end;
52 } bins_t;
53 
55  lm_trie_quant_type_t quant_type;
56  bins_t tables[NGRAM_MAX_ORDER - 1][2];
57  bins_t *longest;
58  uint8 *mem;
59  size_t mem_size;
60  uint8 prob_bits;
61  uint8 bo_bits;
62  uint32 prob_mask;
63  uint32 bo_mask;
64 };
65 
66 static void
67 bins_create(bins_t * bins, uint8 bits, float *begin)
68 {
69  bins->begin = begin;
70  bins->end = bins->begin + (1ULL << bits);
71 }
72 
73 static float *
74 lower_bound(float *first, const float *last, float val)
75 {
76  int count, step;
77  float *it;
78 
79  count = last - first;
80  while (count > 0) {
81  it = first;
82  step = count / 2;
83  it += step;
84  if (*it < val) {
85  first = ++it;
86  count -= step + 1;
87  }
88  else {
89  count = step;
90  }
91  }
92  return first;
93 }
94 
95 static uint64
96 bins_encode(bins_t * bins, float value)
97 {
98  float *above = lower_bound(bins->begin, bins->end, value);
99  if (above == bins->begin)
100  return 0;
101  if (above == bins->end)
102  return bins->end - bins->begin - 1;
103  return above - bins->begin - (value - *(above - 1) < *above - value);
104 }
105 
106 static float
107 bins_decode(bins_t * bins, size_t off)
108 {
109  return bins->begin[off];
110 }
111 
112 static size_t
113 quant_apply_size(int order, int prob_bits, int bo_bits)
114 {
115  size_t longest_table = (1U << prob_bits) * sizeof(float);
116  size_t middle_table = (1U << bo_bits) * sizeof(float) + longest_table;
117  // unigrams are currently not quantized so no need for a table.
118  return (order - 2) * middle_table + longest_table;
119 }
120 
121 static size_t
122 quant_size(lm_trie_quant_type_t quant_type, int order)
123 {
124  switch (quant_type) {
125  case NO_QUANT:
126  return 0;
127  case QUANT_16:
128  return quant_apply_size(order, 16, 16);
129  //TODO implement different quantatization stages
130  default:
131  E_INFO("Unsupported quantatization type\n");
132  return 0;
133  }
134 }
135 
137 lm_trie_quant_create(lm_trie_quant_type_t quant_type, int order)
138 {
139  float *start;
140  int i;
141  lm_trie_quant_t *quant =
142  (lm_trie_quant_t *) ckd_calloc(1, sizeof(*quant));
143  quant->quant_type = quant_type;
144  quant->mem_size = quant_size(quant_type, order);
145  quant->mem =
146  (uint8 *) ckd_calloc(quant->mem_size, sizeof(*quant->mem));
147  switch (quant_type) {
148  case NO_QUANT:
149  return quant;
150  case QUANT_16:
151  quant->prob_bits = 16;
152  quant->bo_bits = 16;
153  quant->prob_mask = (1U << quant->prob_bits) - 1;
154  quant->bo_mask = (1U << quant->bo_bits) - 1;
155  break;
156  default:
157  E_INFO("Unsupported quantization type\n");
158  return quant;
159  }
160  start = (float *) (quant->mem);
161  for (i = 0; i < order - 2; i++) {
162  bins_create(&quant->tables[i][0], quant->prob_bits, start);
163  start += (1ULL << quant->prob_bits);
164  bins_create(&quant->tables[i][1], quant->bo_bits, start);
165  start += (1ULL << quant->bo_bits);
166  }
167  bins_create(&quant->tables[order - 2][0], quant->prob_bits, start);
168  quant->longest = &quant->tables[order - 2][0];
169  return quant;
170 }
171 
172 
174 lm_trie_quant_read_bin(FILE * fp, int order)
175 {
176  int quant_type_int;
177  lm_trie_quant_type_t quant_type;
178  lm_trie_quant_t *quant;
179 
180  fread(&quant_type_int, sizeof(quant_type_int), 1, fp);
181  quant_type = (lm_trie_quant_type_t) quant_type_int;
182  quant = lm_trie_quant_create(quant_type, order);
183  fread(quant->mem, sizeof(*quant->mem), quant->mem_size, fp);
184 
185  return quant;
186 }
187 
188 void
189 lm_trie_quant_write_bin(lm_trie_quant_t * quant, FILE * fp)
190 {
191  int quant_type_int = (int) quant->quant_type;
192 
193  fwrite(&quant_type_int, sizeof(quant_type_int), 1, fp);
194  fwrite(quant->mem, sizeof(*quant->mem), quant->mem_size, fp);
195 }
196 
197 void
198 lm_trie_quant_free(lm_trie_quant_t * quant)
199 {
200  if (quant->mem)
201  ckd_free(quant->mem);
202  ckd_free(quant);
203 }
204 
205 uint8
206 lm_trie_quant_msize(lm_trie_quant_t * quant)
207 {
208  switch (quant->quant_type) {
209  case NO_QUANT:
210  return 63;
211  case QUANT_16:
212  return 32; //16 bits for prob + 16 bits for bo
213  //TODO implement different quantatization stages
214  default:
215  E_INFO("Unsupported quantatization type\n");
216  return 0;
217  }
218 }
219 
220 uint8
221 lm_trie_quant_lsize(lm_trie_quant_t * quant)
222 {
223  switch (quant->quant_type) {
224  case NO_QUANT:
225  return 31;
226  case QUANT_16:
227  return 16; //16 bits for probs
228  //TODO implement different quantatization stages
229  default:
230  E_INFO("Unsupported quantatization type\n");
231  return 0;
232  }
233 }
234 
235 uint8
236 lm_trie_quant_to_train(lm_trie_quant_t * quant)
237 {
238  return quant->quant_type > 0;
239 }
240 
241 static int
242 weights_comparator(const void *a, const void *b)
243 {
244  return (int) (*(float *) a - *(float *) b);
245 }
246 
247 static void
248 make_bins(float *values, uint32 values_num, float *centers, uint32 bins)
249 {
250  float *finish, *start;
251  uint32 i;
252 
253  qsort(values, values_num, sizeof(*values), &weights_comparator);
254  start = values;
255  for (i = 0; i < bins; i++, centers++, start = finish) {
256  finish = values + (size_t) ((uint64) values_num * (i + 1) / bins);
257  if (finish == start) {
258  // zero length bucket.
259  *centers = i ? *(centers - 1) : -FLOAT_INF;
260  }
261  else {
262  float sum = 0.0f;
263  float *ptr;
264  for (ptr = start; ptr != finish; ptr++) {
265  sum += *ptr;
266  }
267  *centers = sum / (float) (finish - start);
268  }
269  }
270 }
271 
272 void
273 lm_trie_quant_train(lm_trie_quant_t * quant, int order, uint32 counts,
274  ngram_raw_t * raw_ngrams)
275 {
276  float *probs;
277  float *backoffs;
278  float *centers;
279  uint32 backoff_num;
280  uint32 prob_num;
281  ngram_raw_t *raw_ngrams_end;
282 
283  probs = (float *) ckd_calloc(counts, sizeof(*probs));
284  backoffs = (float *) ckd_calloc(counts, sizeof(*backoffs));
285  raw_ngrams_end = raw_ngrams + counts;
286 
287  for (backoff_num = 0, prob_num = 0; raw_ngrams != raw_ngrams_end;
288  raw_ngrams++) {
289  float *weights = raw_ngrams->weights;
290  probs[prob_num++] = *weights; //first goes prob
291  weights++; //increment to backoff
292  backoffs[backoff_num++] = *weights;
293  }
294 
295  make_bins(probs, prob_num, quant->tables[order - 2][0].begin,
296  1ULL << quant->prob_bits);
297  centers = quant->tables[order - 2][1].begin;
298  make_bins(backoffs, backoff_num, centers, (1ULL << quant->bo_bits));
299  ckd_free(probs);
300  ckd_free(backoffs);
301 }
302 
303 void
304 lm_trie_quant_train_prob(lm_trie_quant_t * quant, int order, uint32 counts,
305  ngram_raw_t * raw_ngrams)
306 {
307  float *probs;
308  uint32 prob_num;
309  ngram_raw_t *raw_ngrams_end;
310 
311  probs = (float *) ckd_calloc(counts, sizeof(*probs));
312  raw_ngrams_end = raw_ngrams + counts;
313 
314  for (prob_num = 0; raw_ngrams != raw_ngrams_end; raw_ngrams++) {
315  float *weights = raw_ngrams->weights;
316  probs[prob_num++] = *weights;
317  }
318 
319  make_bins(probs, prob_num, quant->tables[order - 2][0].begin,
320  1ULL << quant->prob_bits);
321  ckd_free(probs);
322 }
323 
324 void
325 lm_trie_quant_mwrite(lm_trie_quant_t * quant, bitarr_address_t address,
326  int order_minus_2, float prob, float backoff)
327 {
328  switch (quant->quant_type) {
329  case NO_QUANT:
330  bitarr_write_negfloat(address, prob);
331  address.offset += 31;
332  bitarr_write_float(address, backoff);
333  break;
334  case QUANT_16:
335  bitarr_write_int57(address, quant->prob_bits + quant->bo_bits,
336  (uint64) ((bins_encode
337  (&quant->tables[order_minus_2][0],
338  prob) << quant->
339  bo_bits) | bins_encode(&quant->
340  tables
341  [order_minus_2]
342  [1],
343  backoff)));
344  break;
345  //TODO implement different quantatization stages
346  default:
347  E_INFO("Unsupported quantatization type\n");
348  }
349 }
350 
351 void
352 lm_trie_quant_lwrite(lm_trie_quant_t * quant, bitarr_address_t address,
353  float prob)
354 {
355  switch (quant->quant_type) {
356  case NO_QUANT:
357  bitarr_write_negfloat(address, prob);
358  break;
359  case QUANT_16:
360  bitarr_write_int25(address, quant->prob_bits,
361  (uint32) bins_encode(quant->longest, prob));
362  break;
363  //TODO implement different quantatization stages
364  default:
365  E_INFO("Unsupported quantization type\n");
366  }
367 }
368 
369 float
370 lm_trie_quant_mboread(lm_trie_quant_t * quant, bitarr_address_t address,
371  int order_minus_2)
372 {
373  switch (quant->quant_type) {
374  case NO_QUANT:
375  address.offset += 31;
376  return bitarr_read_float(address);
377  case QUANT_16:
378  return bins_decode(&quant->tables[order_minus_2][1],
379  bitarr_read_int25(address, quant->bo_bits,
380  quant->bo_mask));
381  //TODO implement different quantatization stages
382  default:
383  E_INFO("Unsupported quantatization type\n");
384  return 0.0f;
385  }
386 }
387 
388 float
389 lm_trie_quant_mpread(lm_trie_quant_t * quant, bitarr_address_t address,
390  int order_minus_2)
391 {
392  switch (quant->quant_type) {
393  case NO_QUANT:
394  return bitarr_read_negfloat(address);
395  case QUANT_16:
396  address.offset += quant->bo_bits;
397  return bins_decode(&quant->tables[order_minus_2][0],
398  bitarr_read_int25(address, quant->prob_bits,
399  quant->prob_mask));
400  //TODO implement different quantatization stages
401  default:
402  E_INFO("Unsupported quantatization type\n");
403  return 0.0f;
404  }
405 }
406 
407 float
408 lm_trie_quant_lpread(lm_trie_quant_t * quant, bitarr_address_t address)
409 {
410  switch (quant->quant_type) {
411  case NO_QUANT:
412  return bitarr_read_negfloat(address);
413  case QUANT_16:
414  return bins_decode(quant->longest,
415  bitarr_read_int25(address, quant->prob_bits,
416  quant->prob_mask));
417  //TODO implement different quantatization stages
418  default:
419  E_INFO("Unsupported quantatization type\n");
420  return 0.0f;
421  }
422 }
#define E_INFO(...)
Print logging information to standard error stream.
Definition: err.h:114
#define ckd_calloc(n, sz)
Macros to simplify the use of above functions.
Definition: ckd_alloc.h:248
SPHINXBASE_EXPORT void bitarr_write_negfloat(bitarr_address_t address, float value)
Writes non positive float32 to bit array.
Definition: bitarr.c:150
Sphinx's memory allocation/deallocation routines.
Basic type definitions used in Sphinx.
SPHINXBASE_EXPORT float bitarr_read_negfloat(bitarr_address_t address)
Read non positive float32 from bit array.
Definition: bitarr.c:141
SPHINXBASE_EXPORT void bitarr_write_int25(bitarr_address_t address, uint8 length, uint32 value)
Write specified value into bit array.
Definition: bitarr.c:128
SPHINXBASE_EXPORT void ckd_free(void *ptr)
Test and free a 1-D array.
Definition: ckd_alloc.c:244
Structure that stores address of certain value in bit array.
Definition: bitarr.h:75
SPHINXBASE_EXPORT uint32 bitarr_read_int25(bitarr_address_t address, uint8 length, uint32 mask)
Read uint32 value from bit array.
Definition: bitarr.c:116
SPHINXBASE_EXPORT void bitarr_write_float(bitarr_address_t address, float value)
Writes float32 to bit array.
Definition: bitarr.c:165
SPHINXBASE_EXPORT void bitarr_write_int57(bitarr_address_t address, uint8 length, uint64 value)
Write specified value into bit array.
Definition: bitarr.c:103
Implementation of logging routines.
SPHINXBASE_EXPORT float bitarr_read_float(bitarr_address_t address)
Reads float32 from bit array.
Definition: bitarr.c:158
Definition: dtoa.c:178