aho_corasick/packed/teddy/
generic.rs

1use core::fmt::Debug;
2
3use alloc::{
4    boxed::Box, collections::BTreeMap, format, sync::Arc, vec, vec::Vec,
5};
6
7use crate::{
8    packed::{
9        ext::Pointer,
10        pattern::Patterns,
11        vector::{FatVector, Vector},
12    },
13    util::int::U32,
14    PatternID,
15};
16
17/// A match type specialized to the Teddy implementations below.
18///
19/// Essentially, instead of representing a match at byte offsets, we use
20/// raw pointers. This is because the implementations below operate on raw
21/// pointers, and so this is a more natural return type based on how the
22/// implementation works.
23///
24/// Also, the `PatternID` used here is a `u16`.
25#[derive(Clone, Copy, Debug)]
26pub(crate) struct Match {
27    pid: PatternID,
28    start: *const u8,
29    end: *const u8,
30}
31
32impl Match {
33    /// Returns the ID of the pattern that matched.
34    pub(crate) fn pattern(&self) -> PatternID {
35        self.pid
36    }
37
38    /// Returns a pointer into the haystack at which the match starts.
39    pub(crate) fn start(&self) -> *const u8 {
40        self.start
41    }
42
43    /// Returns a pointer into the haystack at which the match ends.
44    pub(crate) fn end(&self) -> *const u8 {
45        self.end
46    }
47}
48
49/// A "slim" Teddy implementation that is generic over both the vector type
50/// and the minimum length of the patterns being searched for.
51///
52/// Only 1, 2, 3 and 4 bytes are supported as minimum lengths.
53#[derive(Clone, Debug)]
54pub(crate) struct Slim<V, const BYTES: usize> {
55    /// A generic data structure for doing "slim" Teddy verification.
56    teddy: Teddy<8>,
57    /// The masks used as inputs to the shuffle operation to generate
58    /// candidates (which are fed into the verification routines).
59    masks: [Mask<V>; BYTES],
60}
61
62impl<V: Vector, const BYTES: usize> Slim<V, BYTES> {
63    /// Create a new "slim" Teddy searcher for the given patterns.
64    ///
65    /// # Panics
66    ///
67    /// This panics when `BYTES` is any value other than 1, 2, 3 or 4.
68    ///
69    /// # Safety
70    ///
71    /// Callers must ensure that this is okay to call in the current target for
72    /// the current CPU.
73    #[inline(always)]
74    pub(crate) unsafe fn new(patterns: Arc<Patterns>) -> Slim<V, BYTES> {
75        assert!(
76            1 <= BYTES && BYTES <= 4,
77            "only 1, 2, 3 or 4 bytes are supported"
78        );
79        let teddy = Teddy::new(patterns);
80        let masks = SlimMaskBuilder::from_teddy(&teddy);
81        Slim { teddy, masks }
82    }
83
84    /// Returns the approximate total amount of heap used by this type, in
85    /// units of bytes.
86    #[inline(always)]
87    pub(crate) fn memory_usage(&self) -> usize {
88        self.teddy.memory_usage()
89    }
90
91    /// Returns the minimum length, in bytes, that a haystack must be in order
92    /// to use it with this searcher.
93    #[inline(always)]
94    pub(crate) fn minimum_len(&self) -> usize {
95        V::BYTES + (BYTES - 1)
96    }
97}
98
99impl<V: Vector> Slim<V, 1> {
100    /// Look for an occurrences of the patterns in this finder in the haystack
101    /// given by the `start` and `end` pointers.
102    ///
103    /// If no match could be found, then `None` is returned.
104    ///
105    /// # Safety
106    ///
107    /// The given pointers representing the haystack must be valid to read
108    /// from. They must also point to a region of memory that is at least the
109    /// minimum length required by this searcher.
110    ///
111    /// Callers must ensure that this is okay to call in the current target for
112    /// the current CPU.
113    #[inline(always)]
114    pub(crate) unsafe fn find(
115        &self,
116        start: *const u8,
117        end: *const u8,
118    ) -> Option<Match> {
119        let len = end.distance(start);
120        debug_assert!(len >= self.minimum_len());
121        let mut cur = start;
122        while cur <= end.sub(V::BYTES) {
123            if let Some(m) = self.find_one(cur, end) {
124                return Some(m);
125            }
126            cur = cur.add(V::BYTES);
127        }
128        if cur < end {
129            cur = end.sub(V::BYTES);
130            if let Some(m) = self.find_one(cur, end) {
131                return Some(m);
132            }
133        }
134        None
135    }
136
137    /// Look for a match starting at the `V::BYTES` at and after `cur`. If
138    /// there isn't one, then `None` is returned.
139    ///
140    /// # Safety
141    ///
142    /// The given pointers representing the haystack must be valid to read
143    /// from. They must also point to a region of memory that is at least the
144    /// minimum length required by this searcher.
145    ///
146    /// Callers must ensure that this is okay to call in the current target for
147    /// the current CPU.
148    #[inline(always)]
149    unsafe fn find_one(
150        &self,
151        cur: *const u8,
152        end: *const u8,
153    ) -> Option<Match> {
154        let c = self.candidate(cur);
155        if !c.is_zero() {
156            if let Some(m) = self.teddy.verify(cur, end, c) {
157                return Some(m);
158            }
159        }
160        None
161    }
162
163    /// Look for a candidate match (represented as a vector) starting at the
164    /// `V::BYTES` at and after `cur`. If there isn't one, then a vector with
165    /// all bits set to zero is returned.
166    ///
167    /// # Safety
168    ///
169    /// The given pointer representing the haystack must be valid to read
170    /// from.
171    ///
172    /// Callers must ensure that this is okay to call in the current target for
173    /// the current CPU.
174    #[inline(always)]
175    unsafe fn candidate(&self, cur: *const u8) -> V {
176        let chunk = V::load_unaligned(cur);
177        Mask::members1(chunk, self.masks)
178    }
179}
180
181impl<V: Vector> Slim<V, 2> {
182    /// See Slim<V, 1>::find.
183    #[inline(always)]
184    pub(crate) unsafe fn find(
185        &self,
186        start: *const u8,
187        end: *const u8,
188    ) -> Option<Match> {
189        let len = end.distance(start);
190        debug_assert!(len >= self.minimum_len());
191        let mut cur = start.add(1);
192        let mut prev0 = V::splat(0xFF);
193        while cur <= end.sub(V::BYTES) {
194            if let Some(m) = self.find_one(cur, end, &mut prev0) {
195                return Some(m);
196            }
197            cur = cur.add(V::BYTES);
198        }
199        if cur < end {
200            cur = end.sub(V::BYTES);
201            prev0 = V::splat(0xFF);
202            if let Some(m) = self.find_one(cur, end, &mut prev0) {
203                return Some(m);
204            }
205        }
206        None
207    }
208
209    /// See Slim<V, 1>::find_one.
210    #[inline(always)]
211    unsafe fn find_one(
212        &self,
213        cur: *const u8,
214        end: *const u8,
215        prev0: &mut V,
216    ) -> Option<Match> {
217        let c = self.candidate(cur, prev0);
218        if !c.is_zero() {
219            if let Some(m) = self.teddy.verify(cur.sub(1), end, c) {
220                return Some(m);
221            }
222        }
223        None
224    }
225
226    /// See Slim<V, 1>::candidate.
227    #[inline(always)]
228    unsafe fn candidate(&self, cur: *const u8, prev0: &mut V) -> V {
229        let chunk = V::load_unaligned(cur);
230        let (res0, res1) = Mask::members2(chunk, self.masks);
231        let res0prev0 = res0.shift_in_one_byte(*prev0);
232        let res = res0prev0.and(res1);
233        *prev0 = res0;
234        res
235    }
236}
237
238impl<V: Vector> Slim<V, 3> {
239    /// See Slim<V, 1>::find.
240    #[inline(always)]
241    pub(crate) unsafe fn find(
242        &self,
243        start: *const u8,
244        end: *const u8,
245    ) -> Option<Match> {
246        let len = end.distance(start);
247        debug_assert!(len >= self.minimum_len());
248        let mut cur = start.add(2);
249        let mut prev0 = V::splat(0xFF);
250        let mut prev1 = V::splat(0xFF);
251        while cur <= end.sub(V::BYTES) {
252            if let Some(m) = self.find_one(cur, end, &mut prev0, &mut prev1) {
253                return Some(m);
254            }
255            cur = cur.add(V::BYTES);
256        }
257        if cur < end {
258            cur = end.sub(V::BYTES);
259            prev0 = V::splat(0xFF);
260            prev1 = V::splat(0xFF);
261            if let Some(m) = self.find_one(cur, end, &mut prev0, &mut prev1) {
262                return Some(m);
263            }
264        }
265        None
266    }
267
268    /// See Slim<V, 1>::find_one.
269    #[inline(always)]
270    unsafe fn find_one(
271        &self,
272        cur: *const u8,
273        end: *const u8,
274        prev0: &mut V,
275        prev1: &mut V,
276    ) -> Option<Match> {
277        let c = self.candidate(cur, prev0, prev1);
278        if !c.is_zero() {
279            if let Some(m) = self.teddy.verify(cur.sub(2), end, c) {
280                return Some(m);
281            }
282        }
283        None
284    }
285
286    /// See Slim<V, 1>::candidate.
287    #[inline(always)]
288    unsafe fn candidate(
289        &self,
290        cur: *const u8,
291        prev0: &mut V,
292        prev1: &mut V,
293    ) -> V {
294        let chunk = V::load_unaligned(cur);
295        let (res0, res1, res2) = Mask::members3(chunk, self.masks);
296        let res0prev0 = res0.shift_in_two_bytes(*prev0);
297        let res1prev1 = res1.shift_in_one_byte(*prev1);
298        let res = res0prev0.and(res1prev1).and(res2);
299        *prev0 = res0;
300        *prev1 = res1;
301        res
302    }
303}
304
305impl<V: Vector> Slim<V, 4> {
306    /// See Slim<V, 1>::find.
307    #[inline(always)]
308    pub(crate) unsafe fn find(
309        &self,
310        start: *const u8,
311        end: *const u8,
312    ) -> Option<Match> {
313        let len = end.distance(start);
314        debug_assert!(len >= self.minimum_len());
315        let mut cur = start.add(3);
316        let mut prev0 = V::splat(0xFF);
317        let mut prev1 = V::splat(0xFF);
318        let mut prev2 = V::splat(0xFF);
319        while cur <= end.sub(V::BYTES) {
320            if let Some(m) =
321                self.find_one(cur, end, &mut prev0, &mut prev1, &mut prev2)
322            {
323                return Some(m);
324            }
325            cur = cur.add(V::BYTES);
326        }
327        if cur < end {
328            cur = end.sub(V::BYTES);
329            prev0 = V::splat(0xFF);
330            prev1 = V::splat(0xFF);
331            prev2 = V::splat(0xFF);
332            if let Some(m) =
333                self.find_one(cur, end, &mut prev0, &mut prev1, &mut prev2)
334            {
335                return Some(m);
336            }
337        }
338        None
339    }
340
341    /// See Slim<V, 1>::find_one.
342    #[inline(always)]
343    unsafe fn find_one(
344        &self,
345        cur: *const u8,
346        end: *const u8,
347        prev0: &mut V,
348        prev1: &mut V,
349        prev2: &mut V,
350    ) -> Option<Match> {
351        let c = self.candidate(cur, prev0, prev1, prev2);
352        if !c.is_zero() {
353            if let Some(m) = self.teddy.verify(cur.sub(3), end, c) {
354                return Some(m);
355            }
356        }
357        None
358    }
359
360    /// See Slim<V, 1>::candidate.
361    #[inline(always)]
362    unsafe fn candidate(
363        &self,
364        cur: *const u8,
365        prev0: &mut V,
366        prev1: &mut V,
367        prev2: &mut V,
368    ) -> V {
369        let chunk = V::load_unaligned(cur);
370        let (res0, res1, res2, res3) = Mask::members4(chunk, self.masks);
371        let res0prev0 = res0.shift_in_three_bytes(*prev0);
372        let res1prev1 = res1.shift_in_two_bytes(*prev1);
373        let res2prev2 = res2.shift_in_one_byte(*prev2);
374        let res = res0prev0.and(res1prev1).and(res2prev2).and(res3);
375        *prev0 = res0;
376        *prev1 = res1;
377        *prev2 = res2;
378        res
379    }
380}
381
382/// A "fat" Teddy implementation that is generic over both the vector type
383/// and the minimum length of the patterns being searched for.
384///
385/// Only 1, 2, 3 and 4 bytes are supported as minimum lengths.
386#[derive(Clone, Debug)]
387pub(crate) struct Fat<V, const BYTES: usize> {
388    /// A generic data structure for doing "fat" Teddy verification.
389    teddy: Teddy<16>,
390    /// The masks used as inputs to the shuffle operation to generate
391    /// candidates (which are fed into the verification routines).
392    masks: [Mask<V>; BYTES],
393}
394
395impl<V: FatVector, const BYTES: usize> Fat<V, BYTES> {
396    /// Create a new "fat" Teddy searcher for the given patterns.
397    ///
398    /// # Panics
399    ///
400    /// This panics when `BYTES` is any value other than 1, 2, 3 or 4.
401    ///
402    /// # Safety
403    ///
404    /// Callers must ensure that this is okay to call in the current target for
405    /// the current CPU.
406    #[inline(always)]
407    pub(crate) unsafe fn new(patterns: Arc<Patterns>) -> Fat<V, BYTES> {
408        assert!(
409            1 <= BYTES && BYTES <= 4,
410            "only 1, 2, 3 or 4 bytes are supported"
411        );
412        let teddy = Teddy::new(patterns);
413        let masks = FatMaskBuilder::from_teddy(&teddy);
414        Fat { teddy, masks }
415    }
416
417    /// Returns the approximate total amount of heap used by this type, in
418    /// units of bytes.
419    #[inline(always)]
420    pub(crate) fn memory_usage(&self) -> usize {
421        self.teddy.memory_usage()
422    }
423
424    /// Returns the minimum length, in bytes, that a haystack must be in order
425    /// to use it with this searcher.
426    #[inline(always)]
427    pub(crate) fn minimum_len(&self) -> usize {
428        V::Half::BYTES + (BYTES - 1)
429    }
430}
431
432impl<V: FatVector> Fat<V, 1> {
433    /// Look for an occurrences of the patterns in this finder in the haystack
434    /// given by the `start` and `end` pointers.
435    ///
436    /// If no match could be found, then `None` is returned.
437    ///
438    /// # Safety
439    ///
440    /// The given pointers representing the haystack must be valid to read
441    /// from. They must also point to a region of memory that is at least the
442    /// minimum length required by this searcher.
443    ///
444    /// Callers must ensure that this is okay to call in the current target for
445    /// the current CPU.
446    #[inline(always)]
447    pub(crate) unsafe fn find(
448        &self,
449        start: *const u8,
450        end: *const u8,
451    ) -> Option<Match> {
452        let len = end.distance(start);
453        debug_assert!(len >= self.minimum_len());
454        let mut cur = start;
455        while cur <= end.sub(V::Half::BYTES) {
456            if let Some(m) = self.find_one(cur, end) {
457                return Some(m);
458            }
459            cur = cur.add(V::Half::BYTES);
460        }
461        if cur < end {
462            cur = end.sub(V::Half::BYTES);
463            if let Some(m) = self.find_one(cur, end) {
464                return Some(m);
465            }
466        }
467        None
468    }
469
470    /// Look for a match starting at the `V::BYTES` at and after `cur`. If
471    /// there isn't one, then `None` is returned.
472    ///
473    /// # Safety
474    ///
475    /// The given pointers representing the haystack must be valid to read
476    /// from. They must also point to a region of memory that is at least the
477    /// minimum length required by this searcher.
478    ///
479    /// Callers must ensure that this is okay to call in the current target for
480    /// the current CPU.
481    #[inline(always)]
482    unsafe fn find_one(
483        &self,
484        cur: *const u8,
485        end: *const u8,
486    ) -> Option<Match> {
487        let c = self.candidate(cur);
488        if !c.is_zero() {
489            if let Some(m) = self.teddy.verify(cur, end, c) {
490                return Some(m);
491            }
492        }
493        None
494    }
495
496    /// Look for a candidate match (represented as a vector) starting at the
497    /// `V::BYTES` at and after `cur`. If there isn't one, then a vector with
498    /// all bits set to zero is returned.
499    ///
500    /// # Safety
501    ///
502    /// The given pointer representing the haystack must be valid to read
503    /// from.
504    ///
505    /// Callers must ensure that this is okay to call in the current target for
506    /// the current CPU.
507    #[inline(always)]
508    unsafe fn candidate(&self, cur: *const u8) -> V {
509        let chunk = V::load_half_unaligned(cur);
510        Mask::members1(chunk, self.masks)
511    }
512}
513
514impl<V: FatVector> Fat<V, 2> {
515    /// See `Fat<V, 1>::find`.
516    #[inline(always)]
517    pub(crate) unsafe fn find(
518        &self,
519        start: *const u8,
520        end: *const u8,
521    ) -> Option<Match> {
522        let len = end.distance(start);
523        debug_assert!(len >= self.minimum_len());
524        let mut cur = start.add(1);
525        let mut prev0 = V::splat(0xFF);
526        while cur <= end.sub(V::Half::BYTES) {
527            if let Some(m) = self.find_one(cur, end, &mut prev0) {
528                return Some(m);
529            }
530            cur = cur.add(V::Half::BYTES);
531        }
532        if cur < end {
533            cur = end.sub(V::Half::BYTES);
534            prev0 = V::splat(0xFF);
535            if let Some(m) = self.find_one(cur, end, &mut prev0) {
536                return Some(m);
537            }
538        }
539        None
540    }
541
542    /// See `Fat<V, 1>::find_one`.
543    #[inline(always)]
544    unsafe fn find_one(
545        &self,
546        cur: *const u8,
547        end: *const u8,
548        prev0: &mut V,
549    ) -> Option<Match> {
550        let c = self.candidate(cur, prev0);
551        if !c.is_zero() {
552            if let Some(m) = self.teddy.verify(cur.sub(1), end, c) {
553                return Some(m);
554            }
555        }
556        None
557    }
558
559    /// See `Fat<V, 1>::candidate`.
560    #[inline(always)]
561    unsafe fn candidate(&self, cur: *const u8, prev0: &mut V) -> V {
562        let chunk = V::load_half_unaligned(cur);
563        let (res0, res1) = Mask::members2(chunk, self.masks);
564        let res0prev0 = res0.half_shift_in_one_byte(*prev0);
565        let res = res0prev0.and(res1);
566        *prev0 = res0;
567        res
568    }
569}
570
571impl<V: FatVector> Fat<V, 3> {
572    /// See `Fat<V, 1>::find`.
573    #[inline(always)]
574    pub(crate) unsafe fn find(
575        &self,
576        start: *const u8,
577        end: *const u8,
578    ) -> Option<Match> {
579        let len = end.distance(start);
580        debug_assert!(len >= self.minimum_len());
581        let mut cur = start.add(2);
582        let mut prev0 = V::splat(0xFF);
583        let mut prev1 = V::splat(0xFF);
584        while cur <= end.sub(V::Half::BYTES) {
585            if let Some(m) = self.find_one(cur, end, &mut prev0, &mut prev1) {
586                return Some(m);
587            }
588            cur = cur.add(V::Half::BYTES);
589        }
590        if cur < end {
591            cur = end.sub(V::Half::BYTES);
592            prev0 = V::splat(0xFF);
593            prev1 = V::splat(0xFF);
594            if let Some(m) = self.find_one(cur, end, &mut prev0, &mut prev1) {
595                return Some(m);
596            }
597        }
598        None
599    }
600
601    /// See `Fat<V, 1>::find_one`.
602    #[inline(always)]
603    unsafe fn find_one(
604        &self,
605        cur: *const u8,
606        end: *const u8,
607        prev0: &mut V,
608        prev1: &mut V,
609    ) -> Option<Match> {
610        let c = self.candidate(cur, prev0, prev1);
611        if !c.is_zero() {
612            if let Some(m) = self.teddy.verify(cur.sub(2), end, c) {
613                return Some(m);
614            }
615        }
616        None
617    }
618
619    /// See `Fat<V, 1>::candidate`.
620    #[inline(always)]
621    unsafe fn candidate(
622        &self,
623        cur: *const u8,
624        prev0: &mut V,
625        prev1: &mut V,
626    ) -> V {
627        let chunk = V::load_half_unaligned(cur);
628        let (res0, res1, res2) = Mask::members3(chunk, self.masks);
629        let res0prev0 = res0.half_shift_in_two_bytes(*prev0);
630        let res1prev1 = res1.half_shift_in_one_byte(*prev1);
631        let res = res0prev0.and(res1prev1).and(res2);
632        *prev0 = res0;
633        *prev1 = res1;
634        res
635    }
636}
637
638impl<V: FatVector> Fat<V, 4> {
639    /// See `Fat<V, 1>::find`.
640    #[inline(always)]
641    pub(crate) unsafe fn find(
642        &self,
643        start: *const u8,
644        end: *const u8,
645    ) -> Option<Match> {
646        let len = end.distance(start);
647        debug_assert!(len >= self.minimum_len());
648        let mut cur = start.add(3);
649        let mut prev0 = V::splat(0xFF);
650        let mut prev1 = V::splat(0xFF);
651        let mut prev2 = V::splat(0xFF);
652        while cur <= end.sub(V::Half::BYTES) {
653            if let Some(m) =
654                self.find_one(cur, end, &mut prev0, &mut prev1, &mut prev2)
655            {
656                return Some(m);
657            }
658            cur = cur.add(V::Half::BYTES);
659        }
660        if cur < end {
661            cur = end.sub(V::Half::BYTES);
662            prev0 = V::splat(0xFF);
663            prev1 = V::splat(0xFF);
664            prev2 = V::splat(0xFF);
665            if let Some(m) =
666                self.find_one(cur, end, &mut prev0, &mut prev1, &mut prev2)
667            {
668                return Some(m);
669            }
670        }
671        None
672    }
673
674    /// See `Fat<V, 1>::find_one`.
675    #[inline(always)]
676    unsafe fn find_one(
677        &self,
678        cur: *const u8,
679        end: *const u8,
680        prev0: &mut V,
681        prev1: &mut V,
682        prev2: &mut V,
683    ) -> Option<Match> {
684        let c = self.candidate(cur, prev0, prev1, prev2);
685        if !c.is_zero() {
686            if let Some(m) = self.teddy.verify(cur.sub(3), end, c) {
687                return Some(m);
688            }
689        }
690        None
691    }
692
693    /// See `Fat<V, 1>::candidate`.
694    #[inline(always)]
695    unsafe fn candidate(
696        &self,
697        cur: *const u8,
698        prev0: &mut V,
699        prev1: &mut V,
700        prev2: &mut V,
701    ) -> V {
702        let chunk = V::load_half_unaligned(cur);
703        let (res0, res1, res2, res3) = Mask::members4(chunk, self.masks);
704        let res0prev0 = res0.half_shift_in_three_bytes(*prev0);
705        let res1prev1 = res1.half_shift_in_two_bytes(*prev1);
706        let res2prev2 = res2.half_shift_in_one_byte(*prev2);
707        let res = res0prev0.and(res1prev1).and(res2prev2).and(res3);
708        *prev0 = res0;
709        *prev1 = res1;
710        *prev2 = res2;
711        res
712    }
713}
714
715/// The common elements of all "slim" and "fat" Teddy search implementations.
716///
717/// Essentially, this contains the patterns and the buckets. Namely, it
718/// contains enough to implement the verification step after candidates are
719/// identified via the shuffle masks.
720///
721/// It is generic over the number of buckets used. In general, the number of
722/// buckets is either 8 (for "slim" Teddy) or 16 (for "fat" Teddy). The generic
723/// parameter isn't really meant to be instantiated for any value other than
724/// 8 or 16, although it is technically possible. The main hiccup is that there
725/// is some bit-shifting done in the critical part of verification that could
726/// be quite expensive if `N` is not a multiple of 2.
727#[derive(Clone, Debug)]
728struct Teddy<const BUCKETS: usize> {
729    /// The patterns we are searching for.
730    ///
731    /// A pattern string can be found by its `PatternID`.
732    patterns: Arc<Patterns>,
733    /// The allocation of patterns in buckets. This only contains the IDs of
734    /// patterns. In order to do full verification, callers must provide the
735    /// actual patterns when using Teddy.
736    buckets: [Vec<PatternID>; BUCKETS],
737    // N.B. The above representation is very simple, but it definitely results
738    // in ping-ponging between different allocations during verification. I've
739    // tried experimenting with other representations that flatten the pattern
740    // strings into a single allocation, but it doesn't seem to help much.
741    // Probably everything is small enough to fit into cache anyway, and so the
742    // pointer chasing isn't a big deal?
743    //
744    // One other avenue I haven't explored is some kind of hashing trick
745    // that let's us do another high-confidence check before launching into
746    // `memcmp`.
747}
748
749impl<const BUCKETS: usize> Teddy<BUCKETS> {
750    /// Create a new generic data structure for Teddy verification.
751    fn new(patterns: Arc<Patterns>) -> Teddy<BUCKETS> {
752        assert_ne!(0, patterns.len(), "Teddy requires at least one pattern");
753        assert_ne!(
754            0,
755            patterns.minimum_len(),
756            "Teddy does not support zero-length patterns"
757        );
758        assert!(
759            BUCKETS == 8 || BUCKETS == 16,
760            "Teddy only supports 8 or 16 buckets"
761        );
762        // MSRV(1.63): Use core::array::from_fn below instead of allocating a
763        // superfluous outer Vec. Not a big deal (especially given the BTreeMap
764        // allocation below), but nice to not do it.
765        let buckets =
766            <[Vec<PatternID>; BUCKETS]>::try_from(vec![vec![]; BUCKETS])
767                .unwrap();
768        let mut t = Teddy { patterns, buckets };
769
770        let mut map: BTreeMap<Box<[u8]>, usize> = BTreeMap::new();
771        for (id, pattern) in t.patterns.iter() {
772            // We try to be slightly clever in how we assign patterns into
773            // buckets. Generally speaking, we want patterns with the same
774            // prefix to be in the same bucket, since it minimizes the amount
775            // of time we spend churning through buckets in the verification
776            // step.
777            //
778            // So we could assign patterns with the same N-prefix (where N is
779            // the size of the mask, which is one of {1, 2, 3}) to the same
780            // bucket. However, case insensitive searches are fairly common, so
781            // we'd for example, ideally want to treat `abc` and `ABC` as if
782            // they shared the same prefix. ASCII has the nice property that
783            // the lower 4 bits of A and a are the same, so we therefore group
784            // patterns with the same low-nybble-N-prefix into the same bucket.
785            //
786            // MOREOVER, this is actually necessary for correctness! In
787            // particular, by grouping patterns with the same prefix into the
788            // same bucket, we ensure that we preserve correct leftmost-first
789            // and leftmost-longest match semantics. In addition to the fact
790            // that `patterns.iter()` iterates in the correct order, this
791            // guarantees that all possible ambiguous matches will occur in
792            // the same bucket. The verification routine could be adjusted to
793            // support correct leftmost match semantics regardless of bucket
794            // allocation, but that results in a performance hit. It's much
795            // nicer to be able to just stop as soon as a match is found.
796            let lonybs = pattern.low_nybbles(t.mask_len());
797            if let Some(&bucket) = map.get(&lonybs) {
798                t.buckets[bucket].push(id);
799            } else {
800                // N.B. We assign buckets in reverse because it shouldn't have
801                // any influence on performance, but it does make it harder to
802                // get leftmost match semantics accidentally correct.
803                let bucket = (BUCKETS - 1) - (id.as_usize() % BUCKETS);
804                t.buckets[bucket].push(id);
805                map.insert(lonybs, bucket);
806            }
807        }
808        t
809    }
810
811    /// Verify whether there are any matches starting at or after `cur` in the
812    /// haystack. The candidate chunk given should correspond to 8-bit bitsets
813    /// for N buckets.
814    ///
815    /// # Safety
816    ///
817    /// The given pointers representing the haystack must be valid to read
818    /// from.
819    #[inline(always)]
820    unsafe fn verify64(
821        &self,
822        cur: *const u8,
823        end: *const u8,
824        mut candidate_chunk: u64,
825    ) -> Option<Match> {
826        while candidate_chunk != 0 {
827            let bit = candidate_chunk.trailing_zeros().as_usize();
828            candidate_chunk &= !(1 << bit);
829
830            let cur = cur.add(bit / BUCKETS);
831            let bucket = bit % BUCKETS;
832            if let Some(m) = self.verify_bucket(cur, end, bucket) {
833                return Some(m);
834            }
835        }
836        None
837    }
838
839    /// Verify whether there are any matches starting at `at` in the given
840    /// `haystack` corresponding only to patterns in the given bucket.
841    ///
842    /// # Safety
843    ///
844    /// The given pointers representing the haystack must be valid to read
845    /// from.
846    ///
847    /// The bucket index must be less than or equal to `self.buckets.len()`.
848    #[inline(always)]
849    unsafe fn verify_bucket(
850        &self,
851        cur: *const u8,
852        end: *const u8,
853        bucket: usize,
854    ) -> Option<Match> {
855        debug_assert!(bucket < self.buckets.len());
856        // SAFETY: The caller must ensure that the bucket index is correct.
857        for pid in self.buckets.get_unchecked(bucket).iter().copied() {
858            // SAFETY: This is safe because we are guaranteed that every
859            // index in a Teddy bucket is a valid index into `pats`, by
860            // construction.
861            debug_assert!(pid.as_usize() < self.patterns.len());
862            let pat = self.patterns.get_unchecked(pid);
863            if pat.is_prefix_raw(cur, end) {
864                let start = cur;
865                let end = start.add(pat.len());
866                return Some(Match { pid, start, end });
867            }
868        }
869        None
870    }
871
872    /// Returns the total number of masks required by the patterns in this
873    /// Teddy searcher.
874    ///
875    /// Basically, the mask length corresponds to the type of Teddy searcher
876    /// to use: a 1-byte, 2-byte, 3-byte or 4-byte searcher. The bigger the
877    /// better, typically, since searching for longer substrings usually
878    /// decreases the rate of false positives. Therefore, the number of masks
879    /// needed is the length of the shortest pattern in this searcher. If the
880    /// length of the shortest pattern (in bytes) is bigger than 4, then the
881    /// mask length is 4 since there are no Teddy searchers for more than 4
882    /// bytes.
883    fn mask_len(&self) -> usize {
884        core::cmp::min(4, self.patterns.minimum_len())
885    }
886
887    /// Returns the approximate total amount of heap used by this type, in
888    /// units of bytes.
889    fn memory_usage(&self) -> usize {
890        // This is an upper bound rather than a precise accounting. No
891        // particular reason, other than it's probably very close to actual
892        // memory usage in practice.
893        self.patterns.len() * core::mem::size_of::<PatternID>()
894    }
895}
896
897impl Teddy<8> {
898    /// Runs the verification routine for "slim" Teddy.
899    ///
900    /// The candidate given should be a collection of 8-bit bitsets (one bitset
901    /// per lane), where the ith bit is set in the jth lane if and only if the
902    /// byte occurring at `at + j` in `cur` is in the bucket `i`.
903    ///
904    /// # Safety
905    ///
906    /// Callers must ensure that this is okay to call in the current target for
907    /// the current CPU.
908    ///
909    /// The given pointers must be valid to read from.
910    #[inline(always)]
911    unsafe fn verify<V: Vector>(
912        &self,
913        mut cur: *const u8,
914        end: *const u8,
915        candidate: V,
916    ) -> Option<Match> {
917        debug_assert!(!candidate.is_zero());
918        // Convert the candidate into 64-bit chunks, and then verify each of
919        // those chunks.
920        candidate.for_each_64bit_lane(
921            #[inline(always)]
922            |_, chunk| {
923                let result = self.verify64(cur, end, chunk);
924                cur = cur.add(8);
925                result
926            },
927        )
928    }
929}
930
931impl Teddy<16> {
932    /// Runs the verification routine for "fat" Teddy.
933    ///
934    /// The candidate given should be a collection of 8-bit bitsets (one bitset
935    /// per lane), where the ith bit is set in the jth lane if and only if the
936    /// byte occurring at `at + (j < 16 ? j : j - 16)` in `cur` is in the
937    /// bucket `j < 16 ? i : i + 8`.
938    ///
939    /// # Safety
940    ///
941    /// Callers must ensure that this is okay to call in the current target for
942    /// the current CPU.
943    ///
944    /// The given pointers must be valid to read from.
945    #[inline(always)]
946    unsafe fn verify<V: FatVector>(
947        &self,
948        mut cur: *const u8,
949        end: *const u8,
950        candidate: V,
951    ) -> Option<Match> {
952        // This is a bit tricky, but we basically want to convert our
953        // candidate, which looks like this (assuming a 256-bit vector):
954        //
955        //     a31 a30 ... a17 a16 a15 a14 ... a01 a00
956        //
957        // where each a(i) is an 8-bit bitset corresponding to the activated
958        // buckets, to this
959        //
960        //     a31 a15 a30 a14 a29 a13 ... a18 a02 a17 a01 a16 a00
961        //
962        // Namely, for Fat Teddy, the high 128-bits of the candidate correspond
963        // to the same bytes in the haystack in the low 128-bits (so we only
964        // scan 16 bytes at a time), but are for buckets 8-15 instead of 0-7.
965        //
966        // The verification routine wants to look at all potentially matching
967        // buckets before moving on to the next lane. So for example, both
968        // a16 and a00 both correspond to the first byte in our window; a00
969        // contains buckets 0-7 and a16 contains buckets 8-15. Specifically,
970        // a16 should be checked before a01. So the transformation shown above
971        // allows us to use our normal verification procedure with one small
972        // change: we treat each bitset as 16 bits instead of 8 bits.
973        debug_assert!(!candidate.is_zero());
974
975        // Swap the 128-bit lanes in the candidate vector.
976        let swapped = candidate.swap_halves();
977        // Interleave the bytes from the low 128-bit lanes, starting with
978        // cand first.
979        let r1 = candidate.interleave_low_8bit_lanes(swapped);
980        // Interleave the bytes from the high 128-bit lanes, starting with
981        // cand first.
982        let r2 = candidate.interleave_high_8bit_lanes(swapped);
983        // Now just take the 2 low 64-bit integers from both r1 and r2. We
984        // can drop the high 64-bit integers because they are a mirror image
985        // of the low 64-bit integers. All we care about are the low 128-bit
986        // lanes of r1 and r2. Combined, they contain all our 16-bit bitsets
987        // laid out in the desired order, as described above.
988        r1.for_each_low_64bit_lane(
989            r2,
990            #[inline(always)]
991            |_, chunk| {
992                let result = self.verify64(cur, end, chunk);
993                cur = cur.add(4);
994                result
995            },
996        )
997    }
998}
999
1000/// A vector generic mask for the low and high nybbles in a set of patterns.
1001/// Each 8-bit lane `j` in a vector corresponds to a bitset where the `i`th bit
1002/// is set if and only if the nybble `j` is in the bucket `i` at a particular
1003/// position.
1004///
1005/// This is slightly tweaked dependending on whether Slim or Fat Teddy is being
1006/// used. For Slim Teddy, the bitsets in the lower half are the same as the
1007/// bitsets in the higher half, so that we can search `V::BYTES` bytes at a
1008/// time. (Remember, the nybbles in the haystack are used as indices into these
1009/// masks, and 256-bit shuffles only operate on 128-bit lanes.)
1010///
1011/// For Fat Teddy, the bitsets are not repeated, but instead, the high half
1012/// bits correspond to an addition 8 buckets. So that a bitset `00100010` has
1013/// buckets 1 and 5 set if it's in the lower half, but has buckets 9 and 13 set
1014/// if it's in the higher half.
1015#[derive(Clone, Copy, Debug)]
1016struct Mask<V> {
1017    lo: V,
1018    hi: V,
1019}
1020
1021impl<V: Vector> Mask<V> {
1022    /// Return a candidate for Teddy (fat or slim) that is searching for 1-byte
1023    /// candidates.
1024    ///
1025    /// If a candidate is returned, it will be a collection of 8-bit bitsets
1026    /// (one bitset per lane), where the ith bit is set in the jth lane if and
1027    /// only if the byte occurring at the jth lane in `chunk` is in the bucket
1028    /// `i`. If no candidate is found, then the vector returned will have all
1029    /// lanes set to zero.
1030    ///
1031    /// `chunk` should correspond to a `V::BYTES` window of the haystack (where
1032    /// the least significant byte corresponds to the start of the window). For
1033    /// fat Teddy, the haystack window length should be `V::BYTES / 2`, with
1034    /// the window repeated in each half of the vector.
1035    ///
1036    /// `mask1` should correspond to a low/high mask for the first byte of all
1037    /// patterns that are being searched.
1038    #[inline(always)]
1039    unsafe fn members1(chunk: V, masks: [Mask<V>; 1]) -> V {
1040        let lomask = V::splat(0xF);
1041        let hlo = chunk.and(lomask);
1042        let hhi = chunk.shift_8bit_lane_right::<4>().and(lomask);
1043        let locand = masks[0].lo.shuffle_bytes(hlo);
1044        let hicand = masks[0].hi.shuffle_bytes(hhi);
1045        locand.and(hicand)
1046    }
1047
1048    /// Return a candidate for Teddy (fat or slim) that is searching for 2-byte
1049    /// candidates.
1050    ///
1051    /// If candidates are returned, each will be a collection of 8-bit bitsets
1052    /// (one bitset per lane), where the ith bit is set in the jth lane if and
1053    /// only if the byte occurring at the jth lane in `chunk` is in the bucket
1054    /// `i`. Each candidate returned corresponds to the first and second bytes
1055    /// of the patterns being searched. If no candidate is found, then all of
1056    /// the lanes will be set to zero in at least one of the vectors returned.
1057    ///
1058    /// `chunk` should correspond to a `V::BYTES` window of the haystack (where
1059    /// the least significant byte corresponds to the start of the window). For
1060    /// fat Teddy, the haystack window length should be `V::BYTES / 2`, with
1061    /// the window repeated in each half of the vector.
1062    ///
1063    /// The masks should correspond to the masks computed for the first and
1064    /// second bytes of all patterns that are being searched.
1065    #[inline(always)]
1066    unsafe fn members2(chunk: V, masks: [Mask<V>; 2]) -> (V, V) {
1067        let lomask = V::splat(0xF);
1068        let hlo = chunk.and(lomask);
1069        let hhi = chunk.shift_8bit_lane_right::<4>().and(lomask);
1070
1071        let locand1 = masks[0].lo.shuffle_bytes(hlo);
1072        let hicand1 = masks[0].hi.shuffle_bytes(hhi);
1073        let cand1 = locand1.and(hicand1);
1074
1075        let locand2 = masks[1].lo.shuffle_bytes(hlo);
1076        let hicand2 = masks[1].hi.shuffle_bytes(hhi);
1077        let cand2 = locand2.and(hicand2);
1078
1079        (cand1, cand2)
1080    }
1081
1082    /// Return a candidate for Teddy (fat or slim) that is searching for 3-byte
1083    /// candidates.
1084    ///
1085    /// If candidates are returned, each will be a collection of 8-bit bitsets
1086    /// (one bitset per lane), where the ith bit is set in the jth lane if and
1087    /// only if the byte occurring at the jth lane in `chunk` is in the bucket
1088    /// `i`. Each candidate returned corresponds to the first, second and third
1089    /// bytes of the patterns being searched. If no candidate is found, then
1090    /// all of the lanes will be set to zero in at least one of the vectors
1091    /// returned.
1092    ///
1093    /// `chunk` should correspond to a `V::BYTES` window of the haystack (where
1094    /// the least significant byte corresponds to the start of the window). For
1095    /// fat Teddy, the haystack window length should be `V::BYTES / 2`, with
1096    /// the window repeated in each half of the vector.
1097    ///
1098    /// The masks should correspond to the masks computed for the first, second
1099    /// and third bytes of all patterns that are being searched.
1100    #[inline(always)]
1101    unsafe fn members3(chunk: V, masks: [Mask<V>; 3]) -> (V, V, V) {
1102        let lomask = V::splat(0xF);
1103        let hlo = chunk.and(lomask);
1104        let hhi = chunk.shift_8bit_lane_right::<4>().and(lomask);
1105
1106        let locand1 = masks[0].lo.shuffle_bytes(hlo);
1107        let hicand1 = masks[0].hi.shuffle_bytes(hhi);
1108        let cand1 = locand1.and(hicand1);
1109
1110        let locand2 = masks[1].lo.shuffle_bytes(hlo);
1111        let hicand2 = masks[1].hi.shuffle_bytes(hhi);
1112        let cand2 = locand2.and(hicand2);
1113
1114        let locand3 = masks[2].lo.shuffle_bytes(hlo);
1115        let hicand3 = masks[2].hi.shuffle_bytes(hhi);
1116        let cand3 = locand3.and(hicand3);
1117
1118        (cand1, cand2, cand3)
1119    }
1120
1121    /// Return a candidate for Teddy (fat or slim) that is searching for 4-byte
1122    /// candidates.
1123    ///
1124    /// If candidates are returned, each will be a collection of 8-bit bitsets
1125    /// (one bitset per lane), where the ith bit is set in the jth lane if and
1126    /// only if the byte occurring at the jth lane in `chunk` is in the bucket
1127    /// `i`. Each candidate returned corresponds to the first, second, third
1128    /// and fourth bytes of the patterns being searched. If no candidate is
1129    /// found, then all of the lanes will be set to zero in at least one of the
1130    /// vectors returned.
1131    ///
1132    /// `chunk` should correspond to a `V::BYTES` window of the haystack (where
1133    /// the least significant byte corresponds to the start of the window). For
1134    /// fat Teddy, the haystack window length should be `V::BYTES / 2`, with
1135    /// the window repeated in each half of the vector.
1136    ///
1137    /// The masks should correspond to the masks computed for the first,
1138    /// second, third and fourth bytes of all patterns that are being searched.
1139    #[inline(always)]
1140    unsafe fn members4(chunk: V, masks: [Mask<V>; 4]) -> (V, V, V, V) {
1141        let lomask = V::splat(0xF);
1142        let hlo = chunk.and(lomask);
1143        let hhi = chunk.shift_8bit_lane_right::<4>().and(lomask);
1144
1145        let locand1 = masks[0].lo.shuffle_bytes(hlo);
1146        let hicand1 = masks[0].hi.shuffle_bytes(hhi);
1147        let cand1 = locand1.and(hicand1);
1148
1149        let locand2 = masks[1].lo.shuffle_bytes(hlo);
1150        let hicand2 = masks[1].hi.shuffle_bytes(hhi);
1151        let cand2 = locand2.and(hicand2);
1152
1153        let locand3 = masks[2].lo.shuffle_bytes(hlo);
1154        let hicand3 = masks[2].hi.shuffle_bytes(hhi);
1155        let cand3 = locand3.and(hicand3);
1156
1157        let locand4 = masks[3].lo.shuffle_bytes(hlo);
1158        let hicand4 = masks[3].hi.shuffle_bytes(hhi);
1159        let cand4 = locand4.and(hicand4);
1160
1161        (cand1, cand2, cand3, cand4)
1162    }
1163}
1164
1165/// Represents the low and high nybble masks that will be used during
1166/// search. Each mask is 32 bytes wide, although only the first 16 bytes are
1167/// used for 128-bit vectors.
1168///
1169/// Each byte in the mask corresponds to a 8-bit bitset, where bit `i` is set
1170/// if and only if the corresponding nybble is in the ith bucket. The index of
1171/// the byte (0-15, inclusive) corresponds to the nybble.
1172///
1173/// Each mask is used as the target of a shuffle, where the indices for the
1174/// shuffle are taken from the haystack. AND'ing the shuffles for both the
1175/// low and high masks together also results in 8-bit bitsets, but where bit
1176/// `i` is set if and only if the correspond *byte* is in the ith bucket.
1177#[derive(Clone, Default)]
1178struct SlimMaskBuilder {
1179    lo: [u8; 32],
1180    hi: [u8; 32],
1181}
1182
1183impl SlimMaskBuilder {
1184    /// Update this mask by adding the given byte to the given bucket. The
1185    /// given bucket must be in the range 0-7.
1186    ///
1187    /// # Panics
1188    ///
1189    /// When `bucket >= 8`.
1190    fn add(&mut self, bucket: usize, byte: u8) {
1191        assert!(bucket < 8);
1192
1193        let bucket = u8::try_from(bucket).unwrap();
1194        let byte_lo = usize::from(byte & 0xF);
1195        let byte_hi = usize::from((byte >> 4) & 0xF);
1196        // When using 256-bit vectors, we need to set this bucket assignment in
1197        // the low and high 128-bit portions of the mask. This allows us to
1198        // process 32 bytes at a time. Namely, AVX2 shuffles operate on each
1199        // of the 128-bit lanes, rather than the full 256-bit vector at once.
1200        self.lo[byte_lo] |= 1 << bucket;
1201        self.lo[byte_lo + 16] |= 1 << bucket;
1202        self.hi[byte_hi] |= 1 << bucket;
1203        self.hi[byte_hi + 16] |= 1 << bucket;
1204    }
1205
1206    /// Turn this builder into a vector mask.
1207    ///
1208    /// # Panics
1209    ///
1210    /// When `V` represents a vector bigger than what `MaskBytes` can contain.
1211    ///
1212    /// # Safety
1213    ///
1214    /// Callers must ensure that this is okay to call in the current target for
1215    /// the current CPU.
1216    #[inline(always)]
1217    unsafe fn build<V: Vector>(&self) -> Mask<V> {
1218        assert!(V::BYTES <= self.lo.len());
1219        assert!(V::BYTES <= self.hi.len());
1220        Mask {
1221            lo: V::load_unaligned(self.lo[..].as_ptr()),
1222            hi: V::load_unaligned(self.hi[..].as_ptr()),
1223        }
1224    }
1225
1226    /// A convenience function for building `N` vector masks from a slim
1227    /// `Teddy` value.
1228    ///
1229    /// # Panics
1230    ///
1231    /// When `V` represents a vector bigger than what `MaskBytes` can contain.
1232    ///
1233    /// # Safety
1234    ///
1235    /// Callers must ensure that this is okay to call in the current target for
1236    /// the current CPU.
1237    #[inline(always)]
1238    unsafe fn from_teddy<const BYTES: usize, V: Vector>(
1239        teddy: &Teddy<8>,
1240    ) -> [Mask<V>; BYTES] {
1241        // MSRV(1.63): Use core::array::from_fn to just build the array here
1242        // instead of creating a vector and turning it into an array.
1243        let mut mask_builders = vec![SlimMaskBuilder::default(); BYTES];
1244        for (bucket_index, bucket) in teddy.buckets.iter().enumerate() {
1245            for pid in bucket.iter().copied() {
1246                let pat = teddy.patterns.get(pid);
1247                for (i, builder) in mask_builders.iter_mut().enumerate() {
1248                    builder.add(bucket_index, pat.bytes()[i]);
1249                }
1250            }
1251        }
1252        let array =
1253            <[SlimMaskBuilder; BYTES]>::try_from(mask_builders).unwrap();
1254        array.map(|builder| builder.build())
1255    }
1256}
1257
1258impl Debug for SlimMaskBuilder {
1259    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1260        let (mut parts_lo, mut parts_hi) = (vec![], vec![]);
1261        for i in 0..32 {
1262            parts_lo.push(format!("{:02}: {:08b}", i, self.lo[i]));
1263            parts_hi.push(format!("{:02}: {:08b}", i, self.hi[i]));
1264        }
1265        f.debug_struct("SlimMaskBuilder")
1266            .field("lo", &parts_lo)
1267            .field("hi", &parts_hi)
1268            .finish()
1269    }
1270}
1271
1272/// Represents the low and high nybble masks that will be used during "fat"
1273/// Teddy search.
1274///
1275/// Each mask is 32 bytes wide, and at the time of writing, only 256-bit vectors
1276/// support fat Teddy.
1277///
1278/// A fat Teddy mask is like a slim Teddy mask, except that instead of
1279/// repeating the bitsets in the high and low 128-bits in 256-bit vectors, the
1280/// high and low 128-bit halves each represent distinct buckets. (Bringing the
1281/// total to 16 instead of 8.) This permits spreading the patterns out a bit
1282/// more and thus putting less pressure on verification to be fast.
1283///
1284/// Each byte in the mask corresponds to a 8-bit bitset, where bit `i` is set
1285/// if and only if the corresponding nybble is in the ith bucket. The index of
1286/// the byte (0-15, inclusive) corresponds to the nybble.
1287#[derive(Clone, Copy, Default)]
1288struct FatMaskBuilder {
1289    lo: [u8; 32],
1290    hi: [u8; 32],
1291}
1292
1293impl FatMaskBuilder {
1294    /// Update this mask by adding the given byte to the given bucket. The
1295    /// given bucket must be in the range 0-15.
1296    ///
1297    /// # Panics
1298    ///
1299    /// When `bucket >= 16`.
1300    fn add(&mut self, bucket: usize, byte: u8) {
1301        assert!(bucket < 16);
1302
1303        let bucket = u8::try_from(bucket).unwrap();
1304        let byte_lo = usize::from(byte & 0xF);
1305        let byte_hi = usize::from((byte >> 4) & 0xF);
1306        // Unlike slim teddy, fat teddy only works with AVX2. For fat teddy,
1307        // the high 128 bits of our mask correspond to buckets 8-15, while the
1308        // low 128 bits correspond to buckets 0-7.
1309        if bucket < 8 {
1310            self.lo[byte_lo] |= 1 << bucket;
1311            self.hi[byte_hi] |= 1 << bucket;
1312        } else {
1313            self.lo[byte_lo + 16] |= 1 << (bucket % 8);
1314            self.hi[byte_hi + 16] |= 1 << (bucket % 8);
1315        }
1316    }
1317
1318    /// Turn this builder into a vector mask.
1319    ///
1320    /// # Panics
1321    ///
1322    /// When `V` represents a vector bigger than what `MaskBytes` can contain.
1323    ///
1324    /// # Safety
1325    ///
1326    /// Callers must ensure that this is okay to call in the current target for
1327    /// the current CPU.
1328    #[inline(always)]
1329    unsafe fn build<V: Vector>(&self) -> Mask<V> {
1330        assert!(V::BYTES <= self.lo.len());
1331        assert!(V::BYTES <= self.hi.len());
1332        Mask {
1333            lo: V::load_unaligned(self.lo[..].as_ptr()),
1334            hi: V::load_unaligned(self.hi[..].as_ptr()),
1335        }
1336    }
1337
1338    /// A convenience function for building `N` vector masks from a fat
1339    /// `Teddy` value.
1340    ///
1341    /// # Panics
1342    ///
1343    /// When `V` represents a vector bigger than what `MaskBytes` can contain.
1344    ///
1345    /// # Safety
1346    ///
1347    /// Callers must ensure that this is okay to call in the current target for
1348    /// the current CPU.
1349    #[inline(always)]
1350    unsafe fn from_teddy<const BYTES: usize, V: Vector>(
1351        teddy: &Teddy<16>,
1352    ) -> [Mask<V>; BYTES] {
1353        // MSRV(1.63): Use core::array::from_fn to just build the array here
1354        // instead of creating a vector and turning it into an array.
1355        let mut mask_builders = vec![FatMaskBuilder::default(); BYTES];
1356        for (bucket_index, bucket) in teddy.buckets.iter().enumerate() {
1357            for pid in bucket.iter().copied() {
1358                let pat = teddy.patterns.get(pid);
1359                for (i, builder) in mask_builders.iter_mut().enumerate() {
1360                    builder.add(bucket_index, pat.bytes()[i]);
1361                }
1362            }
1363        }
1364        let array =
1365            <[FatMaskBuilder; BYTES]>::try_from(mask_builders).unwrap();
1366        array.map(|builder| builder.build())
1367    }
1368}
1369
1370impl Debug for FatMaskBuilder {
1371    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1372        let (mut parts_lo, mut parts_hi) = (vec![], vec![]);
1373        for i in 0..32 {
1374            parts_lo.push(format!("{:02}: {:08b}", i, self.lo[i]));
1375            parts_hi.push(format!("{:02}: {:08b}", i, self.hi[i]));
1376        }
1377        f.debug_struct("FatMaskBuilder")
1378            .field("lo", &parts_lo)
1379            .field("hi", &parts_hi)
1380            .finish()
1381    }
1382}