1use bytes::{BufMut, BytesMut};
2use postgres::types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
3use std::convert::TryInto;
4
5pub use bigdecimal::BigDecimal;
6pub use num::{BigInt, BigUint, Integer};
7#[cfg(feature = "serde")]
8use std::str::FromStr;
9
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Deserializer, Serialize, Serializer};
12
13#[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Clone)]
19pub struct PgNumeric {
20 pub n: Option<BigDecimal>,
21}
22
23impl PgNumeric {
24 pub fn new(n: Option<BigDecimal>) -> Self {
27 Self { n }
28 }
29
30 pub fn is_nan(&self) -> bool {
33 self.n.is_none()
34 }
35}
36
37use byteorder::{BigEndian, ReadBytesExt};
38use std::io::Cursor;
39
40impl<'a> FromSql<'a> for PgNumeric {
41 fn from_sql(
42 _: &Type,
43 raw: &'a [u8],
44 ) -> Result<Self, Box<dyn std::error::Error + 'static + Sync + Send>> {
45 let mut rdr = Cursor::new(raw);
46
47 let n_digits = rdr.read_u16::<BigEndian>()?;
48 let weight = rdr.read_i16::<BigEndian>()?;
49 let sign = match rdr.read_u16::<BigEndian>()? {
50 0x4000 => num::bigint::Sign::Minus,
51 0x0000 => num::bigint::Sign::Plus,
52 0xC000 => return Ok(Self { n: None }),
53 _ => return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "").into()),
54 };
55 let scale = rdr.read_u16::<BigEndian>()?;
56
57 let mut biguint = BigUint::from(0u32);
58 for n in (0..n_digits).rev() {
59 let digit = rdr.read_u16::<BigEndian>()?;
60 biguint += BigUint::from(digit) * BigUint::from(10_000u32).pow(n as u32);
61 }
62
63 let correction_exp = 4 * (i64::from(weight) - i64::from(n_digits) + 1);
70 let res = BigDecimal::new(BigInt::from_biguint(sign, biguint), -correction_exp)
71 .with_scale(i64::from(scale));
72
73 Ok(Self { n: Some(res) })
74 }
75
76 fn accepts(ty: &Type) -> bool {
77 matches!(*ty, Type::NUMERIC)
78 }
79}
80
81impl ToSql for PgNumeric {
82 fn to_sql(
83 &self,
84 _: &Type,
85 out: &mut BytesMut,
86 ) -> Result<IsNull, Box<dyn std::error::Error + 'static + Sync + Send>> {
87 fn write_header(out: &mut BytesMut, n_digits: u16, weight: i16, sign: u16, scale: u16) {
88 out.put_u16(n_digits);
89 out.put_i16(weight);
90 out.put_u16(sign);
91 out.put_u16(scale);
92 }
93 fn write_body(out: &mut BytesMut, digits: &[i16]) {
94 for digit in digits {
96 out.put_i16(*digit);
97 }
98 }
99 fn write_nan(out: &mut BytesMut) {
100 out.reserve(8);
102 write_header(out, 0, 0, 0xC000, 0);
103 }
105
106 match &self.n {
107 None => {
108 write_nan(out);
109 Ok(IsNull::No)
110 }
111 Some(n) => {
112 let (bigint, exponent) = n.as_bigint_and_exponent();
113 let (sign, biguint) = bigint.into_parts();
114 let neg = sign == num::bigint::Sign::Minus;
115 let scale: i16 = exponent.try_into()?;
116
117 let (integer, decimal) = biguint.div_rem(&BigUint::from(10u32).pow(scale as u32));
118 let integer_digits: Vec<i16> = base10000(integer)?;
119 let mut weight = integer_digits.len().try_into().map(|len: i16| len - 1)?;
120
121 let decimal =
126 decimal * BigUint::from(10_u32).pow((4 - ((scale - 1) % 4 + 1)) as u32);
127 let decimal_digits: Vec<i16> = base10000(decimal)?;
128
129 let have_decimals_weight: i16 = decimal_digits.len().try_into()?;
130 let want_decimals_weight = 1 + (scale - 1) / 4;
133 let correction_weight = want_decimals_weight - have_decimals_weight;
134 let mut decimal_zeroes_prefix: Vec<i16> = vec![];
135 if integer_digits.is_empty() {
136 weight -= correction_weight;
138 } else {
139 decimal_zeroes_prefix = std::iter::repeat(0_i16)
142 .take(correction_weight.try_into()?)
143 .collect();
144 }
145
146 let mut digits: Vec<i16> = vec![];
147 digits.extend(integer_digits);
148 digits.extend(decimal_zeroes_prefix);
149 digits.extend(decimal_digits);
150 strip_trailing_zeroes(&mut digits);
151 let n_digits = digits.len();
152
153 out.reserve(8 + n_digits * 2);
156
157 write_header(
158 out,
159 n_digits.try_into()?,
160 weight,
161 if neg { 0x4000 } else { 0x0000 },
162 scale.try_into()?,
163 );
164
165 write_body(out, &digits);
166
167 Ok(IsNull::No)
168 }
169 }
170 }
171
172 fn accepts(ty: &Type) -> bool {
173 matches!(*ty, Type::NUMERIC)
174 }
175
176 to_sql_checked!();
177}
178
179fn base10000(
180 mut n: BigUint,
181) -> Result<Vec<i16>, Box<dyn std::error::Error + 'static + Sync + Send>> {
182 let mut res: Vec<i16> = vec![];
183
184 while n != BigUint::from(0_u32) {
185 let (remainder, digit) = n.div_rem(&BigUint::from(10_000u32));
186 res.push(digit.try_into()?);
187 n = remainder;
188 }
189
190 res.reverse();
191 Ok(res)
192}
193
194fn strip_trailing_zeroes(digits: &mut Vec<i16>) {
195 let mut truncate_at = 0;
196 for (i, d) in digits.iter().enumerate().rev() {
197 if *d != 0 {
198 truncate_at = i + 1;
199 break;
200 }
201 }
202 digits.truncate(truncate_at);
203}
204
205#[test]
206fn strip_trailing_zeroes_tests() {
207 struct TestCase {
208 inp: Vec<i16>,
209 exp: Vec<i16>,
210 }
211 let test_cases: Vec<TestCase> = vec![
212 TestCase {
213 inp: vec![],
214 exp: vec![],
215 },
216 TestCase {
217 inp: vec![10, 5, 105],
218 exp: vec![10, 5, 105],
219 },
220 TestCase {
221 inp: vec![10, 5, 105, 0, 0, 0],
222 exp: vec![10, 5, 105],
223 },
224 TestCase {
225 inp: vec![0, 10, 0, 0, 5, 0, 105, 0, 0, 0],
226 exp: vec![0, 10, 0, 0, 5, 0, 105],
227 },
228 TestCase {
229 inp: vec![0],
230 exp: vec![],
231 },
232 ];
233
234 for tc in test_cases {
235 let mut got = tc.inp.clone();
236 strip_trailing_zeroes(&mut got);
237 assert_eq!(tc.exp, got);
238 }
239}
240
241#[test]
242fn base10000_tests() {
243 struct TestCase {
244 inp: BigUint,
245 exp: Vec<i16>,
246 }
247 let test_cases: Vec<TestCase> = vec![
248 TestCase {
249 inp: BigUint::parse_bytes(b"0", 10).unwrap(),
250 exp: vec![],
251 },
252 TestCase {
253 inp: BigUint::parse_bytes(b"1", 10).unwrap(),
254 exp: vec![1],
255 },
256 TestCase {
257 inp: BigUint::parse_bytes(b"10", 10).unwrap(),
258 exp: vec![10],
259 },
260 TestCase {
261 inp: BigUint::parse_bytes(b"100", 10).unwrap(),
262 exp: vec![100],
263 },
264 TestCase {
265 inp: BigUint::parse_bytes(b"1000", 10).unwrap(),
266 exp: vec![1000],
267 },
268 TestCase {
269 inp: BigUint::parse_bytes(b"9999", 10).unwrap(),
270 exp: vec![9999],
271 },
272 TestCase {
273 inp: BigUint::parse_bytes(b"10000", 10).unwrap(),
274 exp: vec![1, 0],
275 },
276 TestCase {
277 inp: BigUint::parse_bytes(b"100000000", 10).unwrap(),
278 exp: vec![1, 0, 0],
279 },
280 TestCase {
281 inp: BigUint::parse_bytes(b"900087000", 10).unwrap(),
282 exp: vec![9, 8, 7000],
283 },
284 ];
285 for tc in test_cases {
286 let got = base10000(tc.inp);
287 assert_eq!(tc.exp, got.unwrap());
288 }
289}
290
291#[test]
292fn integration_tests() {
293 use postgres::{Client, NoTls};
294 use std::str::FromStr;
295
296 let mut dbconn = Client::connect(
297 "host=localhost port=15432 user=test password=test dbname=test",
298 NoTls,
299 )
300 .unwrap();
301
302 dbconn
303 .execute("CREATE TABLE IF NOT EXISTS foobar (n numeric)", &[])
304 .unwrap();
305
306 let mut test_for_pgnumeric = |pgnumeric| {
307 dbconn.execute("DELETE FROM foobar;", &[]).unwrap();
308 dbconn
309 .execute("INSERT INTO foobar VALUES ($1)", &[&pgnumeric])
310 .unwrap();
311
312 let got: PgNumeric = dbconn
313 .query_one("SELECT n FROM foobar", &[])
314 .unwrap()
315 .get::<usize, Option<PgNumeric>>(0)
316 .unwrap();
317 assert_eq!(pgnumeric, got);
318
319 let got_as_str: String = dbconn
320 .query_one("SELECT n::text FROM foobar", &[])
321 .unwrap()
322 .get::<usize, Option<String>>(0)
323 .unwrap();
324 let got = match got_as_str.as_str() {
325 "NaN" => PgNumeric { n: None },
326 s => PgNumeric {
327 n: Some(BigDecimal::from_str(s).unwrap()),
328 },
329 };
330 assert_eq!(pgnumeric, got);
331 };
332
333 let tests = &[
334 "10",
335 "100",
336 "1000",
337 "10000",
338 "10100",
339 "30109",
340 "0.1",
341 "0.01",
342 "0.001",
343 "0.0001",
344 "0.00001",
345 "0.0000001",
346 "1.1",
347 "1.001",
348 "1.00001",
349 "3.14159265",
350 "98756756756756756756756757657657656756756756756757656745644534534535435434567567656756757658787687676855674456345345364564.5675675675765765765765765756",
351"204093200000000000000000000000000000000",
352 "nan"
353 ];
354 for n in tests {
355 let n = match n {
356 &"nan" => PgNumeric { n: None },
357 _ => PgNumeric {
358 n: Some(BigDecimal::from_str(n).unwrap()),
359 },
360 };
361
362 test_for_pgnumeric(n);
363 }
364
365 for n in tests {
366 if n == &"nan" {
367 continue;
368 }
369
370 let n = PgNumeric {
371 n: Some(BigDecimal::from_str(n).unwrap() * BigDecimal::from(-1)),
372 };
373 test_for_pgnumeric(n);
374 }
375}
376
377#[cfg(feature = "serde")]
378impl Serialize for PgNumeric {
379 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
380 where
381 S: Serializer,
382 {
383 match &self.n {
384 None => serializer.serialize_none(),
385 Some(bigdecimal) => serializer.serialize_some(&bigdecimal.to_string().as_str()),
386 }
387 }
388}
389
390#[cfg(feature = "serde")]
391impl<'a> Deserialize<'a> for PgNumeric {
392 fn deserialize<D>(deserializer: D) -> Result<PgNumeric, D::Error>
393 where
394 D: Deserializer<'a>,
395 {
396 struct BigDecimalVisitor {}
397 impl<'de> serde::de::Visitor<'de> for BigDecimalVisitor {
398 type Value = Option<BigDecimal>;
399
400 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
401 write!(formatter, "a string that is parseable as a bigdecimal",)
402 }
403
404 fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
405 where
406 E: serde::de::Error,
407 {
408 Ok(Some(BigDecimal::from_str(s).unwrap()))
409 }
410
411 fn visit_some<D>(self, d: D) -> Result<Self::Value, D::Error>
412 where
413 D: Deserializer<'de>,
414 {
415 d.deserialize_str(BigDecimalVisitor {})
416 }
417
418 fn visit_none<E>(self) -> Result<Self::Value, E>
419 where
420 E: serde::de::Error,
421 {
422 Ok(None)
423 }
424 }
425
426 let n = deserializer.deserialize_option(BigDecimalVisitor {})?;
427 Ok(PgNumeric { n })
428 }
429}