poly1305/backend/
avx2.rs

1//! AVX2 implementation of the Poly1305 state machine.
2
3// The State struct and its logic was originally derived from Goll and Gueron's AVX2 C
4// code:
5//     [Vectorization of Poly1305 message authentication code](https://ieeexplore.ieee.org/document/7113463)
6//
7// which was sourced from Bhattacharyya and Sarkar's modified variant:
8//     [Improved SIMD Implementation of Poly1305](https://eprint.iacr.org/2019/842)
9//     https://github.com/Sreyosi/Improved-SIMD-Implementation-of-Poly1305
10//
11// The logic has been extensively rewritten and documented, and several bugs in the
12// original C code were fixed.
13//
14// Note that State only implements the original Goll-Gueron algorithm, not the
15// optimisations provided by Bhattacharyya and Sarkar. The latter require the message
16// length to be known, which is incompatible with the streaming API of UniversalHash.
17
18use universal_hash::{
19    consts::{U16, U4},
20    crypto_common::{BlockSizeUser, ParBlocksSizeUser},
21    generic_array::GenericArray,
22    UhfBackend,
23};
24
25use crate::{Block, Key, Tag};
26
27mod helpers;
28use self::helpers::*;
29
30/// Four Poly1305 blocks (64-bytes)
31type ParBlocks = universal_hash::ParBlocks<State>;
32
33#[derive(Copy, Clone)]
34struct Initialized {
35    p: Aligned4x130,
36    m: SpacedMultiplier4x130,
37    r4: PrecomputedMultiplier,
38}
39
40#[derive(Clone)]
41pub(crate) struct State {
42    k: AdditionKey,
43    r1: PrecomputedMultiplier,
44    r2: PrecomputedMultiplier,
45    initialized: Option<Initialized>,
46    cached_blocks: [Block; 4],
47    num_cached_blocks: usize,
48    partial_block: Option<Block>,
49}
50
51impl State {
52    /// Initialize Poly1305 [`State`] with the given key
53    pub(crate) fn new(key: &Key) -> Self {
54        // Prepare addition key and polynomial key.
55        let (k, r1) = unsafe { prepare_keys(key) };
56
57        // Precompute R^2.
58        let r2 = (r1 * r1).reduce();
59
60        State {
61            k,
62            r1,
63            r2: r2.into(),
64            initialized: None,
65            cached_blocks: [Block::default(); 4],
66            num_cached_blocks: 0,
67            partial_block: None,
68        }
69    }
70
71    /// Process four Poly1305 blocks at once.
72    #[target_feature(enable = "avx2")]
73    pub(crate) unsafe fn compute_par_blocks(&mut self, blocks: &ParBlocks) {
74        assert!(self.partial_block.is_none());
75        assert_eq!(self.num_cached_blocks, 0);
76
77        self.process_blocks(Aligned4x130::from_par_blocks(blocks));
78    }
79
80    /// Compute a Poly1305 block
81    #[target_feature(enable = "avx2")]
82    pub(crate) unsafe fn compute_block(&mut self, block: &Block, partial: bool) {
83        // We can cache a single partial block.
84        if partial {
85            assert!(self.partial_block.is_none());
86            self.partial_block = Some(*block);
87            return;
88        }
89
90        self.cached_blocks[self.num_cached_blocks].copy_from_slice(block);
91        if self.num_cached_blocks < 3 {
92            self.num_cached_blocks += 1;
93            return;
94        } else {
95            self.num_cached_blocks = 0;
96        }
97
98        self.process_blocks(Aligned4x130::from_blocks(&self.cached_blocks));
99    }
100
101    /// Compute a Poly1305 block
102    #[target_feature(enable = "avx2")]
103    unsafe fn process_blocks(&mut self, blocks: Aligned4x130) {
104        if let Some(inner) = &mut self.initialized {
105            // P <-- R^4 * P + blocks
106            inner.p = (&inner.p * inner.r4).reduce() + blocks;
107        } else {
108            // Initialize the polynomial.
109            let p = blocks;
110
111            // Initialize the multiplier (used to merge down the polynomial during
112            // finalization).
113            let (m, r4) = SpacedMultiplier4x130::new(self.r1, self.r2);
114
115            self.initialized = Some(Initialized { p, m, r4 })
116        }
117    }
118
119    /// Finalize output producing a [`Tag`]
120    #[target_feature(enable = "avx2")]
121    pub(crate) unsafe fn finalize(&mut self) -> Tag {
122        assert!(self.num_cached_blocks < 4);
123        let mut data = &self.cached_blocks[..];
124
125        // T ← R◦T
126        // P = T_0 + T_1 + T_2 + T_3
127        let mut p = self
128            .initialized
129            .take()
130            .map(|inner| (inner.p * inner.m).sum().reduce());
131
132        if self.num_cached_blocks >= 2 {
133            // Compute 32 byte block (remaining data < 64 bytes)
134            let mut c = Aligned2x130::from_blocks(data[..2].try_into().unwrap());
135            if let Some(p) = p {
136                c = c + p;
137            }
138            p = Some(c.mul_and_sum(self.r1, self.r2).reduce());
139            data = &data[2..];
140            self.num_cached_blocks -= 2;
141        }
142
143        if self.num_cached_blocks == 1 {
144            // Compute 16 byte block (remaining data < 32 bytes)
145            let mut c = Aligned130::from_block(&data[0]);
146            if let Some(p) = p {
147                c = c + p;
148            }
149            p = Some((c * self.r1).reduce());
150            self.num_cached_blocks -= 1;
151        }
152
153        if let Some(block) = &self.partial_block {
154            // Compute last block (remaining data < 16 bytes)
155            let mut c = Aligned130::from_partial_block(block);
156            if let Some(p) = p {
157                c = c + p;
158            }
159            p = Some((c * self.r1).reduce());
160        }
161
162        // Compute tag: p + k mod 2^128
163        let mut tag = GenericArray::<u8, _>::default();
164        let tag_int = if let Some(p) = p {
165            self.k + p
166        } else {
167            self.k.into()
168        };
169        tag_int.write(tag.as_mut_slice());
170
171        tag
172    }
173}
174
175impl BlockSizeUser for State {
176    type BlockSize = U16;
177}
178
179impl ParBlocksSizeUser for State {
180    type ParBlocksSize = U4;
181}
182
183impl UhfBackend for State {
184    fn proc_block(&mut self, block: &Block) {
185        unsafe { self.compute_block(block, false) };
186    }
187
188    fn proc_par_blocks(&mut self, blocks: &ParBlocks) {
189        if self.num_cached_blocks == 0 {
190            // Fast path.
191            unsafe { self.compute_par_blocks(blocks) };
192        } else {
193            // We are unaligned; use the slow fallback.
194            for block in blocks {
195                self.proc_block(block);
196            }
197        }
198    }
199
200    fn blocks_needed_to_align(&self) -> usize {
201        if self.num_cached_blocks == 0 {
202            // There are no cached blocks; fast path is available.
203            0
204        } else {
205            // There are cached blocks; report how many more we need.
206            self.cached_blocks.len() - self.num_cached_blocks
207        }
208    }
209}