aho_corasick/util/
alphabet.rs

1use crate::util::int::Usize;
2
3/// A representation of byte oriented equivalence classes.
4///
5/// This is used in finite state machines to reduce the size of the transition
6/// table. This can have a particularly large impact not only on the total size
7/// of an FSM, but also on FSM build times because it reduces the number of
8/// transitions that need to be visited/set.
9#[derive(Clone, Copy)]
10pub(crate) struct ByteClasses([u8; 256]);
11
12impl ByteClasses {
13    /// Creates a new set of equivalence classes where all bytes are mapped to
14    /// the same class.
15    pub(crate) fn empty() -> ByteClasses {
16        ByteClasses([0; 256])
17    }
18
19    /// Creates a new set of equivalence classes where each byte belongs to
20    /// its own equivalence class.
21    pub(crate) fn singletons() -> ByteClasses {
22        let mut classes = ByteClasses::empty();
23        for b in 0..=255 {
24            classes.set(b, b);
25        }
26        classes
27    }
28
29    /// Set the equivalence class for the given byte.
30    #[inline]
31    pub(crate) fn set(&mut self, byte: u8, class: u8) {
32        self.0[usize::from(byte)] = class;
33    }
34
35    /// Get the equivalence class for the given byte.
36    #[inline]
37    pub(crate) fn get(&self, byte: u8) -> u8 {
38        self.0[usize::from(byte)]
39    }
40
41    /// Return the total number of elements in the alphabet represented by
42    /// these equivalence classes. Equivalently, this returns the total number
43    /// of equivalence classes.
44    #[inline]
45    pub(crate) fn alphabet_len(&self) -> usize {
46        // Add one since the number of equivalence classes is one bigger than
47        // the last one.
48        usize::from(self.0[255]) + 1
49    }
50
51    /// Returns the stride, as a base-2 exponent, required for these
52    /// equivalence classes.
53    ///
54    /// The stride is always the smallest power of 2 that is greater than or
55    /// equal to the alphabet length. This is done so that converting between
56    /// state IDs and indices can be done with shifts alone, which is much
57    /// faster than integer division. The "stride2" is the exponent. i.e.,
58    /// `2^stride2 = stride`.
59    pub(crate) fn stride2(&self) -> usize {
60        let zeros = self.alphabet_len().next_power_of_two().trailing_zeros();
61        usize::try_from(zeros).unwrap()
62    }
63
64    /// Returns the stride for these equivalence classes, which corresponds
65    /// to the smallest power of 2 greater than or equal to the number of
66    /// equivalence classes.
67    pub(crate) fn stride(&self) -> usize {
68        1 << self.stride2()
69    }
70
71    /// Returns true if and only if every byte in this class maps to its own
72    /// equivalence class. Equivalently, there are 257 equivalence classes
73    /// and each class contains exactly one byte (plus the special EOI class).
74    #[inline]
75    pub(crate) fn is_singleton(&self) -> bool {
76        self.alphabet_len() == 256
77    }
78
79    /// Returns an iterator over all equivalence classes in this set.
80    pub(crate) fn iter(&self) -> ByteClassIter {
81        ByteClassIter { it: 0..self.alphabet_len() }
82    }
83
84    /// Returns an iterator of the bytes in the given equivalence class.
85    pub(crate) fn elements(&self, class: u8) -> ByteClassElements {
86        ByteClassElements { classes: self, class, bytes: 0..=255 }
87    }
88
89    /// Returns an iterator of byte ranges in the given equivalence class.
90    ///
91    /// That is, a sequence of contiguous ranges are returned. Typically, every
92    /// class maps to a single contiguous range.
93    fn element_ranges(&self, class: u8) -> ByteClassElementRanges {
94        ByteClassElementRanges { elements: self.elements(class), range: None }
95    }
96}
97
98impl core::fmt::Debug for ByteClasses {
99    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
100        if self.is_singleton() {
101            write!(f, "ByteClasses(<one-class-per-byte>)")
102        } else {
103            write!(f, "ByteClasses(")?;
104            for (i, class) in self.iter().enumerate() {
105                if i > 0 {
106                    write!(f, ", ")?;
107                }
108                write!(f, "{:?} => [", class)?;
109                for (start, end) in self.element_ranges(class) {
110                    if start == end {
111                        write!(f, "{:?}", start)?;
112                    } else {
113                        write!(f, "{:?}-{:?}", start, end)?;
114                    }
115                }
116                write!(f, "]")?;
117            }
118            write!(f, ")")
119        }
120    }
121}
122
123/// An iterator over each equivalence class.
124#[derive(Debug)]
125pub(crate) struct ByteClassIter {
126    it: core::ops::Range<usize>,
127}
128
129impl Iterator for ByteClassIter {
130    type Item = u8;
131
132    fn next(&mut self) -> Option<u8> {
133        self.it.next().map(|class| class.as_u8())
134    }
135}
136
137/// An iterator over all elements in a specific equivalence class.
138#[derive(Debug)]
139pub(crate) struct ByteClassElements<'a> {
140    classes: &'a ByteClasses,
141    class: u8,
142    bytes: core::ops::RangeInclusive<u8>,
143}
144
145impl<'a> Iterator for ByteClassElements<'a> {
146    type Item = u8;
147
148    fn next(&mut self) -> Option<u8> {
149        while let Some(byte) = self.bytes.next() {
150            if self.class == self.classes.get(byte) {
151                return Some(byte);
152            }
153        }
154        None
155    }
156}
157
158/// An iterator over all elements in an equivalence class expressed as a
159/// sequence of contiguous ranges.
160#[derive(Debug)]
161pub(crate) struct ByteClassElementRanges<'a> {
162    elements: ByteClassElements<'a>,
163    range: Option<(u8, u8)>,
164}
165
166impl<'a> Iterator for ByteClassElementRanges<'a> {
167    type Item = (u8, u8);
168
169    fn next(&mut self) -> Option<(u8, u8)> {
170        loop {
171            let element = match self.elements.next() {
172                None => return self.range.take(),
173                Some(element) => element,
174            };
175            match self.range.take() {
176                None => {
177                    self.range = Some((element, element));
178                }
179                Some((start, end)) => {
180                    if usize::from(end) + 1 != usize::from(element) {
181                        self.range = Some((element, element));
182                        return Some((start, end));
183                    }
184                    self.range = Some((start, element));
185                }
186            }
187        }
188    }
189}
190
191/// A partitioning of bytes into equivalence classes.
192///
193/// A byte class set keeps track of an *approximation* of equivalence classes
194/// of bytes during NFA construction. That is, every byte in an equivalence
195/// class cannot discriminate between a match and a non-match.
196///
197/// Note that this may not compute the minimal set of equivalence classes.
198/// Basically, any byte in a pattern given to the noncontiguous NFA builder
199/// will automatically be treated as its own equivalence class. All other
200/// bytes---any byte not in any pattern---will be treated as their own
201/// equivalence classes. In theory, all bytes not in any pattern should
202/// be part of a single equivalence class, but in practice, we only treat
203/// contiguous ranges of bytes as an equivalence class. So the number of
204/// classes computed may be bigger than necessary. This usually doesn't make
205/// much of a difference, and keeps the implementation simple.
206#[derive(Clone, Debug)]
207pub(crate) struct ByteClassSet(ByteSet);
208
209impl Default for ByteClassSet {
210    fn default() -> ByteClassSet {
211        ByteClassSet::empty()
212    }
213}
214
215impl ByteClassSet {
216    /// Create a new set of byte classes where all bytes are part of the same
217    /// equivalence class.
218    pub(crate) fn empty() -> Self {
219        ByteClassSet(ByteSet::empty())
220    }
221
222    /// Indicate the the range of byte given (inclusive) can discriminate a
223    /// match between it and all other bytes outside of the range.
224    pub(crate) fn set_range(&mut self, start: u8, end: u8) {
225        debug_assert!(start <= end);
226        if start > 0 {
227            self.0.add(start - 1);
228        }
229        self.0.add(end);
230    }
231
232    /// Convert this boolean set to a map that maps all byte values to their
233    /// corresponding equivalence class. The last mapping indicates the largest
234    /// equivalence class identifier (which is never bigger than 255).
235    pub(crate) fn byte_classes(&self) -> ByteClasses {
236        let mut classes = ByteClasses::empty();
237        let mut class = 0u8;
238        let mut b = 0u8;
239        loop {
240            classes.set(b, class);
241            if b == 255 {
242                break;
243            }
244            if self.0.contains(b) {
245                class = class.checked_add(1).unwrap();
246            }
247            b = b.checked_add(1).unwrap();
248        }
249        classes
250    }
251}
252
253/// A simple set of bytes that is reasonably cheap to copy and allocation free.
254#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
255pub(crate) struct ByteSet {
256    bits: BitSet,
257}
258
259/// The representation of a byte set. Split out so that we can define a
260/// convenient Debug impl for it while keeping "ByteSet" in the output.
261#[derive(Clone, Copy, Default, Eq, PartialEq)]
262struct BitSet([u128; 2]);
263
264impl ByteSet {
265    /// Create an empty set of bytes.
266    pub(crate) fn empty() -> ByteSet {
267        ByteSet { bits: BitSet([0; 2]) }
268    }
269
270    /// Add a byte to this set.
271    ///
272    /// If the given byte already belongs to this set, then this is a no-op.
273    pub(crate) fn add(&mut self, byte: u8) {
274        let bucket = byte / 128;
275        let bit = byte % 128;
276        self.bits.0[usize::from(bucket)] |= 1 << bit;
277    }
278
279    /// Return true if and only if the given byte is in this set.
280    pub(crate) fn contains(&self, byte: u8) -> bool {
281        let bucket = byte / 128;
282        let bit = byte % 128;
283        self.bits.0[usize::from(bucket)] & (1 << bit) > 0
284    }
285}
286
287impl core::fmt::Debug for BitSet {
288    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
289        let mut fmtd = f.debug_set();
290        for b in 0u8..=255 {
291            if (ByteSet { bits: *self }).contains(b) {
292                fmtd.entry(&b);
293            }
294        }
295        fmtd.finish()
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use alloc::{vec, vec::Vec};
302
303    use super::*;
304
305    #[test]
306    fn byte_classes() {
307        let mut set = ByteClassSet::empty();
308        set.set_range(b'a', b'z');
309
310        let classes = set.byte_classes();
311        assert_eq!(classes.get(0), 0);
312        assert_eq!(classes.get(1), 0);
313        assert_eq!(classes.get(2), 0);
314        assert_eq!(classes.get(b'a' - 1), 0);
315        assert_eq!(classes.get(b'a'), 1);
316        assert_eq!(classes.get(b'm'), 1);
317        assert_eq!(classes.get(b'z'), 1);
318        assert_eq!(classes.get(b'z' + 1), 2);
319        assert_eq!(classes.get(254), 2);
320        assert_eq!(classes.get(255), 2);
321
322        let mut set = ByteClassSet::empty();
323        set.set_range(0, 2);
324        set.set_range(4, 6);
325        let classes = set.byte_classes();
326        assert_eq!(classes.get(0), 0);
327        assert_eq!(classes.get(1), 0);
328        assert_eq!(classes.get(2), 0);
329        assert_eq!(classes.get(3), 1);
330        assert_eq!(classes.get(4), 2);
331        assert_eq!(classes.get(5), 2);
332        assert_eq!(classes.get(6), 2);
333        assert_eq!(classes.get(7), 3);
334        assert_eq!(classes.get(255), 3);
335    }
336
337    #[test]
338    fn full_byte_classes() {
339        let mut set = ByteClassSet::empty();
340        for b in 0u8..=255 {
341            set.set_range(b, b);
342        }
343        assert_eq!(set.byte_classes().alphabet_len(), 256);
344    }
345
346    #[test]
347    fn elements_typical() {
348        let mut set = ByteClassSet::empty();
349        set.set_range(b'b', b'd');
350        set.set_range(b'g', b'm');
351        set.set_range(b'z', b'z');
352        let classes = set.byte_classes();
353        // class 0: \x00-a
354        // class 1: b-d
355        // class 2: e-f
356        // class 3: g-m
357        // class 4: n-y
358        // class 5: z-z
359        // class 6: \x7B-\xFF
360        assert_eq!(classes.alphabet_len(), 7);
361
362        let elements = classes.elements(0).collect::<Vec<_>>();
363        assert_eq!(elements.len(), 98);
364        assert_eq!(elements[0], b'\x00');
365        assert_eq!(elements[97], b'a');
366
367        let elements = classes.elements(1).collect::<Vec<_>>();
368        assert_eq!(elements, vec![b'b', b'c', b'd'],);
369
370        let elements = classes.elements(2).collect::<Vec<_>>();
371        assert_eq!(elements, vec![b'e', b'f'],);
372
373        let elements = classes.elements(3).collect::<Vec<_>>();
374        assert_eq!(elements, vec![b'g', b'h', b'i', b'j', b'k', b'l', b'm',],);
375
376        let elements = classes.elements(4).collect::<Vec<_>>();
377        assert_eq!(elements.len(), 12);
378        assert_eq!(elements[0], b'n');
379        assert_eq!(elements[11], b'y');
380
381        let elements = classes.elements(5).collect::<Vec<_>>();
382        assert_eq!(elements, vec![b'z']);
383
384        let elements = classes.elements(6).collect::<Vec<_>>();
385        assert_eq!(elements.len(), 133);
386        assert_eq!(elements[0], b'\x7B');
387        assert_eq!(elements[132], b'\xFF');
388    }
389
390    #[test]
391    fn elements_singletons() {
392        let classes = ByteClasses::singletons();
393        assert_eq!(classes.alphabet_len(), 256);
394
395        let elements = classes.elements(b'a').collect::<Vec<_>>();
396        assert_eq!(elements, vec![b'a']);
397    }
398
399    #[test]
400    fn elements_empty() {
401        let classes = ByteClasses::empty();
402        assert_eq!(classes.alphabet_len(), 1);
403
404        let elements = classes.elements(0).collect::<Vec<_>>();
405        assert_eq!(elements.len(), 256);
406        assert_eq!(elements[0], b'\x00');
407        assert_eq!(elements[255], b'\xFF');
408    }
409}