pg_bigdecimal/
lib.rs

1use bytes::{BufMut, BytesMut};
2use postgres::types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
3use std::convert::TryInto;
4
5pub use bigdecimal::BigDecimal;
6pub use num::{BigInt, BigUint, Integer};
7#[cfg(feature = "serde")]
8use std::str::FromStr;
9
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Deserializer, Serialize, Serializer};
12
13/// A rust variant of the Postgres Numeric type. The full spectrum of Postgres'
14/// Numeric value range is supported.
15///
16/// Represented as an Optional BigDecimal. None for 'NaN', Some(bigdecimal) for
17/// all other values.
18#[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Clone)]
19pub struct PgNumeric {
20    pub n: Option<BigDecimal>,
21}
22
23impl PgNumeric {
24    /// Construct a new PgNumeric value from an optional BigDecimal
25    /// (None for NaN values).
26    pub fn new(n: Option<BigDecimal>) -> Self {
27        Self { n }
28    }
29
30    /// Returns true if this PgNumeric value represents a NaN value.
31    /// Otherwise returns false.
32    pub fn is_nan(&self) -> bool {
33        self.n.is_none()
34    }
35}
36
37use byteorder::{BigEndian, ReadBytesExt};
38use std::io::Cursor;
39
40impl<'a> FromSql<'a> for PgNumeric {
41    fn from_sql(
42        _: &Type,
43        raw: &'a [u8],
44    ) -> Result<Self, Box<dyn std::error::Error + 'static + Sync + Send>> {
45        let mut rdr = Cursor::new(raw);
46
47        let n_digits = rdr.read_u16::<BigEndian>()?;
48        let weight = rdr.read_i16::<BigEndian>()?;
49        let sign = match rdr.read_u16::<BigEndian>()? {
50            0x4000 => num::bigint::Sign::Minus,
51            0x0000 => num::bigint::Sign::Plus,
52            0xC000 => return Ok(Self { n: None }),
53            _ => return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "").into()),
54        };
55        let scale = rdr.read_u16::<BigEndian>()?;
56
57        let mut biguint = BigUint::from(0u32);
58        for n in (0..n_digits).rev() {
59            let digit = rdr.read_u16::<BigEndian>()?;
60            biguint += BigUint::from(digit) * BigUint::from(10_000u32).pow(n as u32);
61        }
62
63        // First digit in unsigned now has factor 10_000^(digits.len() - 1),
64        // but should have 10_000^weight
65        //
66        // Credits: this logic has been copied from rust Diesel's related code
67        // that provides the same translation from Postgres numeric into their
68        // related rust type.
69        let correction_exp = 4 * (i64::from(weight) - i64::from(n_digits) + 1);
70        let res = BigDecimal::new(BigInt::from_biguint(sign, biguint), -correction_exp)
71            .with_scale(i64::from(scale));
72
73        Ok(Self { n: Some(res) })
74    }
75
76    fn accepts(ty: &Type) -> bool {
77        matches!(*ty, Type::NUMERIC)
78    }
79}
80
81impl ToSql for PgNumeric {
82    fn to_sql(
83        &self,
84        _: &Type,
85        out: &mut BytesMut,
86    ) -> Result<IsNull, Box<dyn std::error::Error + 'static + Sync + Send>> {
87        fn write_header(out: &mut BytesMut, n_digits: u16, weight: i16, sign: u16, scale: u16) {
88            out.put_u16(n_digits);
89            out.put_i16(weight);
90            out.put_u16(sign);
91            out.put_u16(scale);
92        }
93        fn write_body(out: &mut BytesMut, digits: &[i16]) {
94            // write the body
95            for digit in digits {
96                out.put_i16(*digit);
97            }
98        }
99        fn write_nan(out: &mut BytesMut) {
100            // 8 bytes for the header (4 * 2byte numbers)
101            out.reserve(8);
102            write_header(out, 0, 0, 0xC000, 0);
103            // no body for nan
104        }
105
106        match &self.n {
107            None => {
108                write_nan(out);
109                Ok(IsNull::No)
110            }
111            Some(n) => {
112                let (bigint, exponent) = n.as_bigint_and_exponent();
113                let (sign, biguint) = bigint.into_parts();
114                let neg = sign == num::bigint::Sign::Minus;
115                let scale: i16 = exponent.try_into()?;
116
117                let (integer, decimal) = biguint.div_rem(&BigUint::from(10u32).pow(scale as u32));
118                let integer_digits: Vec<i16> = base10000(integer)?;
119                let mut weight = integer_digits.len().try_into().map(|len: i16| len - 1)?;
120
121                // must shift decimal part to align the decimal point between
122                // two 10000 based digits.
123                // note: shifted modulo by 1
124                //       (resulting in 1..4 instead of 0..3 ranges)
125                let decimal =
126                    decimal * BigUint::from(10_u32).pow((4 - ((scale - 1) % 4 + 1)) as u32);
127                let decimal_digits: Vec<i16> = base10000(decimal)?;
128
129                let have_decimals_weight: i16 = decimal_digits.len().try_into()?;
130                // the /4 is shifted by -1 to shift increments to
131                // <multiples of 4 + 1>
132                let want_decimals_weight = 1 + (scale - 1) / 4;
133                let correction_weight = want_decimals_weight - have_decimals_weight;
134                let mut decimal_zeroes_prefix: Vec<i16> = vec![];
135                if integer_digits.is_empty() {
136                    // if we have no integer part, can simply set weight to -
137                    weight -= correction_weight;
138                } else {
139                    // if we do have an integer part, cannot save space.
140                    //  we'll have to prefix the decimal part with 0 digits
141                    decimal_zeroes_prefix = std::iter::repeat(0_i16)
142                        .take(correction_weight.try_into()?)
143                        .collect();
144                }
145
146                let mut digits: Vec<i16> = vec![];
147                digits.extend(integer_digits);
148                digits.extend(decimal_zeroes_prefix);
149                digits.extend(decimal_digits);
150                strip_trailing_zeroes(&mut digits);
151                let n_digits = digits.len();
152
153                // 8 bytes for the header (4 * 2byte numbers)
154                // + 2 bytes per digit
155                out.reserve(8 + n_digits * 2);
156
157                write_header(
158                    out,
159                    n_digits.try_into()?,
160                    weight,
161                    if neg { 0x4000 } else { 0x0000 },
162                    scale.try_into()?,
163                );
164
165                write_body(out, &digits);
166
167                Ok(IsNull::No)
168            }
169        }
170    }
171
172    fn accepts(ty: &Type) -> bool {
173        matches!(*ty, Type::NUMERIC)
174    }
175
176    to_sql_checked!();
177}
178
179fn base10000(
180    mut n: BigUint,
181) -> Result<Vec<i16>, Box<dyn std::error::Error + 'static + Sync + Send>> {
182    let mut res: Vec<i16> = vec![];
183
184    while n != BigUint::from(0_u32) {
185        let (remainder, digit) = n.div_rem(&BigUint::from(10_000u32));
186        res.push(digit.try_into()?);
187        n = remainder;
188    }
189
190    res.reverse();
191    Ok(res)
192}
193
194fn strip_trailing_zeroes(digits: &mut Vec<i16>) {
195    let mut truncate_at = 0;
196    for (i, d) in digits.iter().enumerate().rev() {
197        if *d != 0 {
198            truncate_at = i + 1;
199            break;
200        }
201    }
202    digits.truncate(truncate_at);
203}
204
205#[test]
206fn strip_trailing_zeroes_tests() {
207    struct TestCase {
208        inp: Vec<i16>,
209        exp: Vec<i16>,
210    }
211    let test_cases: Vec<TestCase> = vec![
212        TestCase {
213            inp: vec![],
214            exp: vec![],
215        },
216        TestCase {
217            inp: vec![10, 5, 105],
218            exp: vec![10, 5, 105],
219        },
220        TestCase {
221            inp: vec![10, 5, 105, 0, 0, 0],
222            exp: vec![10, 5, 105],
223        },
224        TestCase {
225            inp: vec![0, 10, 0, 0, 5, 0, 105, 0, 0, 0],
226            exp: vec![0, 10, 0, 0, 5, 0, 105],
227        },
228        TestCase {
229            inp: vec![0],
230            exp: vec![],
231        },
232    ];
233
234    for tc in test_cases {
235        let mut got = tc.inp.clone();
236        strip_trailing_zeroes(&mut got);
237        assert_eq!(tc.exp, got);
238    }
239}
240
241#[test]
242fn base10000_tests() {
243    struct TestCase {
244        inp: BigUint,
245        exp: Vec<i16>,
246    }
247    let test_cases: Vec<TestCase> = vec![
248        TestCase {
249            inp: BigUint::parse_bytes(b"0", 10).unwrap(),
250            exp: vec![],
251        },
252        TestCase {
253            inp: BigUint::parse_bytes(b"1", 10).unwrap(),
254            exp: vec![1],
255        },
256        TestCase {
257            inp: BigUint::parse_bytes(b"10", 10).unwrap(),
258            exp: vec![10],
259        },
260        TestCase {
261            inp: BigUint::parse_bytes(b"100", 10).unwrap(),
262            exp: vec![100],
263        },
264        TestCase {
265            inp: BigUint::parse_bytes(b"1000", 10).unwrap(),
266            exp: vec![1000],
267        },
268        TestCase {
269            inp: BigUint::parse_bytes(b"9999", 10).unwrap(),
270            exp: vec![9999],
271        },
272        TestCase {
273            inp: BigUint::parse_bytes(b"10000", 10).unwrap(),
274            exp: vec![1, 0],
275        },
276        TestCase {
277            inp: BigUint::parse_bytes(b"100000000", 10).unwrap(),
278            exp: vec![1, 0, 0],
279        },
280        TestCase {
281            inp: BigUint::parse_bytes(b"900087000", 10).unwrap(),
282            exp: vec![9, 8, 7000],
283        },
284    ];
285    for tc in test_cases {
286        let got = base10000(tc.inp);
287        assert_eq!(tc.exp, got.unwrap());
288    }
289}
290
291#[test]
292fn integration_tests() {
293    use postgres::{Client, NoTls};
294    use std::str::FromStr;
295
296    let mut dbconn = Client::connect(
297        "host=localhost port=15432 user=test password=test dbname=test",
298        NoTls,
299    )
300    .unwrap();
301
302    dbconn
303        .execute("CREATE TABLE IF NOT EXISTS foobar (n numeric)", &[])
304        .unwrap();
305
306    let mut test_for_pgnumeric = |pgnumeric| {
307        dbconn.execute("DELETE FROM foobar;", &[]).unwrap();
308        dbconn
309            .execute("INSERT INTO foobar VALUES ($1)", &[&pgnumeric])
310            .unwrap();
311
312        let got: PgNumeric = dbconn
313            .query_one("SELECT n FROM foobar", &[])
314            .unwrap()
315            .get::<usize, Option<PgNumeric>>(0)
316            .unwrap();
317        assert_eq!(pgnumeric, got);
318
319        let got_as_str: String = dbconn
320            .query_one("SELECT n::text FROM foobar", &[])
321            .unwrap()
322            .get::<usize, Option<String>>(0)
323            .unwrap();
324        let got = match got_as_str.as_str() {
325            "NaN" => PgNumeric { n: None },
326            s => PgNumeric {
327                n: Some(BigDecimal::from_str(s).unwrap()),
328            },
329        };
330        assert_eq!(pgnumeric, got);
331    };
332
333    let tests = &[
334        "10",
335        "100",
336        "1000",
337        "10000",
338        "10100",
339        "30109",
340        "0.1",
341        "0.01",
342        "0.001",
343        "0.0001",
344        "0.00001",
345        "0.0000001",
346        "1.1",
347        "1.001",
348        "1.00001",
349        "3.14159265",
350        "98756756756756756756756757657657656756756756756757656745644534534535435434567567656756757658787687676855674456345345364564.5675675675765765765765765756",
351"204093200000000000000000000000000000000",
352        "nan"
353    ];
354    for n in tests {
355        let n = match n {
356            &"nan" => PgNumeric { n: None },
357            _ => PgNumeric {
358                n: Some(BigDecimal::from_str(n).unwrap()),
359            },
360        };
361
362        test_for_pgnumeric(n);
363    }
364
365    for n in tests {
366        if n == &"nan" {
367            continue;
368        }
369
370        let n = PgNumeric {
371            n: Some(BigDecimal::from_str(n).unwrap() * BigDecimal::from(-1)),
372        };
373        test_for_pgnumeric(n);
374    }
375}
376
377#[cfg(feature = "serde")]
378impl Serialize for PgNumeric {
379    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
380    where
381        S: Serializer,
382    {
383        match &self.n {
384            None => serializer.serialize_none(),
385            Some(bigdecimal) => serializer.serialize_some(&bigdecimal.to_string().as_str()),
386        }
387    }
388}
389
390#[cfg(feature = "serde")]
391impl<'a> Deserialize<'a> for PgNumeric {
392    fn deserialize<D>(deserializer: D) -> Result<PgNumeric, D::Error>
393    where
394        D: Deserializer<'a>,
395    {
396        struct BigDecimalVisitor {}
397        impl<'de> serde::de::Visitor<'de> for BigDecimalVisitor {
398            type Value = Option<BigDecimal>;
399
400            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
401                write!(formatter, "a string that is parseable as a bigdecimal",)
402            }
403
404            fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
405            where
406                E: serde::de::Error,
407            {
408                Ok(Some(BigDecimal::from_str(s).unwrap()))
409            }
410
411            fn visit_some<D>(self, d: D) -> Result<Self::Value, D::Error>
412            where
413                D: Deserializer<'de>,
414            {
415                d.deserialize_str(BigDecimalVisitor {})
416            }
417
418            fn visit_none<E>(self) -> Result<Self::Value, E>
419            where
420                E: serde::de::Error,
421            {
422                Ok(None)
423            }
424        }
425
426        let n = deserializer.deserialize_option(BigDecimalVisitor {})?;
427        Ok(PgNumeric { n })
428    }
429}