polars_compute/
hyperloglogplus.rs

1//! # HyperLogLogPlus
2//!
3//! `hyperloglogplus` module contains implementation of HyperLogLogPlus
4//! algorithm for cardinality estimation so that [`crate::series::approx_n_unique`] function can
5//! be efficiently implemented.
6//!
7//! This module borrows code from [arrow-datafusion](https://github.com/apache/arrow-datafusion/blob/93771052c5ac31f2cf22b8c25bf938656afe1047/datafusion/physical-expr/src/aggregate/hyperloglog.rs).
8//!
9//! # Examples
10//!
11//! ```
12//!     # use polars_compute::hyperloglogplus::*;
13//!     let mut hllp = HyperLogLog::new();
14//!     hllp.add(&12345);
15//!     hllp.add(&23456);
16//!
17//!     assert_eq!(hllp.count(), 2);
18//! ```
19
20use std::hash::{BuildHasher, Hash};
21use std::marker::PhantomData;
22
23use polars_utils::aliases::PlFixedStateQuality;
24
25/// The greater is P, the smaller the error.
26const HLL_P: usize = 14_usize;
27/// The number of bits of the hash value used determining the number of leading zeros
28const HLL_Q: usize = 64_usize - HLL_P;
29const NUM_REGISTERS: usize = 1_usize << HLL_P;
30/// Mask to obtain index into the registers
31const HLL_P_MASK: u64 = (NUM_REGISTERS as u64) - 1;
32
33#[derive(Clone, Debug)]
34pub struct HyperLogLog<T>
35where
36    T: Hash + ?Sized,
37{
38    registers: [u8; NUM_REGISTERS],
39    phantom: PhantomData<T>,
40}
41
42impl<T> Default for HyperLogLog<T>
43where
44    T: Hash + ?Sized,
45{
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51/// Fixed seed for the hashing so that values are consistent across runs
52///
53/// Note that when we later move on to have serialized HLL register binaries
54/// shared across cluster, this SEED will have to be consistent across all
55/// parties otherwise we might have corruption. So ideally for later this seed
56/// shall be part of the serialized form (or stay unchanged across versions).
57const SEED: PlFixedStateQuality = PlFixedStateQuality::with_seed(0);
58
59impl<T> HyperLogLog<T>
60where
61    T: Hash + ?Sized,
62{
63    /// Creates a new, empty HyperLogLog.
64    pub fn new() -> Self {
65        let registers = [0; NUM_REGISTERS];
66        Self::new_with_registers(registers)
67    }
68
69    /// Creates a HyperLogLog from already populated registers
70    /// note that this method should not be invoked in untrusted environment
71    /// because the internal structure of registers are not examined.
72    pub(crate) fn new_with_registers(registers: [u8; NUM_REGISTERS]) -> Self {
73        Self {
74            registers,
75            phantom: PhantomData,
76        }
77    }
78
79    #[inline]
80    fn hash_value(&self, obj: &T) -> u64 {
81        SEED.hash_one(obj)
82    }
83
84    /// Adds an element to the HyperLogLog.
85    pub fn add(&mut self, obj: &T) {
86        let hash = self.hash_value(obj);
87        let index = (hash & HLL_P_MASK) as usize;
88        let p = ((hash >> HLL_P) | (1_u64 << HLL_Q)).trailing_zeros() + 1;
89        self.registers[index] = self.registers[index].max(p as u8);
90    }
91
92    /// Get the register histogram (each value in register index into
93    /// the histogram; u32 is enough because we only have 2**14=16384 registers
94    #[inline]
95    fn get_histogram(&self) -> [u32; HLL_Q + 2] {
96        let mut histogram = [0; HLL_Q + 2];
97        // hopefully this can be unrolled
98        for r in self.registers {
99            histogram[r as usize] += 1;
100        }
101        histogram
102    }
103
104    /// Merge the other [`HyperLogLog`] into this one
105    pub fn merge(&mut self, other: &HyperLogLog<T>) {
106        assert!(
107            self.registers.len() == other.registers.len(),
108            "unexpected got unequal register size, expect {}, got {}",
109            self.registers.len(),
110            other.registers.len()
111        );
112        for i in 0..self.registers.len() {
113            self.registers[i] = self.registers[i].max(other.registers[i]);
114        }
115    }
116
117    /// Guess the number of unique elements seen by the HyperLogLog.
118    pub fn count(&self) -> usize {
119        let histogram = self.get_histogram();
120        let m = NUM_REGISTERS as f64;
121        let mut z = m * hll_tau((m - histogram[HLL_Q + 1] as f64) / m);
122        for i in histogram[1..=HLL_Q].iter().rev() {
123            z += *i as f64;
124            z *= 0.5;
125        }
126        z += m * hll_sigma(histogram[0] as f64 / m);
127        (0.5 / 2_f64.ln() * m * m / z).round() as usize
128    }
129}
130
131/// Helper function sigma as defined in
132/// "New cardinality estimation algorithms for HyperLogLog sketches"
133/// Otmar Ertl, arXiv:1702.01284
134#[inline]
135fn hll_sigma(x: f64) -> f64 {
136    if x == 1. {
137        f64::INFINITY
138    } else {
139        let mut y = 1.0;
140        let mut z = x;
141        let mut x = x;
142        loop {
143            x *= x;
144            let z_prime = z;
145            z += x * y;
146            y += y;
147            if z_prime == z {
148                break;
149            }
150        }
151        z
152    }
153}
154
155/// Helper function tau as defined in
156/// "New cardinality estimation algorithms for HyperLogLog sketches"
157/// Otmar Ertl, arXiv:1702.01284
158#[inline]
159fn hll_tau(x: f64) -> f64 {
160    if x == 0.0 || x == 1.0 {
161        0.0
162    } else {
163        let mut y = 1.0;
164        let mut z = 1.0 - x;
165        let mut x = x;
166        loop {
167            x = x.sqrt();
168            let z_prime = z;
169            y *= 0.5;
170            z -= (1.0 - x).powi(2) * y;
171            if z_prime == z {
172                break;
173            }
174        }
175        z / 3.0
176    }
177}
178
179impl<T> AsRef<[u8]> for HyperLogLog<T>
180where
181    T: Hash + ?Sized,
182{
183    fn as_ref(&self) -> &[u8] {
184        &self.registers
185    }
186}
187
188impl<T> Extend<T> for HyperLogLog<T>
189where
190    T: Hash,
191{
192    fn extend<S: IntoIterator<Item = T>>(&mut self, iter: S) {
193        for elem in iter {
194            self.add(&elem);
195        }
196    }
197}
198
199impl<'a, T> Extend<&'a T> for HyperLogLog<T>
200where
201    T: 'a + Hash + ?Sized,
202{
203    fn extend<S: IntoIterator<Item = &'a T>>(&mut self, iter: S) {
204        for elem in iter {
205            self.add(elem);
206        }
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::{HyperLogLog, NUM_REGISTERS};
213
214    fn compare_with_delta(got: usize, expected: usize) {
215        let expected = expected as f64;
216        let diff = (got as f64) - expected;
217        let diff = diff.abs() / expected;
218        // times 6 because we want the tests to be stable
219        // so we allow a rather large margin of error
220        // this is adopted from redis's unit test version as well
221        let margin = 1.04 / ((NUM_REGISTERS as f64).sqrt()) * 6.0;
222        assert!(
223            diff <= margin,
224            "{} is not near {} percent of {} which is ({}, {})",
225            got,
226            margin,
227            expected,
228            expected * (1.0 - margin),
229            expected * (1.0 + margin)
230        );
231    }
232
233    macro_rules! sized_number_test {
234        ($SIZE: expr, $T: tt) => {{
235            let mut hll = HyperLogLog::<$T>::new();
236            for i in 0..$SIZE {
237                hll.add(&i);
238            }
239            compare_with_delta(hll.count(), $SIZE);
240        }};
241    }
242
243    macro_rules! typed_large_number_test {
244        ($SIZE: expr) => {{
245            sized_number_test!($SIZE, u64);
246            sized_number_test!($SIZE, u128);
247            sized_number_test!($SIZE, i64);
248            sized_number_test!($SIZE, i128);
249        }};
250    }
251
252    macro_rules! typed_number_test {
253        ($SIZE: expr) => {{
254            sized_number_test!($SIZE, u16);
255            sized_number_test!($SIZE, u32);
256            sized_number_test!($SIZE, i16);
257            sized_number_test!($SIZE, i32);
258            typed_large_number_test!($SIZE);
259        }};
260    }
261
262    #[test]
263    fn test_empty() {
264        let hll = HyperLogLog::<u64>::new();
265        assert_eq!(hll.count(), 0);
266    }
267
268    #[test]
269    fn test_one() {
270        let mut hll = HyperLogLog::<u64>::new();
271        hll.add(&1);
272        assert_eq!(hll.count(), 1);
273    }
274
275    #[test]
276    fn test_number_100() {
277        typed_number_test!(100);
278    }
279
280    #[test]
281    fn test_number_1k() {
282        typed_number_test!(1_000);
283    }
284
285    #[test]
286    fn test_number_10k() {
287        typed_number_test!(10_000);
288    }
289
290    #[test]
291    fn test_number_100k() {
292        typed_large_number_test!(100_000);
293    }
294
295    #[test]
296    fn test_number_1m() {
297        typed_large_number_test!(1_000_000);
298    }
299
300    #[test]
301    fn test_u8() {
302        let mut hll = HyperLogLog::<[u8]>::new();
303        for i in 0..1000 {
304            let s = i.to_string();
305            let b = s.as_bytes();
306            hll.add(b);
307        }
308        compare_with_delta(hll.count(), 1000);
309    }
310
311    #[test]
312    fn test_string() {
313        let mut hll = HyperLogLog::<String>::new();
314        hll.extend((0..1000).map(|i| i.to_string()));
315        compare_with_delta(hll.count(), 1000);
316    }
317
318    #[test]
319    fn test_empty_merge() {
320        let mut hll = HyperLogLog::<u64>::new();
321        hll.merge(&HyperLogLog::<u64>::new());
322        assert_eq!(hll.count(), 0);
323    }
324
325    #[test]
326    fn test_merge_overlapped() {
327        let mut hll = HyperLogLog::<String>::new();
328        hll.extend((0..1000).map(|i| i.to_string()));
329
330        let mut other = HyperLogLog::<String>::new();
331        other.extend((0..1000).map(|i| i.to_string()));
332
333        hll.merge(&other);
334        compare_with_delta(hll.count(), 1000);
335    }
336
337    #[test]
338    fn test_repetition() {
339        let mut hll = HyperLogLog::<u32>::new();
340        for i in 0..1_000_000 {
341            hll.add(&(i % 1000));
342        }
343        compare_with_delta(hll.count(), 1000);
344    }
345}