diff --git a/Cargo.lock b/Cargo.lock index 6a882762a..e6c71af15 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -40,6 +40,15 @@ dependencies = [ "threadpool", ] +[[package]] +name = "addr2line" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a49806b9dadc843c61e7c97e72490ad7f7220ae249012fbda9ad0609457c0543" +dependencies = [ + "gimli", +] + [[package]] name = "ahash" version = "0.3.5" @@ -66,15 +75,15 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.28" +version = "1.0.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9a60d744a80c30fcb657dfe2c1b22bcb3e814c1a1e3674f32bf5820b570fbff" +checksum = "85bb70cc08ec97ca5450e6eba421deeea5f172c0fc61f78b5357b2a8e8be195f" [[package]] name = "arc-swap" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d663a8e9a99154b5fb793032533f6328da35e23aac63d5c152279aa8ba356825" +checksum = "b585a98a234c46fc563103e9278c9391fde1f4e6850334da895d27edb9580f62" [[package]] name = "arrayref" @@ -164,9 +173,9 @@ checksum = "c17772156ef2829aadc587461c7753af20b7e8db1529bc66855add962a3b35d3" [[package]] name = "async-trait" -version = "0.1.30" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da71fef07bc806586090247e971229289f64c210a278ee5ae419314eb386b31d" +checksum = "26c4f3195085c36ea8d24d32b2f828d23296a9370a28aa39d111f6f16bef9f3b" dependencies = [ "proc-macro2", "quote", @@ -201,26 +210,17 @@ checksum = "f8aac770f1885fd7e387acedd76065302551364496e46b3dd00860b2f8359b9d" [[package]] name = "backtrace" -version = "0.3.46" +version = "0.3.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1e692897359247cc6bb902933361652380af0f1b7651ae5c5013407f30e109e" +checksum = "0df2f85c8a2abbe3b7d7e748052fdd9b76a0458fdeb16ad4223f5eca78c7c130" dependencies = [ - "backtrace-sys", + "addr2line", "cfg-if", "libc", + "object", "rustc-demangle", ] -[[package]] -name = "backtrace-sys" -version = "0.1.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7de8aba10a69c8e8d7622c5710229485ec32e9d55fdad160ea559c086fdcd118" -dependencies = [ - "cc", - "libc", -] - [[package]] name = "base-x" version = "0.2.6" @@ -244,15 +244,15 @@ checksum = "b41b7ea54a0c9d92199de89e20e58d49f02f8e699814ef3fdf266f6f748d15c7" [[package]] name = "base64" -version = "0.12.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d5ca2cd0adc3f48f9e9ea5a6bbdf9ccc0bfade884847e484d452414c7ccffb3" +checksum = "53d1ccbaf7d9ec9537465a97bf19edc1a4e158ecb49fc16178202238c569cc42" [[package]] name = "bigdecimal" -version = "0.1.0" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "460825c9e21708024d67c07057cd5560e5acdccac85de0de624a81d3de51bacb" +checksum = "1374191e2dd25f9ae02e3aa95041ed5d747fc77b3c102b49fe2dd9a8117a6244" dependencies = [ "num-bigint", "num-integer", @@ -299,9 +299,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.2.1" +version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12ae9db68ad7fac5fe51304d20f016c911539251075a214f8e663babefa35187" +checksum = "5356f1d23ee24a1f785a56d1d1a5f0fd5b0f6a0c0fb2412ce11da71649ab78f6" [[package]] name = "byte-tools" @@ -346,9 +346,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.50" +version = "1.0.54" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95e28fa049fda1c330bcf9d723be7663a899c4679724b34c81e9f5a326aab8cd" +checksum = "7bbb73db36c1246e9034e307d0fba23f9a2e251faa47ade70c1bd252220c8311" [[package]] name = "cfg-if" @@ -365,14 +365,14 @@ dependencies = [ "num-integer", "num-traits", "serde", - "time 0.1.42", + "time 0.1.43", ] [[package]] name = "clap" -version = "2.33.0" +version = "2.33.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5067f5bb2d80ef5d68b4c87db81601f0b75bca627bc2ef76b141d7b846a3c6d9" +checksum = "bdfa80d47f954d53a35a64987ca1422f495b8d6483c0fe9f7117b36c2a792129" dependencies = [ "ansi_term", "atty", @@ -406,15 +406,16 @@ dependencies = [ [[package]] name = "console" -version = "0.10.0" +version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6728a28023f207181b193262711102bfbaf47cc9d13bc71d0736607ef8efe88c" +checksum = "2586208b33573b7f76ccfbe5adb076394c88deaf81b84d7213969805b0a952a7" dependencies = [ "clicolors-control", "encode_unicode", "lazy_static", "libc", "regex", + "terminal_size", "termios", "unicode-width", "winapi 0.3.8", @@ -432,7 +433,7 @@ version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "888604f00b3db336d2af898ec3c1d5d0ddf5e6d462220f2ededc33a87ac4bbd5" dependencies = [ - "time 0.1.42", + "time 0.1.43", "url 1.7.2", ] @@ -510,9 +511,9 @@ dependencies = [ [[package]] name = "crossbeam-queue" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c695eeca1e7173472a32221542ae469b3e9aac3a4fc81f7696bcad82029493db" +checksum = "ab6bffe714b6bb07e42f201352c34f51fefd355ace793f9e638ebd52d23f98d2" dependencies = [ "cfg-if", "crossbeam-utils 0.7.2", @@ -551,9 +552,9 @@ dependencies = [ [[package]] name = "data-encoding" -version = "2.2.0" +version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11c0346158a19b3627234e15596f5e465c360fcdb97d817bcb255e0510f5a788" +checksum = "72aa14c04dfae8dd7d8a2b1cb7ca2152618cd01336dbfe704b8dcbf8d41dbd69" [[package]] name = "derive_more" @@ -568,9 +569,9 @@ dependencies = [ [[package]] name = "dialoguer" -version = "0.5.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94616e25d2c04fc97253d145f6ca33ad84a584258dc70c4e621cc79a57f903b6" +checksum = "d8b5eb0fce3c4f955b8d8d864b131fb8863959138da962026c106ba7a2e3bf7a" dependencies = [ "console", "lazy_static", @@ -641,9 +642,9 @@ checksum = "e88a8acf291dafb59c2d96e8f59828f3838bb1a70398823ade51a84de6a6deed" [[package]] name = "fnv" -version = "1.0.6" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fad85553e09a6f881f739c29f0b00b0f01357c743266d478b68951ce23285f3" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "foreign-types" @@ -684,9 +685,9 @@ checksum = "1b980f2816d6ee8673b6517b52cb0e808a180efc92e5c19d02cdda79066703ef" [[package]] name = "futures" -version = "0.3.4" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c329ae8753502fb44ae4fc2b622fa2a94652c41e795143765ba0927f92ab780" +checksum = "1e05b85ec287aac0dc34db7d4a569323df697f9c55b99b15d6b4ef8cde49f613" dependencies = [ "futures-channel", "futures-core", @@ -725,9 +726,9 @@ dependencies = [ [[package]] name = "futures-executor" -version = "0.3.4" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f674f3e1bcb15b37284a90cedf55afdba482ab061c407a9c0ebbd0f3109741ba" +checksum = "10d6bb888be1153d3abeb9006b11b02cf5e9b209fda28693c31ae1e4e012e314" dependencies = [ "futures-core", "futures-task", @@ -819,6 +820,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "gimli" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bcc8e0c9bce37868955864dbecd2b1ab2bdf967e6f28066d65aaac620444b65c" + [[package]] name = "glob" version = "0.3.0" @@ -877,9 +884,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.1.10" +version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "725cf19794cf90aa94e65050cb4191ff5d8fa87a498383774c47b332e3af952e" +checksum = "91780f809e750b0a89f5544be56617ff6b1227ee485bcb06ebe10cdf89bd3b71" dependencies = [ "libc", ] @@ -931,7 +938,7 @@ checksum = "9625f605ddfaf894bf78a544a7b8e31f562dc843654723a49892d9c7e75ac708" dependencies = [ "async-std", "bytes 0.4.12", - "futures 0.3.4", + "futures 0.3.5", "http", "pin-project-lite", ] @@ -942,7 +949,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e33d5dae94e0fdb82f9524ea2f2b98458b3d8448526d8cc8beccb3d3fded8aff" dependencies = [ - "futures 0.3.4", + "futures 0.3.5", "http", "http-service", "hyper", @@ -981,7 +988,7 @@ dependencies = [ "log", "net2", "rustc_version", - "time 0.1.42", + "time 0.1.43", "tokio 0.1.22", "tokio-buf", "tokio-executor", @@ -1090,9 +1097,9 @@ dependencies = [ [[package]] name = "kv-log-macro" -version = "1.0.4" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c54d9f465d530a752e6ebdc217e081a7a614b48cb200f6f0aee21ba6bc9aabb" +checksum = "4ff57d6d215f7ca7eb35a9a64d656ba4d9d2bef114d741dc08048e75e2f5d418" dependencies = [ "log", ] @@ -1190,9 +1197,9 @@ checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d" [[package]] name = "mio" -version = "0.6.21" +version = "0.6.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "302dec22bcf6bae6dfb69c647187f4b4d0fb6f535521f7bc022430ce8e12008f" +checksum = "fce347092656428bc8eaf6201042cb551b8d67855af7374542a92a0fbfcac430" dependencies = [ "cfg-if", "fuchsia-zircon", @@ -1215,15 +1222,15 @@ checksum = "f5e374eff525ce1c5b7687c4cef63943e7686524a387933ad27ca7ec43779cb3" dependencies = [ "log", "mio", - "miow 0.3.3", + "miow 0.3.4", "winapi 0.3.8", ] [[package]] name = "mio-uds" -version = "0.6.7" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "966257a94e196b11bb43aca423754d87429960a768de9414f3691d6957abf125" +checksum = "afcb699eb26d4332647cc848492bbc15eafb26f08d0304550d5aa1f612e066f0" dependencies = [ "iovec", "libc", @@ -1244,9 +1251,9 @@ dependencies = [ [[package]] name = "miow" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "396aa0f2003d7df8395cb93e09871561ccc3e785f0acb369170e8cc74ddf9226" +checksum = "22dfdd1d51b2639a5abd17ed07005c3af05fb7a2a3b1a1d0d7af1000a520c1c7" dependencies = [ "socket2", "winapi 0.3.8", @@ -1272,9 +1279,9 @@ dependencies = [ [[package]] name = "net2" -version = "0.2.33" +version = "0.2.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42550d9fb7b6684a6d404d9fa7250c2eb2646df731d1c06afc06dcee9e1bcf88" +checksum = "2ba7c918ac76704fb42afcbbb43891e72731f3dcca3bef2a19786297baf14af7" dependencies = [ "cfg-if", "libc", @@ -1326,19 +1333,25 @@ dependencies = [ [[package]] name = "num_cpus" -version = "1.12.0" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46203554f085ff89c235cd12f7075f3233af9b11ed7c9e16dfe2560d03313ce6" +checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" dependencies = [ "hermit-abi", "libc", ] [[package]] -name = "once_cell" -version = "1.3.1" +name = "object" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1c601810575c99596d4afc46f78a678c80105117c379eb3650cf99b8a21ce5b" +checksum = "9cbca9424c482ee628fa549d9c812e2cd22f1180b9222c9200fdfa6eb31aecb2" + +[[package]] +name = "once_cell" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b631f7e854af39a1739f401cf34a8a013dfe09eac4fa4dba91e9768bd28168d" [[package]] name = "opaque-debug" @@ -1348,9 +1361,9 @@ checksum = "2839e79665f131bdb5782e51f2c6c9599c133c6098982a54c794358bf432529c" [[package]] name = "openssl" -version = "0.10.28" +version = "0.10.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "973293749822d7dd6370d6da1e523b0d1db19f06c459134c658b2a4261378b52" +checksum = "cee6d85f4cb4c4f59a6a85d5b68a233d280c82e29e822913b9c8b129fbf20bdd" dependencies = [ "bitflags", "cfg-if", @@ -1368,9 +1381,9 @@ checksum = "77af24da69f9d9341038eba93a073b1fdaaa1b788221b00a69bce9e762cb32de" [[package]] name = "openssl-sys" -version = "0.9.54" +version = "0.9.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1024c0a59774200a555087a6da3f253a9095a5f344e353b212ac4c8b8e450986" +checksum = "7410fef80af8ac071d4f63755c0ab89ac3df0fd1ea91f1d1f37cf5cec4395990" dependencies = [ "autocfg", "cc", @@ -1431,9 +1444,9 @@ dependencies = [ [[package]] name = "paste" -version = "0.1.9" +version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "092d791bf7847f70bbd49085489fba25fc2c193571752bff9e36e74e72403932" +checksum = "d53181dcd37421c08d3b69f887784956674d09c3f9a47a04fece2b130a5b346b" dependencies = [ "paste-impl", "proc-macro-hack", @@ -1441,9 +1454,9 @@ dependencies = [ [[package]] name = "paste-impl" -version = "0.1.9" +version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "406c23fb4c45cc6f68a9bbabb8ec7bd6f8cfcbd17e9e8f72c2460282f8325729" +checksum = "05ca490fa1c034a71412b4d1edcb904ec5a0981a4426c9eb2128c0fda7a68d17" dependencies = [ "proc-macro-hack", "proc-macro2", @@ -1556,9 +1569,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "237844750cfbb86f67afe27eee600dfbbcb6188d734139b534cbfbf4f96792ae" +checksum = "f7505eeebd78492e0f6108f7171c4948dbb120ee8119d9d77d0afa5469bef67f" [[package]] name = "pin-utils" @@ -1586,15 +1599,15 @@ checksum = "05da548ad6865900e60eaba7f589cc0783590a92e940c26953ff81ddbab2d677" [[package]] name = "ppv-lite86" -version = "0.2.6" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b" +checksum = "237a5ed80e274dbc66f86bd59c1e25edc039660be53194b5fe0a482e0f2612ea" [[package]] name = "proc-macro-error" -version = "0.4.12" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18f33027081eba0a6d8aba6d1b1c3a3be58cbb12106341c2d5759fcd9b5277e7" +checksum = "98e9e4b82e0ef281812565ea4751049f1bdcdfccda7d3f459f2e138a40c08678" dependencies = [ "proc-macro-error-attr", "proc-macro2", @@ -1605,9 +1618,9 @@ dependencies = [ [[package]] name = "proc-macro-error-attr" -version = "0.4.12" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a5b4b77fdb63c1eca72173d68d24501c54ab1269409f6b672c85deb18af69de" +checksum = "4f5444ead4e9935abd7f27dc51f7e852a0569ac888096d5ec2499470794e2e53" dependencies = [ "proc-macro2", "quote", @@ -1618,9 +1631,9 @@ dependencies = [ [[package]] name = "proc-macro-hack" -version = "0.5.14" +version = "0.5.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcfdefadc3d57ca21cf17990a28ef4c0f7c61383a28cb7604cf4a18e6ede1420" +checksum = "7e0456befd48169b9f13ef0f0ad46d492cf9d2dbb918bcf38e01eed4ce3ec5e4" [[package]] name = "proc-macro-nested" @@ -1645,9 +1658,9 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quote" -version = "1.0.3" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bdc6c187c65bca4260c9011c9e3132efe4909da44726bad24cf7572ae338d7f" +checksum = "54a21852a652ad6f610c9510194f398ff6f8692e334fd1145fed931f7fbe44ea" dependencies = [ "proc-macro2", ] @@ -1711,9 +1724,9 @@ checksum = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84" [[package]] name = "regex" -version = "1.3.6" +version = "1.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f6946991529684867e47d86474e3a6d0c0ab9b82d5821e314b1ede31fa3a4b3" +checksum = "9c3780fcf44b193bc4d09f36d2a3c87b251da4a046c87795a0d35f4f927ad8e6" dependencies = [ "aho-corasick", "memchr", @@ -1723,9 +1736,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.6.17" +version = "0.6.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fe5bd57d1d7414c6b5ed48563a2c855d995ff777729dcd91c369ec7fea395ae" +checksum = "26412eb97c6b088a6997e05f69403a802a92d520de2f8e63c2b65f9e0f47c4e8" [[package]] name = "remove_dir_all" @@ -1785,15 +1798,15 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "535622e6be132bccd223f4bb2b8ac8d53cda3c7a6394944d3b2b33fb974f9d76" +checksum = "ed3d612bc64430efeb3f7ee6ef26d590dce0c43249217bddc62112540c7941e1" [[package]] name = "schannel" -version = "0.1.18" +version = "0.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "039c25b130bd8c1321ee2d7de7fde2659fa9c2744e4bb29711cfc852ea53cd19" +checksum = "8f05ba609c234e60bee0d547fe94a4c7e9da733d1c962cf6e59efa4cd9c8bc75" dependencies = [ "lazy_static", "winapi 0.3.8", @@ -1813,9 +1826,9 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "security-framework" -version = "0.4.2" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "572dfa3a0785509e7a44b5b4bebcf94d41ba34e9ed9eb9df722545c3b3c4144a" +checksum = "64808902d7d99f78eaddd2b4e2509713babc3dc3c85ad6f4c447680f3c01e535" dependencies = [ "bitflags", "core-foundation", @@ -1826,9 +1839,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "0.4.2" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ddb15a5fec93b7021b8a9e96009c5d8d51c15673569f7c0f6b7204e5b7b404f" +checksum = "17bf11d99252f512695eb468de5516e5cf75455521e69dfe343f3b74e4748405" dependencies = [ "core-foundation-sys", "libc", @@ -1858,18 +1871,18 @@ checksum = "f638d531eccd6e23b980caf34876660d38e265409d8e99b397ab71eb3612fad0" [[package]] name = "serde" -version = "1.0.110" +version = "1.0.111" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99e7b308464d16b56eba9964e4972a3eee817760ab60d88c3f86e1fecb08204c" +checksum = "c9124df5b40cbd380080b2cc6ab894c040a3070d995f5c9dc77e18c34a8ae37d" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.110" +version = "1.0.111" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "818fbf6bfa9a42d3bfcaca148547aa00c7b915bec71d1757aa2d44ca68771984" +checksum = "3f2c3ac8e6ca1e9c80b8be1023940162bf81ae3cffbb1809474152f2ce1eb250" dependencies = [ "proc-macro2", "quote", @@ -1920,9 +1933,9 @@ checksum = "2579985fda508104f7587689507983eadd6a6e84dd35d6d115361f530916fa0d" [[package]] name = "sha2" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27044adfd2e1f077f649f59deb9490d3941d674002f7d062870a60ebe9bd47a0" +checksum = "a256f46ea78a0c0d9ff00077504903ac881a1dafdc20da66545699e7776b3e69" dependencies = [ "block-buffer", "digest", @@ -2023,7 +2036,7 @@ dependencies = [ "async-std", "dotenv", "env_logger", - "futures 0.3.4", + "futures 0.3.5", "paste", "serde", "serde_json", @@ -2046,7 +2059,7 @@ dependencies = [ "console", "dialoguer", "dotenv", - "futures 0.3.4", + "futures 0.3.5", "glob", "serde", "serde_json", @@ -2062,7 +2075,7 @@ version = "0.3.5" dependencies = [ "async-stream", "atoi", - "base64 0.12.0", + "base64 0.12.1", "bigdecimal", "bitflags", "byteorder", @@ -2113,7 +2126,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-std", - "futures 0.3.4", + "futures 0.3.5", "paw", "sqlx", "structopt", @@ -2124,7 +2137,7 @@ name = "sqlx-example-postgres-listen" version = "0.1.0" dependencies = [ "async-std", - "futures 0.3.4", + "futures 0.3.5", "sqlx", ] @@ -2135,7 +2148,7 @@ dependencies = [ "anyhow", "async-std", "dotenv", - "futures 0.3.4", + "futures 0.3.5", "paw", "sqlx", "structopt", @@ -2150,7 +2163,7 @@ dependencies = [ "async-trait", "chrono", "env_logger", - "futures 0.3.4", + "futures 0.3.5", "heck", "http", "itertools", @@ -2173,7 +2186,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-std", - "futures 0.3.4", + "futures 0.3.5", "paw", "sqlx", "structopt", @@ -2185,7 +2198,7 @@ version = "0.3.5" dependencies = [ "async-std", "dotenv", - "futures 0.3.4", + "futures 0.3.5", "heck", "hex", "once_cell", @@ -2297,9 +2310,9 @@ checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" [[package]] name = "structopt" -version = "0.3.12" +version = "0.3.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8faa2719539bbe9d77869bfb15d4ee769f99525e707931452c97b693b3f159d" +checksum = "863246aaf5ddd0d6928dfeb1a9ca65f505599e4e1b399935ef7e75107516b4ef" dependencies = [ "clap", "lazy_static", @@ -2308,9 +2321,9 @@ dependencies = [ [[package]] name = "structopt-derive" -version = "0.4.5" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f88b8e18c69496aad6f9ddf4630dd7d585bcaf765786cb415b9aec2fe5a0430" +checksum = "d239ca4b13aee7a2142e6795cbd69e457665ff8037aed33b3effdc430d2f927a" dependencies = [ "heck", "proc-macro-error", @@ -2370,6 +2383,16 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "terminal_size" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8038f95fc7a6f351163f4b964af631bd26c9e828f7db085f2a84aca56f70d13b" +dependencies = [ + "libc", + "winapi 0.3.8", +] + [[package]] name = "termios" version = "0.3.2" @@ -2419,9 +2442,9 @@ dependencies = [ [[package]] name = "threadpool" -version = "1.7.1" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2f0c90a5f3459330ac8bc0d2f879c693bb7a2f59689c1083fc4ef83834da865" +checksum = "d050e60b33d41c19108b32cea32164033a9013fe3b46cbd4457559bfbf77afaa" dependencies = [ "num_cpus", ] @@ -2434,7 +2457,7 @@ checksum = "e619c99048ae107912703d0efeec4ff4fbff704f064e51d3eee614b28ea7b739" dependencies = [ "async-std", "cookie", - "futures 0.3.4", + "futures 0.3.5", "http", "http-service", "http-service-hyper", @@ -2449,12 +2472,11 @@ dependencies = [ [[package]] name = "time" -version = "0.1.42" +version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db8dcfca086c1143c9270ac42a2bbd8a7ee477b78ac8e45b19abfb0cbede4b6f" +checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438" dependencies = [ "libc", - "redox_syscall", "winapi 0.3.8", ] @@ -2485,13 +2507,14 @@ dependencies = [ [[package]] name = "time-macros-impl" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e987cfe0537f575b5fc99909de6185f6c19c3ad8889e2275e686a873d0869ba1" +checksum = "e5c3be1edfad6027c69f5491cf4cb310d1a71ecd6af742788c6ff8bced86b8fa" dependencies = [ "proc-macro-hack", "proc-macro2", "quote", + "standback", "syn", ] @@ -2689,9 +2712,9 @@ checksum = "e604eb7b43c06650e854be16a2a03155743d3752dd1c943f6829e26b7a36e382" [[package]] name = "trybuild" -version = "1.0.24" +version = "1.0.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24b4e093c5ed1a60b22557090120aa14f90ca801549c0949d775ea07c1407720" +checksum = "744665442556a91933cee5e75b0371376eb03498c4d0bfbcebd2a9882b4fb5ef" dependencies = [ "glob", "lazy_static", @@ -2703,9 +2726,9 @@ dependencies = [ [[package]] name = "typenum" -version = "1.11.2" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d2783fe2d6b8c1101136184eb41be8b1ad379e4657050b8aaff0c79ee7575f9" +checksum = "373c8a200f9e67a0c95e62a4f52fbf80c23b4381c05a17845531982fa99e6b33" [[package]] name = "unicode-bidi" @@ -2785,15 +2808,15 @@ checksum = "3fc439f2794e98976c88a2a2dafce96b930fe8010b0a256b3c2199a773933168" [[package]] name = "vec_map" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05c78687fb1a80548ae3250346c3db86a80a7cdd77bda190189f2d0a0987c81a" +checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191" [[package]] name = "version_check" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "078775d0255232fb988e6fccf26ddc9d1ac274299aaedcedce21c6f72cc533ce" +checksum = "b5a972e5669d67ba988ce3dc826706fb0a8b01471c088cb0b6110b805cc36aed" [[package]] name = "void" @@ -2949,9 +2972,9 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa515c5163a99cc82bab70fd3bfdd36d827be85de63737b40fcef2ce084a436e" +checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" dependencies = [ "winapi 0.3.8", ] diff --git a/sqlx-core/src/mysql/arguments.rs b/sqlx-core/src/mysql/arguments.rs index 1669743e7..11ddfe0e5 100644 --- a/sqlx-core/src/mysql/arguments.rs +++ b/sqlx-core/src/mysql/arguments.rs @@ -1,43 +1,51 @@ +use std::ops::{Deref, DerefMut}; + use crate::arguments::Arguments; use crate::encode::{Encode, IsNull}; -use crate::mysql::type_info::MySqlTypeInfo; -use crate::mysql::MySql; -use crate::types::Type; +use crate::mysql::{MySql, MySqlTypeInfo}; -#[derive(Default)] +/// Implementation of [`Arguments`] for MySQL. +#[derive(Debug, Default)] pub struct MySqlArguments { - pub(crate) param_types: Vec, - pub(crate) params: Vec, + pub(crate) values: Vec, + pub(crate) types: Vec, pub(crate) null_bitmap: Vec, } -impl Arguments for MySqlArguments { +impl<'q> Arguments<'q> for MySqlArguments { type Database = MySql; fn reserve(&mut self, len: usize, size: usize) { - self.param_types.reserve(len); - self.params.reserve(size); - - // ensure we have enough size in the bitmap to hold at least `len` extra bits - // the second `& 7` gives us 0 spare bits when param_types.len() is a multiple of 8 - let spare_bits = (8 - (self.param_types.len()) & 7) & 7; - // ensure that if there are no spare bits left, `len = 1` reserves another byte - self.null_bitmap.reserve((len + 7 - spare_bits) / 8); + self.types.reserve(len); + self.values.reserve(size); } fn add(&mut self, value: T) where - T: Type, - T: Encode, + T: Encode<'q, Self::Database>, { - let type_id = >::type_info(); - let index = self.param_types.len(); + let ty = value.produces(); + let index = self.types.len(); - self.param_types.push(type_id); + self.types.push(ty); self.null_bitmap.resize((index / 8) + 1, 0); - if let IsNull::Yes = value.encode_nullable(&mut self.params) { + if let IsNull::Yes = value.encode(self) { self.null_bitmap[index / 8] |= (1 << index % 8) as u8; } } } + +impl Deref for MySqlArguments { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.values + } +} + +impl DerefMut for MySqlArguments { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.values + } +} diff --git a/sqlx-core/src/mysql/connection.rs b/sqlx-core/src/mysql/connection.rs deleted file mode 100644 index 62b18aa97..000000000 --- a/sqlx-core/src/mysql/connection.rs +++ /dev/null @@ -1,345 +0,0 @@ -use std::borrow::Cow; -use std::collections::HashMap; -use std::convert::TryInto; -use std::ops::Range; - -use futures_core::future::BoxFuture; -use sha1::Sha1; - -use crate::connection::{Connect, Connection}; -use crate::executor::Executor; -use crate::mysql::protocol::{ - AuthPlugin, AuthSwitch, Capabilities, ComPing, Handshake, HandshakeResponse, -}; -use crate::mysql::stream::MySqlStream; -use crate::mysql::util::xor_eq; - -use crate::mysql::{rsa, tls}; -use crate::url::Url; - -// Size before a packet is split -pub(super) const MAX_PACKET_SIZE: u32 = 1024; - -pub(super) const COLLATE_UTF8MB4_UNICODE_CI: u8 = 224; - -/// An asynchronous connection to a [`MySql`] database. -/// -/// The connection string expected by `MySqlConnection` should be a MySQL connection -/// string, as documented at -/// -/// -/// ### TLS Support (requires `tls` feature) -/// This connection type supports some of the same flags as the `mysql` CLI application for SSL -/// connections, but they must be specified via the query segment of the connection string -/// rather than as program arguments. -/// -/// The same options for `--ssl-mode` are supported as the `ssl-mode` query parameter: -/// -/// -/// ```text -/// mysql://[:]@[:]/[?ssl-mode=[&ssl-ca=]] -/// ``` -/// where -/// ```text -/// ssl-mode = DISABLED | PREFERRED | REQUIRED | VERIFY_CA | VERIFY_IDENTITY -/// path = percent (URL) encoded path on the local machine -/// ``` -/// -/// If the `tls` feature is not enabled, `ssl-mode=DISABLED` and `ssl-mode=PREFERRED` are no-ops and -/// `ssl-mode=REQUIRED`, `ssl-mode=VERIFY_CA` and `ssl-mode=VERIFY_IDENTITY` are forbidden -/// (attempting to connect with these will return an error). -/// -/// If the `tls` feature is enabled, an upgrade to TLS is attempted on every connection by default -/// (equivalent to `ssl-mode=PREFERRED`). If the server does not support TLS (because `--ssl=0` was -/// passed to the server or an invalid certificate or key was used: -/// ) -/// then it falls back to an unsecured connection and logs a warning. -/// -/// Add `ssl-mode=REQUIRED` to your connection string to emit an error if the TLS upgrade fails. -/// -/// However, like with `mysql` the server certificate is **not** checked for validity by default. -/// -/// Specifying `ssl-mode=VERIFY_CA` will cause the TLS upgrade to verify the server's SSL -/// certificate against a local CA root certificate; this is not the system root certificate -/// but is instead expected to be specified as a local path with the `ssl-ca` query parameter -/// (percent-encoded so the URL remains valid). -/// -/// If you're running MySQL locally it might look something like this (for `VERIFY_CA`): -/// ```text -/// mysql://root:password@localhost/my_database?ssl-mode=VERIFY_CA&ssl-ca=%2Fvar%2Flib%2Fmysql%2Fca.pem -/// ``` -/// -/// `%2F` is the percent-encoding for forward slash (`/`). In the example we give `/var/lib/mysql/ca.pem` -/// as the CA certificate path, which is generated by the MySQL server automatically if -/// no certificate is manually specified. Note that the path may vary based on the default `my.cnf` -/// packaged with MySQL for your Linux distribution. Also note that unlike MySQL, MariaDB does *not* -/// generate certificates automatically and they must always be passed in to enable TLS. -/// -/// If `ssl-ca` is not specified or the file cannot be read, then an error is returned. -/// `ssl-ca` implies `ssl-mode=VERIFY_CA` so you only actually need to specify the former -/// but you may prefer having both to be more explicit. -/// -/// If `ssl-mode=VERIFY_IDENTITY` is specified, in addition to checking the certificate as with -/// `ssl-mode=VERIFY_CA`, the hostname in the connection string will be verified -/// against the hostname in the server certificate, so they must be the same for the TLS -/// upgrade to succeed. `ssl-ca` must still be specified. -pub struct MySqlConnection { - pub(super) stream: MySqlStream, - pub(super) is_ready: bool, - pub(super) cache_statement: HashMap, u32>, - - // Work buffer for the value ranges of the current row - // This is used as the backing memory for each Row's value indexes - pub(super) current_row_values: Vec>>, -} - -fn to_asciz(s: &str) -> Vec { - let mut z = String::with_capacity(s.len() + 1); - z.push_str(s); - z.push('\0'); - - z.into_bytes() -} - -async fn rsa_encrypt_with_nonce( - stream: &mut MySqlStream, - public_key_request_id: u8, - password: &str, - nonce: &[u8], -) -> crate::Result> { - // https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/ - - if stream.is_tls() { - // If in a TLS stream, send the password directly in clear text - return Ok(to_asciz(password)); - } - - // client sends a public key request - stream.send(&[public_key_request_id][..], false).await?; - - // server sends a public key response - let packet = stream.receive().await?; - let rsa_pub_key = &packet[1..]; - - // xor the password with the given nonce - let mut pass = to_asciz(password); - xor_eq(&mut pass, nonce); - - // client sends an RSA encrypted password - rsa::encrypt::(rsa_pub_key, &pass) -} - -async fn make_auth_response( - stream: &mut MySqlStream, - plugin: &AuthPlugin, - password: &str, - nonce: &[u8], -) -> crate::Result> { - if password.is_empty() { - // Empty password should not be sent - return Ok(vec![]); - } - - match plugin { - AuthPlugin::CachingSha2Password | AuthPlugin::MySqlNativePassword => { - Ok(plugin.scramble(password, nonce)) - } - - AuthPlugin::Sha256Password => rsa_encrypt_with_nonce(stream, 0x01, password, nonce).await, - } -} - -async fn establish(stream: &mut MySqlStream, url: &Url) -> crate::Result<()> { - // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase.html - // https://mariadb.com/kb/en/connection/ - - // Read a [Handshake] packet. When connecting to the database server, this is immediately - // received from the database server. - - let handshake = Handshake::read(stream.receive().await?)?; - let mut auth_plugin = handshake.auth_plugin; - let mut auth_plugin_data = handshake.auth_plugin_data; - - stream.capabilities &= handshake.server_capabilities; - stream.capabilities |= Capabilities::PROTOCOL_41; - - log::trace!("using capability flags: {:?}", stream.capabilities); - - // Depending on the ssl-mode and capabilities we should upgrade - // our connection to TLS - - tls::upgrade_if_needed(stream, url).await?; - - // Send a [HandshakeResponse] packet. This is returned in response to the [Handshake] packet - // that is immediately received. - - let password = &*url.password().unwrap_or_default(); - let auth_response = - make_auth_response(stream, &auth_plugin, password, &auth_plugin_data).await?; - - stream - .send( - HandshakeResponse { - client_collation: COLLATE_UTF8MB4_UNICODE_CI, - max_packet_size: MAX_PACKET_SIZE, - username: &url.username().unwrap_or(Cow::Borrowed("root")), - database: url.database(), - auth_plugin: &auth_plugin, - auth_response: &auth_response, - }, - false, - ) - .await?; - - loop { - // After sending the handshake response with our assumed auth method the server - // will send OK, fail, or tell us to change auth methods - let packet = stream.receive().await?; - - match packet[0] { - // OK - 0x00 => { - break; - } - - // ERROR - 0xFF => { - return stream.handle_err(); - } - - // AUTH_SWITCH - 0xFE => { - let auth = AuthSwitch::read(packet)?; - auth_plugin = auth.auth_plugin; - auth_plugin_data = auth.auth_plugin_data; - - let auth_response = - make_auth_response(stream, &auth_plugin, password, &auth_plugin_data).await?; - - stream.send(&*auth_response, false).await?; - } - - 0x01 if auth_plugin == AuthPlugin::CachingSha2Password => { - match packet[1] { - // AUTH_OK - 0x03 => {} - - // AUTH_CONTINUE - 0x04 => { - // The specific password is _not_ cached on the server - // We need to send a normal RSA-encrypted password for this - let enc = rsa_encrypt_with_nonce(stream, 0x02, password, &auth_plugin_data) - .await?; - - stream.send(&*enc, false).await?; - } - - unk => { - return Err(protocol_err!("unexpected result from 'fast' authentication 0x{:x} when expecting OK (0x03) or CONTINUE (0x04)", unk).into()); - } - } - } - - _ => { - return stream.handle_unexpected(); - } - } - } - - Ok(()) -} - -async fn close(mut stream: MySqlStream) -> crate::Result<()> { - // TODO: Actually tell MySQL that we're closing - - stream.flush().await?; - stream.shutdown()?; - - Ok(()) -} - -async fn ping(stream: &mut MySqlStream) -> crate::Result<()> { - stream.wait_until_ready().await?; - stream.is_ready = false; - - stream.send(ComPing, true).await?; - - match stream.receive().await?[0] { - 0x00 | 0xFE => stream.handle_ok().map(drop), - - 0xFF => stream.handle_err(), - - _ => stream.handle_unexpected(), - } -} - -impl MySqlConnection { - pub(super) async fn new(url: std::result::Result) -> crate::Result { - let url = url?; - let mut stream = MySqlStream::new(&url).await?; - - establish(&mut stream, &url).await?; - - let mut self_ = Self { - stream, - current_row_values: Vec::with_capacity(10), - is_ready: true, - cache_statement: HashMap::new(), - }; - - // After the connection is established, we initialize by configuring a few - // connection parameters - - // https://mariadb.com/kb/en/sql-mode/ - - // PIPES_AS_CONCAT - Allows using the pipe character (ASCII 124) as string concatenation operator. - // This means that "A" || "B" can be used in place of CONCAT("A", "B"). - - // NO_ENGINE_SUBSTITUTION - If not set, if the available storage engine specified by a CREATE TABLE is - // not available, a warning is given and the default storage - // engine is used instead. - - // NO_ZERO_DATE - Don't allow '0000-00-00'. This is invalid in Rust. - - // NO_ZERO_IN_DATE - Don't allow 'YYYY-00-00'. This is invalid in Rust. - - // -- - - // Setting the time zone allows us to assume that the output - // from a TIMESTAMP field is UTC - - // -- - - // https://mathiasbynens.be/notes/mysql-utf8mb4 - - self_.execute(r#" -SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE')); -SET time_zone = '+00:00'; -SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci; - "#).await?; - - Ok(self_) - } -} - -impl Connect for MySqlConnection { - fn connect(url: T) -> BoxFuture<'static, crate::Result> - where - T: TryInto, - Self: Sized, - { - Box::pin(MySqlConnection::new(url.try_into())) - } -} - -impl Connection for MySqlConnection { - #[inline] - fn close(self) -> BoxFuture<'static, crate::Result<()>> { - Box::pin(close(self.stream)) - } - - #[inline] - fn ping(&mut self) -> BoxFuture> { - Box::pin(ping(&mut self.stream)) - } -} diff --git a/sqlx-core/src/mysql/connection/auth.rs b/sqlx-core/src/mysql/connection/auth.rs new file mode 100644 index 000000000..31a7881e0 --- /dev/null +++ b/sqlx-core/src/mysql/connection/auth.rs @@ -0,0 +1,176 @@ +use bytes::buf::ext::Chain; +use bytes::Bytes; +use digest::{Digest, FixedOutput}; +use generic_array::GenericArray; +use sha1::Sha1; +use sha2::Sha256; + +use crate::error::Error; +use crate::mysql::connection::stream::MySqlStream; +use crate::mysql::protocol::auth::AuthPlugin; +use crate::mysql::protocol::rsa; +use crate::mysql::protocol::Packet; + +impl AuthPlugin { + pub(super) async fn scramble( + self, + stream: &mut MySqlStream, + password: &str, + nonce: &Chain, + ) -> Result, Error> { + match self { + // https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/ + AuthPlugin::CachingSha2Password => Ok(scramble_sha256(password, nonce).to_vec()), + + AuthPlugin::MySqlNativePassword => Ok(scramble_sha1(password, nonce).to_vec()), + + // https://mariadb.com/kb/en/sha256_password-plugin/ + AuthPlugin::Sha256Password => encrypt_rsa(stream, 0x01, password, nonce).await, + } + } + + pub(super) async fn handle( + self, + stream: &mut MySqlStream, + packet: Packet, + password: &str, + nonce: &Chain, + ) -> Result { + match self { + AuthPlugin::CachingSha2Password if packet[0] == 0x01 => { + match packet[1] { + // AUTH_OK + 0x03 => Ok(true), + + // AUTH_CONTINUE + 0x04 => { + let payload = encrypt_rsa(stream, 0x02, password, nonce).await?; + + stream.write_packet(&*payload); + stream.flush().await?; + + Ok(false) + } + + v => { + Err(err_protocol!("unexpected result from fast authentication 0x{:x} when expecting 0x03 (AUTH_OK) or 0x04 (AUTH_CONTINUE)", v)) + } + } + } + + _ => Err(err_protocol!( + "unexpected packet 0x{:02x} for auth plugin '{}' during authentication", + packet[0], + self.name() + )), + } + } +} + +fn scramble_sha1( + password: &str, + nonce: &Chain, +) -> GenericArray::OutputSize> { + // SHA1( password ) ^ SHA1( seed + SHA1( SHA1( password ) ) ) + // https://mariadb.com/kb/en/connection/#mysql_native_password-plugin + + let mut ctx = Sha1::new(); + + ctx.input(password); + + let mut pw_hash = ctx.result_reset(); + + ctx.input(&pw_hash); + + let pw_hash_hash = ctx.result_reset(); + + ctx.input(nonce.first_ref()); + ctx.input(nonce.last_ref()); + ctx.input(pw_hash_hash); + + let pw_seed_hash_hash = ctx.result(); + + xor_eq(&mut pw_hash, &pw_seed_hash_hash); + + pw_hash +} + +fn scramble_sha256( + password: &str, + nonce: &Chain, +) -> GenericArray::OutputSize> { + // XOR(SHA256(password), SHA256(seed, SHA256(SHA256(password)))) + // https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/#sha-2-encrypted-password + let mut ctx = Sha256::new(); + + ctx.input(password); + + let mut pw_hash = ctx.result_reset(); + + ctx.input(&pw_hash); + + let pw_hash_hash = ctx.result_reset(); + + ctx.input(nonce.first_ref()); + ctx.input(nonce.last_ref()); + ctx.input(pw_hash_hash); + + let pw_seed_hash_hash = ctx.result(); + + xor_eq(&mut pw_hash, &pw_seed_hash_hash); + + pw_hash +} + +async fn encrypt_rsa<'s>( + stream: &'s mut MySqlStream, + public_key_request_id: u8, + password: &'s str, + nonce: &'s Chain, +) -> Result, Error> { + // https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/ + + if stream.is_tls() { + // If in a TLS stream, send the password directly in clear text + return Ok(to_asciz(password)); + } + + // client sends a public key request + stream.write_packet(&[public_key_request_id][..]); + stream.flush().await?; + + // server sends a public key response + let packet = stream.recv_packet().await?; + let rsa_pub_key = &packet[1..]; + + // xor the password with the given nonce + let mut pass = to_asciz(password); + + let (a, b) = (nonce.first_ref(), nonce.last_ref()); + let mut nonce = Vec::with_capacity(a.len() + b.len()); + nonce.extend_from_slice(&*a); + nonce.extend_from_slice(&*b); + + xor_eq(&mut pass, &*nonce); + + // client sends an RSA encrypted password + rsa::encrypt::(rsa_pub_key, &pass) +} + +// XOR(x, y) +// If len(y) < len(x), wrap around inside y +fn xor_eq(x: &mut [u8], y: &[u8]) { + let y_len = y.len(); + + for i in 0..x.len() { + x[i] ^= y[i % y_len]; + } +} + +fn to_asciz(s: &str) -> Vec { + let mut z = String::with_capacity(s.len() + 1); + z.push_str(s); + z.push('\0'); + + z.into_bytes() +} diff --git a/sqlx-core/src/mysql/connection/establish.rs b/sqlx-core/src/mysql/connection/establish.rs new file mode 100644 index 000000000..5bfdb183f --- /dev/null +++ b/sqlx-core/src/mysql/connection/establish.rs @@ -0,0 +1,101 @@ +use bytes::Bytes; +use hashbrown::HashMap; + +use crate::error::Error; +use crate::mysql::connection::{tls, MySqlStream, COLLATE_UTF8MB4_UNICODE_CI, MAX_PACKET_SIZE}; +use crate::mysql::protocol::connect::{ + AuthSwitchRequest, AuthSwitchResponse, Handshake, HandshakeResponse, +}; +use crate::mysql::protocol::Capabilities; +use crate::mysql::{MySqlConnectOptions, MySqlConnection}; +use bytes::buf::BufExt; + +impl MySqlConnection { + pub(crate) async fn establish(options: &MySqlConnectOptions) -> Result { + let mut stream: MySqlStream = MySqlStream::connect(options).await?; + + // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase.html + // https://mariadb.com/kb/en/connection/ + + let handshake: Handshake = stream.recv_packet().await?.decode()?; + + let mut plugin = handshake.auth_plugin; + let mut nonce = handshake.auth_plugin_data; + + stream.capabilities &= handshake.server_capabilities; + stream.capabilities |= Capabilities::PROTOCOL_41; + + // Upgrade to TLS if we were asked to and the server supports it + tls::maybe_upgrade(&mut stream, options).await?; + + let auth_response = if let (Some(plugin), Some(password)) = (plugin, &options.password) { + Some(plugin.scramble(&mut stream, password, &nonce).await?) + } else { + None + }; + + stream.write_packet(HandshakeResponse { + char_set: COLLATE_UTF8MB4_UNICODE_CI, + max_packet_size: MAX_PACKET_SIZE, + username: &options.username, + database: options.database.as_deref(), + auth_plugin: plugin, + auth_response: auth_response.as_deref(), + }); + + stream.flush().await?; + + loop { + let packet = stream.recv_packet().await?; + match packet[0] { + 0x00 => { + let _ok = packet.ok()?; + + break; + } + + 0xfe => { + let switch: AuthSwitchRequest = packet.decode()?; + + plugin = Some(switch.plugin); + nonce = switch.data.chain(Bytes::new()); + + let response = switch + .plugin + .scramble( + &mut stream, + options.password.as_deref().unwrap_or_default(), + &nonce, + ) + .await?; + + stream.write_packet(AuthSwitchResponse(response)); + stream.flush().await?; + } + + id => { + if let (Some(plugin), Some(password)) = (plugin, &options.password) { + if plugin.handle(&mut stream, packet, password, &nonce).await? { + // plugin signaled authentication is ok + break; + } + + // plugin signaled to continue authentication + } else { + return Err(err_protocol!( + "unexpected packet 0x{:02x} during authentication", + id + )); + } + } + } + } + + Ok(Self { + stream, + cache_statement: HashMap::new(), + scratch_row_columns: Default::default(), + scratch_row_column_names: Default::default(), + }) + } +} diff --git a/sqlx-core/src/mysql/connection/executor.rs b/sqlx-core/src/mysql/connection/executor.rs new file mode 100644 index 000000000..d5d196139 --- /dev/null +++ b/sqlx-core/src/mysql/connection/executor.rs @@ -0,0 +1,285 @@ +use std::sync::Arc; + +use async_stream::try_stream; +use bytes::Bytes; +use either::Either; +use futures_core::future::BoxFuture; +use futures_core::stream::BoxStream; +use futures_core::Stream; +use futures_util::{pin_mut, TryStreamExt}; + +use crate::describe::{Column, Describe}; +use crate::error::Error; +use crate::executor::{Execute, Executor}; +use crate::ext::ustr::UStr; +use crate::mysql::io::MySqlBufExt; +use crate::mysql::protocol::response::Status; +use crate::mysql::protocol::statement::{ + BinaryRow, Execute as StatementExecute, Prepare, PrepareOk, +}; +use crate::mysql::protocol::text::{ColumnDefinition, ColumnFlags, Query, TextRow}; +use crate::mysql::protocol::Packet; +use crate::mysql::row::MySqlColumn; +use crate::mysql::{ + MySql, MySqlArguments, MySqlConnection, MySqlRow, MySqlTypeInfo, MySqlValueFormat, +}; + +impl MySqlConnection { + async fn prepare(&mut self, query: &str) -> Result { + if let Some(&statement) = self.cache_statement.get(query) { + return Ok(statement); + } + + // https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html + // https://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html#packet-COM_STMT_PREPARE_OK + + self.stream.send_packet(Prepare { query }).await?; + + let ok: PrepareOk = self.stream.recv().await?; + + // the parameter definitions are very unreliable so we skip over them + // as we have little use + + if ok.params > 0 { + for _ in 0..ok.params { + let _def: ColumnDefinition = self.stream.recv().await?; + } + + self.stream.maybe_recv_eof().await?; + } + + // the column definitions are berefit the type information from the + // to-be-bound parameters; we will receive the output column definitions + // once more on execute so we wait for that + + if ok.columns > 0 { + for _ in 0..(ok.columns as usize) { + let _def: ColumnDefinition = self.stream.recv().await?; + } + + self.stream.maybe_recv_eof().await?; + } + + self.cache_statement + .insert(query.to_owned(), ok.statement_id); + + Ok(ok.statement_id) + } + + async fn recv_result_metadata(&mut self, mut packet: Packet) -> Result<(), Error> { + let num_columns: u64 = packet.get_uint_lenenc(); // column count + + // the result-set metadata is primarily a listing of each output + // column in the result-set + + let column_names = Arc::make_mut(&mut self.scratch_row_column_names); + let columns = Arc::make_mut(&mut self.scratch_row_columns); + + columns.clear(); + column_names.clear(); + + for i in 0..num_columns { + let def: ColumnDefinition = self.stream.recv().await?; + + let name = (match (def.name()?, def.alias()?) { + (_, alias) if !alias.is_empty() => Some(alias), + + (name, _) if !name.is_empty() => Some(name), + + _ => None, + }) + .map(UStr::new); + + if let Some(name) = &name { + column_names.insert(name.clone(), i as usize); + } + + let type_info = MySqlTypeInfo::from_column(&def); + + columns.push(MySqlColumn { name, type_info }); + } + + self.stream.maybe_recv_eof().await?; + + Ok(()) + } + + async fn run<'c>( + &'c mut self, + query: &str, + arguments: Option, + ) -> Result, Error>> + 'c, Error> { + self.stream.wait_until_ready().await?; + self.stream.busy = true; + + let format = if let Some(arguments) = arguments { + let statement = self.prepare(query).await?; + + // https://dev.mysql.com/doc/internals/en/com-stmt-execute.html + self.stream + .send_packet(StatementExecute { + statement, + arguments: &arguments, + }) + .await?; + + MySqlValueFormat::Binary + } else { + // https://dev.mysql.com/doc/internals/en/com-query.html + self.stream.send_packet(Query(query)).await?; + + MySqlValueFormat::Text + }; + + Ok(try_stream! { + loop { + // query response is a meta-packet which may be one of: + // Ok, Err, ResultSet, or (unhandled) LocalInfileRequest + let mut packet = self.stream.recv_packet().await?; + + if packet[0] == 0x00 || packet[0] == 0xff { + // first packet in a query response is OK or ERR + // this indicates either a successful query with no rows at all or a failed query + let ok = packet.ok()?; + let v = Either::Left(ok.affected_rows); + + yield v; + + if ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { + // more result sets exist, continue to the next one + continue; + } + + self.stream.busy = false; + return; + } + + // otherwise, this first packet is the start of the result-set metadata, + self.recv_result_metadata(packet).await?; + + // finally, there will be none or many result-rows + loop { + let packet = self.stream.recv_packet().await?; + + if packet[0] == 0xfe && packet.len() < 9 { + let eof = packet.eof(self.stream.capabilities)?; + let v = Either::Left(0); + + yield v; + + if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { + // more result sets exist, continue to the next one + break; + } + + self.stream.busy = false; + return; + } + + let row = match format { + MySqlValueFormat::Binary => packet.decode_with::(&self.scratch_row_columns)?.0, + MySqlValueFormat::Text => packet.decode_with::(&self.scratch_row_columns)?.0, + }; + + let v = Either::Right(MySqlRow { + row, + format, + columns: Arc::clone(&self.scratch_row_columns), + column_names: Arc::clone(&self.scratch_row_column_names), + }); + + yield v; + } + } + }) + } +} + +impl<'c> Executor<'c> for &'c mut MySqlConnection { + type Database = MySql; + + fn fetch_many<'q: 'c, E>( + self, + mut query: E, + ) -> BoxStream<'c, Result, Error>> + where + E: Execute<'q, Self::Database>, + { + let s = query.query(); + let arguments = query.take_arguments(); + + Box::pin(try_stream! { + let s = self.run(s, arguments).await?; + pin_mut!(s); + + while let Some(v) = s.try_next().await? { + yield v; + } + }) + } + + fn fetch_optional<'q: 'c, E>(self, query: E) -> BoxFuture<'c, Result, Error>> + where + E: Execute<'q, Self::Database>, + { + let mut s = self.fetch_many(query); + + Box::pin(async move { + while let Some(v) = s.try_next().await? { + if let Either::Right(r) = v { + return Ok(Some(r)); + } + } + + Ok(None) + }) + } + + #[doc(hidden)] + fn describe<'q: 'c, E>(self, query: E) -> BoxFuture<'c, Result, Error>> + where + E: Execute<'q, Self::Database>, + { + let query = query.query(); + + Box::pin(async move { + self.stream.send_packet(Prepare { query }).await?; + + let ok: PrepareOk = self.stream.recv().await?; + + let mut params = Vec::with_capacity(ok.params as usize); + let mut columns = Vec::with_capacity(ok.columns as usize); + + if ok.params > 0 { + for _ in 0..ok.params { + let def: ColumnDefinition = self.stream.recv().await?; + + params.push(MySqlTypeInfo::from_column(&def)); + } + + self.stream.maybe_recv_eof().await?; + } + + // the column definitions are berefit the type information from the + // to-be-bound parameters; we will receive the output column definitions + // once more on execute so we wait for that + + if ok.columns > 0 { + for _ in 0..(ok.columns as usize) { + let def: ColumnDefinition = self.stream.recv().await?; + let ty = MySqlTypeInfo::from_column(&def); + + columns.push(Column { + name: def.name()?.to_owned(), + type_info: ty, + not_null: Some(def.flags.contains(ColumnFlags::NOT_NULL)), + }) + } + + self.stream.maybe_recv_eof().await?; + } + + Ok(Describe { params, columns }) + }) + } +} diff --git a/sqlx-core/src/mysql/connection/mod.rs b/sqlx-core/src/mysql/connection/mod.rs new file mode 100644 index 000000000..dd353bb0e --- /dev/null +++ b/sqlx-core/src/mysql/connection/mod.rs @@ -0,0 +1,116 @@ +use std::fmt::{self, Debug, Formatter}; +use std::net::Shutdown; +use std::sync::Arc; + +use futures_core::future::BoxFuture; +use hashbrown::HashMap; + +use crate::connection::{Connect, Connection}; +use crate::error::Error; +use crate::executor::Executor; +use crate::ext::ustr::UStr; +use crate::mysql::protocol::text::{Ping, Quit}; +use crate::mysql::row::MySqlColumn; +use crate::mysql::{MySql, MySqlConnectOptions}; + +mod auth; +mod establish; +mod executor; +mod stream; +mod tls; + +pub(crate) use stream::MySqlStream; + +const COLLATE_UTF8MB4_UNICODE_CI: u8 = 224; + +const MAX_PACKET_SIZE: u32 = 1024; + +/// A connection to a MySQL database. +pub struct MySqlConnection { + // underlying TCP stream, + // wrapped in a potentially TLS stream, + // wrapped in a buffered stream + stream: MySqlStream, + + // cache by query string to the statement id + cache_statement: HashMap, + + // working memory for the active row's column information + // this allows us to re-use these allocations unless the user is persisting the + // Row type past a stream iteration (clone-on-write) + scratch_row_columns: Arc>, + scratch_row_column_names: Arc>, +} + +impl Debug for MySqlConnection { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("MySqlConnection").finish() + } +} + +impl Connection for MySqlConnection { + type Database = MySql; + + fn close(mut self) -> BoxFuture<'static, Result<(), Error>> { + Box::pin(async move { + self.stream.send_packet(Quit).await?; + self.stream.shutdown(Shutdown::Both)?; + + Ok(()) + }) + } + + fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { + self.stream.wait_until_ready().await?; + self.stream.send_packet(Ping).await?; + self.stream.recv_ok().await?; + + Ok(()) + }) + } +} + +impl Connect for MySqlConnection { + type Options = MySqlConnectOptions; + + #[inline] + fn connect_with(options: &Self::Options) -> BoxFuture<'_, Result> { + Box::pin(async move { + let mut conn = MySqlConnection::establish(options).await?; + + // After the connection is established, we initialize by configuring a few + // connection parameters + + // https://mariadb.com/kb/en/sql-mode/ + + // PIPES_AS_CONCAT - Allows using the pipe character (ASCII 124) as string concatenation operator. + // This means that "A" || "B" can be used in place of CONCAT("A", "B"). + + // NO_ENGINE_SUBSTITUTION - If not set, if the available storage engine specified by a CREATE TABLE is + // not available, a warning is given and the default storage + // engine is used instead. + + // NO_ZERO_DATE - Don't allow '0000-00-00'. This is invalid in Rust. + + // NO_ZERO_IN_DATE - Don't allow 'YYYY-00-00'. This is invalid in Rust. + + // -- + + // Setting the time zone allows us to assume that the output + // from a TIMESTAMP field is UTC + + // -- + + // https://mathiasbynens.be/notes/mysql-utf8mb4 + + conn.execute(r#" +SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE')); +SET time_zone = '+00:00'; +SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci; + "#).await?; + + Ok(conn) + }) + } +} diff --git a/sqlx-core/src/mysql/connection/stream.rs b/sqlx-core/src/mysql/connection/stream.rs new file mode 100644 index 000000000..6f6a2190f --- /dev/null +++ b/sqlx-core/src/mysql/connection/stream.rs @@ -0,0 +1,150 @@ +use std::ops::{Deref, DerefMut}; + +use bytes::{Buf, Bytes}; +use sqlx_rt::TcpStream; + +use crate::error::Error; +use crate::io::{BufStream, Decode, Encode}; +use crate::mysql::protocol::response::{EofPacket, ErrPacket, OkPacket}; +use crate::mysql::protocol::{Capabilities, Packet}; +use crate::mysql::{MySqlConnectOptions, MySqlDatabaseError}; +use crate::net::MaybeTlsStream; + +pub struct MySqlStream { + stream: BufStream>, + pub(super) capabilities: Capabilities, + pub(super) sequence_id: u8, + pub(crate) busy: bool, +} + +impl MySqlStream { + pub(super) async fn connect(options: &MySqlConnectOptions) -> Result { + let stream = TcpStream::connect((&*options.host, options.port)).await?; + + let mut capabilities = Capabilities::PROTOCOL_41 + | Capabilities::IGNORE_SPACE + | Capabilities::DEPRECATE_EOF + | Capabilities::FOUND_ROWS + | Capabilities::TRANSACTIONS + | Capabilities::SECURE_CONNECTION + | Capabilities::PLUGIN_AUTH_LENENC_DATA + | Capabilities::MULTI_STATEMENTS + | Capabilities::MULTI_RESULTS + | Capabilities::PLUGIN_AUTH + | Capabilities::PS_MULTI_RESULTS + | Capabilities::SSL; + + if options.database.is_some() { + capabilities |= Capabilities::CONNECT_WITH_DB; + } + + Ok(Self { + busy: false, + capabilities, + sequence_id: 0, + stream: BufStream::new(MaybeTlsStream::Raw(stream)), + }) + } + + pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> { + if self.busy { + loop { + let packet = self.recv_packet().await?; + match packet[0] { + 0xfe if packet.len() < 9 => { + // OK or EOF packet + self.busy = false; + break; + } + + _ => { + // Something else; skip + } + } + } + } + + Ok(()) + } + + pub(crate) async fn send_packet<'en, T>(&mut self, payload: T) -> Result<(), Error> + where + T: Encode<'en, Capabilities>, + { + self.sequence_id = 0; + self.write_packet(payload); + self.flush().await + } + + pub(crate) fn write_packet<'en, T>(&mut self, payload: T) + where + T: Encode<'en, Capabilities>, + { + self.stream + .write_with(Packet(payload), (self.capabilities, &mut self.sequence_id)); + } + + // receive the next packet from the database server + // may block (async) on more data from the server + pub(crate) async fn recv_packet(&mut self) -> Result, Error> { + // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html + // https://mariadb.com/kb/en/library/0-packet/#standard-packet + + let mut header: Bytes = self.stream.read(4).await?; + + let packet_size = header.get_uint_le(3) as usize; + let sequence_id = header.get_u8(); + + self.sequence_id = sequence_id.wrapping_add(1); + + let payload: Bytes = self.stream.read(packet_size).await?; + + // TODO: packet compression + // TODO: packet joining + + if payload[0] == 0xff { + self.busy = false; + + // instead of letting this packet be looked at everywhere, we check here + // and emit a proper Error + return Err( + MySqlDatabaseError(ErrPacket::decode_with(payload, self.capabilities)?).into(), + ); + } + + Ok(Packet(payload)) + } + + pub(crate) async fn recv<'de, T>(&mut self) -> Result + where + T: Decode<'de, Capabilities>, + { + self.recv_packet().await?.decode_with(self.capabilities) + } + + pub(crate) async fn recv_ok(&mut self) -> Result { + self.recv_packet().await?.ok() + } + + pub(crate) async fn maybe_recv_eof(&mut self) -> Result, Error> { + if self.capabilities.contains(Capabilities::DEPRECATE_EOF) { + Ok(None) + } else { + self.recv().await.map(Some) + } + } +} + +impl Deref for MySqlStream { + type Target = BufStream>; + + fn deref(&self) -> &Self::Target { + &self.stream + } +} + +impl DerefMut for MySqlStream { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.stream + } +} diff --git a/sqlx-core/src/mysql/connection/tls.rs b/sqlx-core/src/mysql/connection/tls.rs new file mode 100644 index 000000000..256944328 --- /dev/null +++ b/sqlx-core/src/mysql/connection/tls.rs @@ -0,0 +1,79 @@ +use sqlx_rt::{ + fs, + native_tls::{Certificate, TlsConnector}, +}; + +use crate::error::Error; +use crate::mysql::connection::MySqlStream; +use crate::mysql::protocol::connect::SslRequest; +use crate::mysql::protocol::Capabilities; +use crate::mysql::{MySqlConnectOptions, MySqlSslMode}; + +pub(super) async fn maybe_upgrade( + stream: &mut MySqlStream, + options: &MySqlConnectOptions, +) -> Result<(), Error> { + // https://www.postgresql.org/docs/12/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS + match options.ssl_mode { + MySqlSslMode::Disabled => {} + + MySqlSslMode::Preferred => { + // try upgrade, but its okay if we fail + upgrade(stream, options).await?; + } + + MySqlSslMode::Required | MySqlSslMode::VerifyIdentity | MySqlSslMode::VerifyCa => { + if !upgrade(stream, options).await? { + // upgrade failed, die + return Err(Error::Tls("server does not support TLS".into())); + } + } + } + + Ok(()) +} + +async fn upgrade(stream: &mut MySqlStream, options: &MySqlConnectOptions) -> Result { + if !stream.capabilities.contains(Capabilities::SSL) { + // server does not support TLS + return Ok(false); + } + + stream.write_packet(SslRequest { + max_packet_size: super::MAX_PACKET_SIZE, + char_set: super::COLLATE_UTF8MB4_UNICODE_CI, + }); + + stream.flush().await?; + + // FIXME: de-duplicate with postgres/connection/tls.rs + + let accept_invalid_certs = !matches!( + options.ssl_mode, + MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity + ); + + let mut builder = TlsConnector::builder(); + builder + .danger_accept_invalid_certs(accept_invalid_certs) + .danger_accept_invalid_hostnames(!matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity)); + + if !accept_invalid_certs { + if let Some(ca) = &options.ssl_ca { + let data = fs::read(ca).await?; + let cert = Certificate::from_pem(&data).map_err(Error::tls)?; + + builder.add_root_certificate(cert); + } + } + + #[cfg(not(feature = "runtime-async-std"))] + let connector = builder.build().map_err(Error::tls)?; + + #[cfg(feature = "runtime-async-std")] + let connector = builder; + + stream.upgrade(&options.host, connector.into()).await?; + + Ok(true) +} diff --git a/sqlx-core/src/mysql/cursor.rs b/sqlx-core/src/mysql/cursor.rs deleted file mode 100644 index 250738f13..000000000 --- a/sqlx-core/src/mysql/cursor.rs +++ /dev/null @@ -1,164 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; - -use futures_core::future::BoxFuture; - -use crate::connection::ConnectionSource; -use crate::cursor::Cursor; -use crate::executor::Execute; -use crate::mysql::protocol::{ColumnCount, ColumnDefinition, Row, Status}; -use crate::mysql::{MySql, MySqlArguments, MySqlConnection, MySqlRow, MySqlTypeInfo}; -use crate::pool::Pool; - -pub struct MySqlCursor<'c, 'q> { - source: ConnectionSource<'c, MySqlConnection>, - query: Option<(&'q str, Option)>, - column_names: Arc, u16>>, - column_types: Vec, - binary: bool, -} - -impl crate::cursor::private::Sealed for MySqlCursor<'_, '_> {} - -impl<'c, 'q> Cursor<'c, 'q> for MySqlCursor<'c, 'q> { - type Database = MySql; - - #[doc(hidden)] - fn from_pool(pool: &Pool, query: E) -> Self - where - Self: Sized, - E: Execute<'q, MySql>, - { - Self { - source: ConnectionSource::Pool(pool.clone()), - column_names: Arc::default(), - column_types: Vec::new(), - binary: true, - query: Some(query.into_parts()), - } - } - - #[doc(hidden)] - fn from_connection(conn: &'c mut MySqlConnection, query: E) -> Self - where - Self: Sized, - E: Execute<'q, MySql>, - { - Self { - source: ConnectionSource::ConnectionRef(conn), - column_names: Arc::default(), - column_types: Vec::new(), - binary: true, - query: Some(query.into_parts()), - } - } - - fn next(&mut self) -> BoxFuture>>> { - Box::pin(next(self)) - } -} - -async fn next<'a, 'c: 'a, 'q: 'a>( - cursor: &'a mut MySqlCursor<'c, 'q>, -) -> crate::Result>> { - let mut conn = cursor.source.resolve().await?; - - // The first time [next] is called we need to actually execute our - // contained query. We guard against this happening on _all_ next calls - // by using [Option::take] which replaces the potential value in the Option with `None - let mut initial = if let Some((query, arguments)) = cursor.query.take() { - let statement = conn.run(query, arguments).await?; - - // No statement ID = TEXT mode - cursor.binary = statement.is_some(); - - true - } else { - false - }; - - loop { - let packet_id = conn.stream.receive().await?[0]; - - match packet_id { - // OK or EOF packet - 0x00 | 0xFE - if conn.stream.packet().len() < 0xFF_FF_FF && (packet_id != 0x00 || initial) => - { - let status = if let Some(eof) = conn.stream.maybe_handle_eof()? { - eof.status - } else { - conn.stream.handle_ok()?.status - }; - - if status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { - // There is more to this query - initial = true; - } else { - conn.is_ready = true; - return Ok(None); - } - } - - // ERR packet - 0xFF => { - conn.is_ready = true; - return conn.stream.handle_err(); - } - - _ if initial => { - // At the start of the results we expect to see a - // COLUMN_COUNT followed by N COLUMN_DEF - - let cc = ColumnCount::read(conn.stream.packet())?; - - // We use these definitions to get the actual column types that is critical - // in parsing the rows coming back soon - - cursor.column_types.clear(); - cursor.column_types.reserve(cc.columns as usize); - - let mut column_names = HashMap::with_capacity(cc.columns as usize); - - for i in 0..cc.columns { - let column = ColumnDefinition::read(conn.stream.receive().await?)?; - - cursor - .column_types - .push(MySqlTypeInfo::from_nullable_column_def(&column)); - - if let Some(name) = column.name() { - column_names.insert(name.to_owned().into_boxed_str(), i as u16); - } - } - - if cc.columns > 0 { - conn.stream.maybe_receive_eof().await?; - } - - cursor.column_names = Arc::new(column_names); - initial = false; - } - - _ if !cursor.binary || packet_id == 0x00 => { - let row = Row::read( - conn.stream.packet(), - &cursor.column_types, - &mut conn.current_row_values, - cursor.binary, - )?; - - let row = MySqlRow { - row, - names: Arc::clone(&cursor.column_names), - }; - - return Ok(Some(row)); - } - - _ => { - return conn.stream.handle_unexpected(); - } - } - } -} diff --git a/sqlx-core/src/mysql/database.rs b/sqlx-core/src/mysql/database.rs index e75d77328..f6cc64f0c 100644 --- a/sqlx-core/src/mysql/database.rs +++ b/sqlx-core/src/mysql/database.rs @@ -1,41 +1,31 @@ -use crate::cursor::HasCursor; -use crate::database::Database; -use crate::mysql::error::MySqlError; -use crate::row::HasRow; -use crate::value::HasRawValue; +use crate::database::{Database, HasArguments, HasValueRef}; +use crate::mysql::value::{MySqlValue, MySqlValueRef}; +use crate::mysql::{MySqlArguments, MySqlConnection, MySqlRow, MySqlTypeInfo}; -/// **MySQL** database driver. +/// PostgreSQL database driver. #[derive(Debug)] pub struct MySql; impl Database for MySql { - type Connection = super::MySqlConnection; + type Connection = MySqlConnection; - type Arguments = super::MySqlArguments; + type Row = MySqlRow; - type TypeInfo = super::MySqlTypeInfo; + type TypeInfo = MySqlTypeInfo; - type TableId = Box; - - type RawBuffer = Vec; - - type Error = MySqlError; + type Value = MySqlValue; } -impl<'c> HasRow<'c> for MySql { +impl<'r> HasValueRef<'r> for MySql { type Database = MySql; - type Row = super::MySqlRow<'c>; + type ValueRef = MySqlValueRef<'r>; } -impl<'c, 'q> HasCursor<'c, 'q> for MySql { +impl HasArguments<'_> for MySql { type Database = MySql; - type Cursor = super::MySqlCursor<'c, 'q>; -} + type Arguments = MySqlArguments; -impl<'c> HasRawValue<'c> for MySql { - type Database = MySql; - - type RawValue = super::MySqlValue<'c>; + type ArgumentBuffer = Vec; } diff --git a/sqlx-core/src/mysql/error.rs b/sqlx-core/src/mysql/error.rs index 63b714240..ae530f1c4 100644 --- a/sqlx-core/src/mysql/error.rs +++ b/sqlx-core/src/mysql/error.rs @@ -1,63 +1,64 @@ -use std::error::Error as StdError; -use std::fmt::{self, Display}; +use std::error::Error; +use std::fmt::{self, Debug, Display, Formatter}; use crate::error::DatabaseError; -use crate::mysql::protocol::ErrPacket; +use crate::mysql::protocol::response::ErrPacket; +use smallvec::alloc::borrow::Cow; -#[derive(Debug)] -pub struct MySqlError(pub(super) ErrPacket); +/// An error returned from the MySQL database. +pub struct MySqlDatabaseError(pub(super) ErrPacket); -impl Display for MySqlError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.pad(self.message()) - } -} - -impl DatabaseError for MySqlError { - fn message(&self) -> &str { - &*self.0.error_message - } - - fn code(&self) -> Option<&str> { +impl MySqlDatabaseError { + /// The [SQLSTATE](https://dev.mysql.com/doc/refman/8.0/en/server-error-reference.html) code for this error. + pub fn code(&self) -> Option<&str> { self.0.sql_state.as_deref() } - fn as_ref_err(&self) -> &(dyn StdError + Send + Sync + 'static) { - self + /// The [number](https://dev.mysql.com/doc/refman/8.0/en/server-error-reference.html) + /// for this error. + /// + /// MySQL tends to use SQLSTATE as a general error category, and the error number as a more + /// granular indication of the error. + pub fn number(&self) -> u16 { + self.0.error_code } - fn as_mut_err(&mut self) -> &mut (dyn StdError + Send + Sync + 'static) { - self - } - - fn into_box_err(self: Box) -> Box { - self + /// The human-readable error message. + pub fn message(&self) -> &str { + &self.0.error_message } } -impl StdError for MySqlError {} - -impl From for crate::Error { - fn from(err: MySqlError) -> Self { - crate::Error::Database(Box::new(err)) +impl Debug for MySqlDatabaseError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("MySqlDatabaseError") + .field("code", &self.code()) + .field("number", &self.number()) + .field("message", &self.message()) + .finish() } } -#[test] -fn test_error_downcasting() { - let error = MySqlError(ErrPacket { - error_code: 0xABCD, - sql_state: None, - error_message: "".into(), - }); - - let error = crate::Error::from(error); - - let db_err = match error { - crate::Error::Database(db_err) => db_err, - e => panic!("expected crate::Error::Database, got {:?}", e), - }; - - assert_eq!(db_err.downcast_ref::().0.error_code, 0xABCD); - assert_eq!(db_err.downcast::().0.error_code, 0xABCD); +impl Display for MySqlDatabaseError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + if let Some(code) = &self.code() { + write!(f, "{} ({}): {}", self.number(), code, self.message()) + } else { + write!(f, "{}: {}", self.number(), self.message()) + } + } +} + +impl Error for MySqlDatabaseError {} + +impl DatabaseError for MySqlDatabaseError { + #[inline] + fn message(&self) -> &str { + self.message() + } + + #[inline] + fn code(&self) -> Option> { + self.code().map(Cow::Borrowed) + } } diff --git a/sqlx-core/src/mysql/executor.rs b/sqlx-core/src/mysql/executor.rs deleted file mode 100644 index 024e6ea73..000000000 --- a/sqlx-core/src/mysql/executor.rs +++ /dev/null @@ -1,223 +0,0 @@ -use futures_core::future::BoxFuture; - -use crate::cursor::Cursor; -use crate::describe::{Column, Describe}; -use crate::executor::{Execute, Executor, RefExecutor}; -use crate::mysql::protocol::{ - self, ColumnDefinition, ComQuery, ComStmtExecute, ComStmtPrepare, ComStmtPrepareOk, FieldFlags, - Status, -}; -use crate::mysql::{MySql, MySqlArguments, MySqlCursor, MySqlTypeInfo}; - -impl super::MySqlConnection { - // Creates a prepared statement for the passed query string - async fn prepare(&mut self, query: &str) -> crate::Result { - // https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_com_stmt_prepare.html - self.stream.send(ComStmtPrepare { query }, true).await?; - - // Should receive a COM_STMT_PREPARE_OK or ERR_PACKET - let packet = self.stream.receive().await?; - - if packet[0] == 0xFF { - return self.stream.handle_err(); - } - - ComStmtPrepareOk::read(packet) - } - - async fn drop_column_defs(&mut self, count: usize) -> crate::Result<()> { - for _ in 0..count { - let _column = ColumnDefinition::read(self.stream.receive().await?)?; - } - - if count > 0 { - self.stream.maybe_receive_eof().await?; - } - - Ok(()) - } - - // Gets a cached prepared statement ID _or_ prepares the statement if not in the cache - // At the end we should have [cache_statement] and [cache_statement_columns] filled - async fn get_or_prepare(&mut self, query: &str) -> crate::Result { - if let Some(&id) = self.cache_statement.get(query) { - Ok(id) - } else { - let stmt = self.prepare(query).await?; - - self.cache_statement.insert(query.into(), stmt.statement_id); - - // COM_STMT_PREPARE returns the input columns - // We make no use of that data, so cycle through and drop them - self.drop_column_defs(stmt.params as usize).await?; - - // COM_STMT_PREPARE next returns the output columns - // We just drop these as we get these when we execute the query - self.drop_column_defs(stmt.columns as usize).await?; - - Ok(stmt.statement_id) - } - } - - pub(crate) async fn run( - &mut self, - query: &str, - arguments: Option, - ) -> crate::Result> { - self.stream.wait_until_ready().await?; - self.stream.is_ready = false; - - if let Some(arguments) = arguments { - let statement_id = self.get_or_prepare(query).await?; - - // https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_com_stmt_execute.html - self.stream - .send( - ComStmtExecute { - cursor: protocol::Cursor::NO_CURSOR, - statement_id, - params: &arguments.params, - null_bitmap: &arguments.null_bitmap, - param_types: &arguments.param_types, - }, - true, - ) - .await?; - - Ok(Some(statement_id)) - } else { - // https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_com_query.html - self.stream.send(ComQuery { query }, true).await?; - - Ok(None) - } - } - - async fn affected_rows(&mut self) -> crate::Result { - let mut rows = 0; - - loop { - let id = self.stream.receive().await?[0]; - - match id { - 0x00 | 0xFE if self.stream.packet().len() < 0xFF_FF_FF => { - // ResultSet row can begin with 0xfe byte (when using text protocol - // with a field length > 0xffffff) - - let status = if let Some(eof) = self.stream.maybe_handle_eof()? { - eof.status - } else { - let ok = self.stream.handle_ok()?; - - rows += ok.affected_rows; - ok.status - }; - - if !status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { - self.is_ready = true; - break; - } - } - - 0xFF => { - return self.stream.handle_err(); - } - - _ => {} - } - } - - Ok(rows) - } - - // method is not named describe to work around an intellijrust bug - // otherwise it marks someone trying to describe the connection as "method is private" - async fn do_describe(&mut self, query: &str) -> crate::Result> { - self.stream.wait_until_ready().await?; - - let stmt = self.prepare(query).await?; - - let mut param_types = Vec::with_capacity(stmt.params as usize); - let mut result_columns = Vec::with_capacity(stmt.columns as usize); - - for _ in 0..stmt.params { - let param = ColumnDefinition::read(self.stream.receive().await?)?; - param_types.push(MySqlTypeInfo::from_column_def(¶m)); - } - - if stmt.params > 0 { - self.stream.maybe_receive_eof().await?; - } - - for _ in 0..stmt.columns { - let column = ColumnDefinition::read(self.stream.receive().await?)?; - - result_columns.push(Column:: { - type_info: MySqlTypeInfo::from_column_def(&column), - name: column.column_alias.or(column.column), - table_id: column.table_alias.or(column.table), - // TODO(@abonander): Should this be None in some cases? - non_null: Some(column.flags.contains(FieldFlags::NOT_NULL)), - }); - } - - if stmt.columns > 0 { - self.stream.maybe_receive_eof().await?; - } - - Ok(Describe { - param_types: param_types.into_boxed_slice(), - result_columns: result_columns.into_boxed_slice(), - }) - } -} - -impl Executor for super::MySqlConnection { - type Database = MySql; - - fn execute<'e, 'q: 'e, 'c: 'e, E: 'e>( - &'c mut self, - query: E, - ) -> BoxFuture<'e, crate::Result> - where - E: Execute<'q, Self::Database>, - { - log_execution!(query, { - Box::pin(async move { - let (query, arguments) = query.into_parts(); - - self.run(query, arguments).await?; - self.affected_rows().await - }) - }) - } - - fn fetch<'q, E>(&mut self, query: E) -> MySqlCursor<'_, 'q> - where - E: Execute<'q, Self::Database>, - { - log_execution!(query, { MySqlCursor::from_connection(self, query) }) - } - - #[doc(hidden)] - fn describe<'e, 'q, E: 'e>( - &'e mut self, - query: E, - ) -> BoxFuture<'e, crate::Result>> - where - E: Execute<'q, Self::Database>, - { - Box::pin(async move { self.do_describe(query.into_parts().0).await }) - } -} - -impl<'c> RefExecutor<'c> for &'c mut super::MySqlConnection { - type Database = MySql; - - fn fetch_by_ref<'q, E>(self, query: E) -> MySqlCursor<'c, 'q> - where - E: Execute<'q, Self::Database>, - { - log_execution!(query, { MySqlCursor::from_connection(self, query) }) - } -} diff --git a/sqlx-core/src/mysql/io/buf.rs b/sqlx-core/src/mysql/io/buf.rs new file mode 100644 index 000000000..d533c088e --- /dev/null +++ b/sqlx-core/src/mysql/io/buf.rs @@ -0,0 +1,40 @@ +use bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::io::BufExt; + +pub trait MySqlBufExt: Buf { + // Read a length-encoded integer. + // NOTE: 0xfb or NULL is only returned for binary value encoding to indicate NULL. + // NOTE: 0xff is only returned during a result set to indicate ERR. + // + fn get_uint_lenenc(&mut self) -> u64; + + // Read a length-encoded string. + fn get_str_lenenc(&mut self) -> Result; + + // Read a length-encoded byte sequence. + fn get_bytes_lenenc(&mut self) -> Bytes; +} + +impl MySqlBufExt for Bytes { + fn get_uint_lenenc(&mut self) -> u64 { + match self.get_u8() { + 0xfc => u64::from(self.get_u16_le()), + 0xfd => u64::from(self.get_uint_le(3)), + 0xfe => u64::from(self.get_u64_le()), + + v => u64::from(v), + } + } + + fn get_str_lenenc(&mut self) -> Result { + let size = self.get_uint_lenenc(); + self.get_str(size as usize) + } + + fn get_bytes_lenenc(&mut self) -> Bytes { + let size = self.get_uint_lenenc(); + self.split_to(size as usize) + } +} diff --git a/sqlx-core/src/mysql/io/buf_ext.rs b/sqlx-core/src/mysql/io/buf_ext.rs deleted file mode 100644 index fc01624bf..000000000 --- a/sqlx-core/src/mysql/io/buf_ext.rs +++ /dev/null @@ -1,38 +0,0 @@ -use std::io; - -use byteorder::ByteOrder; - -use crate::io::Buf; - -pub trait BufExt { - fn get_uint_lenenc(&mut self) -> io::Result>; - - fn get_str_lenenc(&mut self) -> io::Result>; - - fn get_bytes_lenenc(&mut self) -> io::Result>; -} - -impl BufExt for &'_ [u8] { - fn get_uint_lenenc(&mut self) -> io::Result> { - Ok(match self.get_u8()? { - 0xFB => None, - 0xFC => Some(u64::from(self.get_u16::()?)), - 0xFD => Some(u64::from(self.get_u24::()?)), - 0xFE => Some(self.get_u64::()?), - - value => Some(u64::from(value)), - }) - } - - fn get_str_lenenc(&mut self) -> io::Result> { - self.get_uint_lenenc::()? - .map(move |len| self.get_str(len as usize)) - .transpose() - } - - fn get_bytes_lenenc(&mut self) -> io::Result> { - self.get_uint_lenenc::()? - .map(move |len| self.get_bytes(len as usize)) - .transpose() - } -} diff --git a/sqlx-core/src/mysql/io/buf_mut.rs b/sqlx-core/src/mysql/io/buf_mut.rs new file mode 100644 index 000000000..5b59c8563 --- /dev/null +++ b/sqlx-core/src/mysql/io/buf_mut.rs @@ -0,0 +1,134 @@ +use bytes::BufMut; + +pub trait MySqlBufMutExt: BufMut { + fn put_uint_lenenc(&mut self, v: u64); + + fn put_str_lenenc(&mut self, v: &str); + + fn put_bytes_lenenc(&mut self, v: &[u8]); +} + +impl MySqlBufMutExt for Vec { + fn put_uint_lenenc(&mut self, v: u64) { + // https://dev.mysql.com/doc/internals/en/integer.html + // https://mariadb.com/kb/en/library/protocol-data-types/#length-encoded-integers + + if v < 251 { + self.push(v as u8); + } else if v < 0x1_00_00 { + self.push(0xfc); + self.extend(&(v as u16).to_le_bytes()); + } else if v < 0x1_00_00_00 { + self.push(0xfd); + self.extend(&(v as u32).to_le_bytes()[..3]); + } else { + self.push(0xfe); + self.extend(&v.to_le_bytes()); + } + } + + fn put_str_lenenc(&mut self, v: &str) { + self.put_bytes_lenenc(v.as_bytes()); + } + + fn put_bytes_lenenc(&mut self, v: &[u8]) { + self.put_uint_lenenc(v.len() as u64); + self.extend(v); + } +} + +#[test] +fn test_encodes_int_lenenc_u8() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFA as u64); + + assert_eq!(&buf[..], b"\xFA"); +} + +#[test] +fn test_encodes_int_lenenc_u16() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(std::u16::MAX as u64); + + assert_eq!(&buf[..], b"\xFC\xFF\xFF"); +} + +#[test] +fn test_encodes_int_lenenc_u24() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFF_FF_FF as u64); + + assert_eq!(&buf[..], b"\xFD\xFF\xFF\xFF"); +} + +#[test] +fn test_encodes_int_lenenc_u64() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(std::u64::MAX); + + assert_eq!(&buf[..], b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"); +} + +#[test] +fn test_encodes_int_lenenc_fb() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFB as u64); + + assert_eq!(&buf[..], b"\xFC\xFB\x00"); +} + +#[test] +fn test_encodes_int_lenenc_fc() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFC as u64); + + assert_eq!(&buf[..], b"\xFC\xFC\x00"); +} + +#[test] +fn test_encodes_int_lenenc_fd() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFD as u64); + + assert_eq!(&buf[..], b"\xFC\xFD\x00"); +} + +#[test] +fn test_encodes_int_lenenc_fe() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFE as u64); + + assert_eq!(&buf[..], b"\xFC\xFE\x00"); +} + +#[test] +fn test_encodes_int_lenenc_ff() { + let mut buf = Vec::with_capacity(1024); + buf.put_uint_lenenc(0xFF as u64); + + assert_eq!(&buf[..], b"\xFC\xFF\x00"); +} + +#[test] +fn test_encodes_string_lenenc() { + let mut buf = Vec::with_capacity(1024); + buf.put_str_lenenc("random_string"); + + assert_eq!(&buf[..], b"\x0Drandom_string"); +} + +#[test] +fn test_encodes_string_null() { + let mut buf = Vec::with_capacity(1024); + buf.put_str_nul("random_string"); + + assert_eq!(&buf[..], b"random_string\0"); +} + +#[test] +fn test_encodes_byte_lenenc() { + let mut buf = Vec::with_capacity(1024); + buf.put_bytes_lenenc(b"random_string"); + + assert_eq!(&buf[..], b"\x0Drandom_string"); +} diff --git a/sqlx-core/src/mysql/io/buf_mut_ext.rs b/sqlx-core/src/mysql/io/buf_mut_ext.rs deleted file mode 100644 index 332c6f8cc..000000000 --- a/sqlx-core/src/mysql/io/buf_mut_ext.rs +++ /dev/null @@ -1,217 +0,0 @@ -use std::{u16, u32, u64, u8}; - -use byteorder::ByteOrder; - -use crate::io::BufMut; - -pub trait BufMutExt { - fn put_uint_lenenc>>(&mut self, val: U); - - fn put_str_lenenc(&mut self, val: &str); - - fn put_bytes_lenenc(&mut self, val: &[u8]); -} - -impl BufMutExt for Vec { - fn put_uint_lenenc>>(&mut self, value: U) { - if let Some(value) = value.into() { - // https://mariadb.com/kb/en/library/protocol-data-types/#length-encoded-integers - if value > 0xFF_FF_FF { - // Integer value is encoded in the next 8 bytes (9 bytes total) - self.push(0xFE); - self.put_u64::(value); - } else if value > u64::from(u16::MAX) { - // Integer value is encoded in the next 3 bytes (4 bytes total) - self.push(0xFD); - self.put_u24::(value as u32); - } else if value > u64::from(u8::MAX) { - // Integer value is encoded in the next 2 bytes (3 bytes total) - self.push(0xFC); - self.put_u16::(value as u16); - } else { - match value { - // If the value is of size u8 and one of the key bytes used in length encoding - // we must put that single byte as a u16 - 0xFB | 0xFC | 0xFD | 0xFE | 0xFF => { - self.push(0xFC); - self.put_u16::(value as u16); - } - - _ => { - self.push(value as u8); - } - } - } - } else { - self.push(0xFB); - } - } - - fn put_str_lenenc(&mut self, val: &str) { - self.put_uint_lenenc::(val.len() as u64); - self.extend_from_slice(val.as_bytes()); - } - - fn put_bytes_lenenc(&mut self, val: &[u8]) { - self.put_uint_lenenc::(val.len() as u64); - self.extend_from_slice(val); - } -} - -#[cfg(test)] -mod tests { - use super::{BufMut, BufMutExt}; - use byteorder::LittleEndian; - - #[test] - fn it_encodes_int_lenenc_none() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(None); - - assert_eq!(&buf[..], b"\xFB"); - } - - #[test] - fn it_encodes_int_lenenc_u8() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(0xFA as u64); - - assert_eq!(&buf[..], b"\xFA"); - } - - #[test] - fn it_encodes_int_lenenc_u16() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(std::u16::MAX as u64); - - assert_eq!(&buf[..], b"\xFC\xFF\xFF"); - } - - #[test] - fn it_encodes_int_lenenc_u24() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(0xFF_FF_FF as u64); - - assert_eq!(&buf[..], b"\xFD\xFF\xFF\xFF"); - } - - #[test] - fn it_encodes_int_lenenc_u64() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(std::u64::MAX); - - assert_eq!(&buf[..], b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"); - } - - #[test] - fn it_encodes_int_lenenc_fb() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(0xFB as u64); - - assert_eq!(&buf[..], b"\xFC\xFB\x00"); - } - - #[test] - fn it_encodes_int_lenenc_fc() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(0xFC as u64); - - assert_eq!(&buf[..], b"\xFC\xFC\x00"); - } - - #[test] - fn it_encodes_int_lenenc_fd() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(0xFD as u64); - - assert_eq!(&buf[..], b"\xFC\xFD\x00"); - } - - #[test] - fn it_encodes_int_lenenc_fe() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(0xFE as u64); - - assert_eq!(&buf[..], b"\xFC\xFE\x00"); - } - - #[test] - fn it_encodes_int_lenenc_ff() { - let mut buf = Vec::with_capacity(1024); - buf.put_uint_lenenc::(0xFF as u64); - - assert_eq!(&buf[..], b"\xFC\xFF\x00"); - } - - #[test] - fn it_encodes_int_u64() { - let mut buf = Vec::with_capacity(1024); - buf.put_u64::(std::u64::MAX); - - assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"); - } - - #[test] - fn it_encodes_int_u32() { - let mut buf = Vec::with_capacity(1024); - buf.put_u32::(std::u32::MAX); - - assert_eq!(&buf[..], b"\xFF\xFF\xFF\xFF"); - } - - #[test] - fn it_encodes_int_u24() { - let mut buf = Vec::with_capacity(1024); - buf.put_u24::(0xFF_FF_FF as u32); - - assert_eq!(&buf[..], b"\xFF\xFF\xFF"); - } - - #[test] - fn it_encodes_int_u16() { - let mut buf = Vec::with_capacity(1024); - buf.put_u16::(std::u16::MAX); - - assert_eq!(&buf[..], b"\xFF\xFF"); - } - - #[test] - fn it_encodes_int_u8() { - let mut buf = Vec::with_capacity(1024); - buf.put_u8(std::u8::MAX); - - assert_eq!(&buf[..], b"\xFF"); - } - - #[test] - fn it_encodes_string_lenenc() { - let mut buf = Vec::with_capacity(1024); - buf.put_str_lenenc::("random_string"); - - assert_eq!(&buf[..], b"\x0Drandom_string"); - } - - #[test] - fn it_encodes_string_fix() { - let mut buf = Vec::with_capacity(1024); - buf.put_str("random_string"); - - assert_eq!(&buf[..], b"random_string"); - } - - #[test] - fn it_encodes_string_null() { - let mut buf = Vec::with_capacity(1024); - buf.put_str_nul("random_string"); - - assert_eq!(&buf[..], b"random_string\0"); - } - - #[test] - fn it_encodes_byte_lenenc() { - let mut buf = Vec::with_capacity(1024); - buf.put_bytes_lenenc::(b"random_string"); - - assert_eq!(&buf[..], b"\x0Drandom_string"); - } -} diff --git a/sqlx-core/src/mysql/io/mod.rs b/sqlx-core/src/mysql/io/mod.rs index a8867b156..fafc91454 100644 --- a/sqlx-core/src/mysql/io/mod.rs +++ b/sqlx-core/src/mysql/io/mod.rs @@ -1,5 +1,5 @@ -mod buf_ext; -mod buf_mut_ext; +mod buf; +mod buf_mut; -pub use buf_ext::BufExt; -pub use buf_mut_ext::BufMutExt; +pub use buf::MySqlBufExt; +pub use buf_mut::MySqlBufMutExt; diff --git a/sqlx-core/src/mysql/mod.rs b/sqlx-core/src/mysql/mod.rs index 56368292d..75281d84a 100644 --- a/sqlx-core/src/mysql/mod.rs +++ b/sqlx-core/src/mysql/mod.rs @@ -1,35 +1,25 @@ -//! **MySQL** database and connection types. - -pub use arguments::MySqlArguments; -pub use connection::MySqlConnection; -pub use cursor::MySqlCursor; -pub use database::MySql; -pub use error::MySqlError; -pub use row::MySqlRow; -pub use type_info::MySqlTypeInfo; -pub use value::{MySqlData, MySqlValue}; +//! **MySQL** database driver. mod arguments; mod connection; -mod cursor; mod database; mod error; -mod executor; mod io; +mod options; mod protocol; mod row; -mod rsa; -mod stream; -mod tls; mod type_info; pub mod types; -mod util; mod value; -/// An alias for [`crate::pool::Pool`], specialized for **MySQL**. -#[cfg_attr(docsrs, doc(cfg(feature = "mysql")))] -pub type MySqlPool = crate::pool::Pool; +pub use arguments::MySqlArguments; +pub use connection::MySqlConnection; +pub use database::MySql; +pub use error::MySqlDatabaseError; +pub use options::{MySqlConnectOptions, MySqlSslMode}; +pub use row::MySqlRow; +pub use type_info::MySqlTypeInfo; +pub use value::{MySqlValue, MySqlValueFormat, MySqlValueRef}; -make_query_as!(MySqlQueryAs, MySql, MySqlRow); -impl_map_row_for_row!(MySql, MySqlRow); -impl_from_row_for_tuples!(MySql, MySqlRow); +/// An alias for [`Pool`][crate::pool::Pool], specialized for MySQL. +pub type MySqlPool = crate::pool::Pool; diff --git a/sqlx-core/src/mysql/options.rs b/sqlx-core/src/mysql/options.rs new file mode 100644 index 000000000..aeed5ec22 --- /dev/null +++ b/sqlx-core/src/mysql/options.rs @@ -0,0 +1,232 @@ +use std::path::{Path, PathBuf}; +use std::str::FromStr; +use url::Url; + +use crate::error::{BoxDynError, Error}; + +/// Options for controlling the desired security state of the connection to the MySQL server. +/// +/// It is used by the [`ssl_mode`](MySqlConnectOptions::ssl_mode) method. +#[derive(Debug, Clone, Copy)] +pub enum MySqlSslMode { + /// Establish an unencrypted connection. + Disabled, + + /// Establish an encrypted connection if the server supports encrypted connections, falling + /// back to an unencrypted connection if an encrypted connection cannot be established. + /// + /// This is the default if `ssl_mode` is not specified. + Preferred, + + /// Establish an encrypted connection if the server supports encrypted connections. + /// The connection attempt fails if an encrypted connection cannot be established. + Required, + + /// Like `Required`, but additionally verify the server Certificate Authority (CA) + /// certificate against the configured CA certificates. The connection attempt fails + /// if no valid matching CA certificates are found. + VerifyCa, + + /// Like `VerifyCa`, but additionally perform host name identity verification by + /// checking the host name the client uses for connecting to the server against the + /// identity in the certificate that the server sends to the client. + VerifyIdentity, +} + +impl Default for MySqlSslMode { + fn default() -> Self { + MySqlSslMode::Preferred + } +} + +impl FromStr for MySqlSslMode { + type Err = Error; + + fn from_str(s: &str) -> Result { + Ok(match s { + "DISABLED" => MySqlSslMode::Disabled, + "PREFERRED" => MySqlSslMode::Preferred, + "REQUIRED" => MySqlSslMode::Required, + "VERIFY_CA" => MySqlSslMode::VerifyCa, + "VERIFY_IDENTITY" => MySqlSslMode::VerifyIdentity, + + _ => { + return Err(err_protocol!("unknown SSL mode value: {:?}", s)); + } + }) + } +} + +/// Options and flags which can be used to configure a MySQL connection. +/// +/// A value of `PgConnectOptions` can be parsed from a connection URI, +/// as described by [MySQL](https://dev.mysql.com/doc/connector-j/8.0/en/connector-j-reference-jdbc-url-format.html). +/// +/// The generic format of the connection URL: +/// +/// ```text +/// mysql://[host][/database][?properties] +/// ``` +/// +/// # Example +/// +/// ```rust,no_run +/// # use sqlx_core::error::Error; +/// # use sqlx_core::connection::Connect; +/// # use sqlx_core::mysql::{MySqlConnectOptions, MySqlConnection, MySqlSslMode}; +/// # +/// # #[sqlx_rt::main] +/// # async fn main() -> Result<(), Error> { +/// // URI connection string +/// let conn = MySqlConnection::connect("mysql://root:password@localhost/db").await?; +/// +/// // Manually-constructed options +/// let conn = MySqlConnection::connect_with(&MySqlConnectOptions::new() +/// .host("localhost") +/// .username("root") +/// .password("password") +/// .database("db") +/// ).await?; +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug, Clone)] +pub struct MySqlConnectOptions { + pub(crate) host: String, + pub(crate) port: u16, + pub(crate) username: String, + pub(crate) password: Option, + pub(crate) database: Option, + pub(crate) ssl_mode: MySqlSslMode, + pub(crate) ssl_ca: Option, +} + +impl MySqlConnectOptions { + /// Creates a new, default set of options ready for configuration + pub fn new() -> Self { + Self { + port: 3306, + host: String::from("localhost"), + username: String::from("root"), + password: None, + database: None, + ssl_mode: MySqlSslMode::Preferred, + ssl_ca: None, + } + } + + /// Sets the name of the host to connect to. + /// + /// The default behavior when the host is not specified, + /// is to connect to localhost. + pub fn host(mut self, host: &str) -> Self { + self.host = host.to_owned(); + self + } + + /// Sets the port to connect to at the server host. + /// + /// The default port for MySQL is `3306`. + pub fn port(mut self, port: u16) -> Self { + self.port = port; + self + } + + /// Sets the username to connect as. + pub fn username(mut self, username: &str) -> Self { + self.username = username.to_owned(); + self + } + + /// Sets the password to connect with. + pub fn password(mut self, password: &str) -> Self { + self.password = Some(password.to_owned()); + self + } + + /// Sets the database name. + pub fn database(mut self, database: &str) -> Self { + self.database = Some(database.to_owned()); + self + } + + /// Sets whether or with what priority a secure SSL TCP/IP connection will be negotiated + /// with the server. + /// + /// By default, the SSL mode is [`Preferred`](MySqlSslMode::Preferred), and the client will + /// first attempt an SSL connection but fallback to a non-SSL connection on failure. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_core::mysql::{MySqlSslMode, MySqlConnectOptions}; + /// let options = MySqlConnectOptions::new() + /// .ssl_mode(MySqlSslMode::Required); + /// ``` + pub fn ssl_mode(mut self, mode: MySqlSslMode) -> Self { + self.ssl_mode = mode; + self + } + + /// Sets the name of a file containing a list of trusted SSL Certificate Authorities. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_core::mysql::{MySqlSslMode, MySqlConnectOptions}; + /// let options = MySqlConnectOptions::new() + /// .ssl_mode(MySqlSslMode::VerifyCa) + /// .ssl_ca("path/to/ca.crt"); + /// ``` + pub fn ssl_ca(mut self, file_name: impl AsRef) -> Self { + self.ssl_ca = Some(file_name.as_ref().to_owned()); + self + } +} + +impl FromStr for MySqlConnectOptions { + type Err = BoxDynError; + + fn from_str(s: &str) -> Result { + let url: Url = s.parse()?; + let mut options = Self::new(); + + if let Some(host) = url.host_str() { + options = options.host(host); + } + + if let Some(port) = url.port() { + options = options.port(port); + } + + let username = url.username(); + if !username.is_empty() { + options = options.username(username); + } + + if let Some(password) = url.password() { + options = options.password(password); + } + + let path = url.path().trim_start_matches('/'); + if !path.is_empty() { + options = options.database(path); + } + + for (key, value) in url.query_pairs().into_iter() { + match &*key { + "ssl-mode" => { + options = options.ssl_mode(value.parse()?); + } + + "ssl-ca" => { + options = options.ssl_ca(&*value); + } + + _ => {} + } + } + + Ok(options) + } +} diff --git a/sqlx-core/src/mysql/protocol/auth.rs b/sqlx-core/src/mysql/protocol/auth.rs new file mode 100644 index 000000000..261a8817a --- /dev/null +++ b/sqlx-core/src/mysql/protocol/auth.rs @@ -0,0 +1,34 @@ +use std::str::FromStr; + +use crate::error::Error; + +#[derive(Debug, Copy, Clone)] +pub enum AuthPlugin { + MySqlNativePassword, + CachingSha2Password, + Sha256Password, +} + +impl AuthPlugin { + pub(crate) fn name(self) -> &'static str { + match self { + AuthPlugin::MySqlNativePassword => "mysql_native_password", + AuthPlugin::CachingSha2Password => "caching_sha2_password", + AuthPlugin::Sha256Password => "sha256_password", + } + } +} + +impl FromStr for AuthPlugin { + type Err = Error; + + fn from_str(s: &str) -> Result { + match s { + "mysql_native_password" => Ok(AuthPlugin::MySqlNativePassword), + "caching_sha2_password" => Ok(AuthPlugin::CachingSha2Password), + "sha256_password" => Ok(AuthPlugin::Sha256Password), + + _ => Err(err_protocol!("unknown authentication plugin: {}", s)), + } + } +} diff --git a/sqlx-core/src/mysql/protocol/auth_plugin.rs b/sqlx-core/src/mysql/protocol/auth_plugin.rs deleted file mode 100644 index 2ffaa8af0..000000000 --- a/sqlx-core/src/mysql/protocol/auth_plugin.rs +++ /dev/null @@ -1,103 +0,0 @@ -use digest::{Digest, FixedOutput}; -use generic_array::GenericArray; -use memchr::memchr; -use sha1::Sha1; -use sha2::Sha256; - -use crate::mysql::util::xor_eq; - -#[derive(Debug, PartialEq)] -pub enum AuthPlugin { - MySqlNativePassword, - CachingSha2Password, - Sha256Password, -} - -impl AuthPlugin { - pub(crate) fn from_opt_str(s: Option<&str>) -> crate::Result { - match s { - Some("mysql_native_password") | None => Ok(AuthPlugin::MySqlNativePassword), - Some("caching_sha2_password") => Ok(AuthPlugin::CachingSha2Password), - Some("sha256_password") => Ok(AuthPlugin::Sha256Password), - - Some(s) => { - Err(protocol_err!("requires unimplemented authentication plugin: {}", s).into()) - } - } - } - - pub(crate) fn as_str(&self) -> &'static str { - match self { - AuthPlugin::MySqlNativePassword => "mysql_native_password", - AuthPlugin::CachingSha2Password => "caching_sha2_password", - AuthPlugin::Sha256Password => "sha256_password", - } - } - - pub(crate) fn scramble(&self, password: &str, nonce: &[u8]) -> Vec { - match self { - AuthPlugin::MySqlNativePassword => { - // The [nonce] for mysql_native_password is (optionally) nul terminated - let end = memchr(b'\0', nonce).unwrap_or(nonce.len()); - - scramble_sha1(password, &nonce[..end]).to_vec() - } - AuthPlugin::CachingSha2Password => scramble_sha256(password, nonce).to_vec(), - - _ => unimplemented!(), - } - } -} - -fn scramble_sha1( - password: &str, - seed: &[u8], -) -> GenericArray::OutputSize> { - // SHA1( password ) ^ SHA1( seed + SHA1( SHA1( password ) ) ) - // https://mariadb.com/kb/en/connection/#mysql_native_password-plugin - - let mut ctx = Sha1::new(); - - ctx.input(password); - - let mut pw_hash = ctx.result_reset(); - - ctx.input(&pw_hash); - - let pw_hash_hash = ctx.result_reset(); - - ctx.input(seed); - ctx.input(pw_hash_hash); - - let pw_seed_hash_hash = ctx.result(); - - xor_eq(&mut pw_hash, &pw_seed_hash_hash); - - pw_hash -} - -fn scramble_sha256( - password: &str, - seed: &[u8], -) -> GenericArray::OutputSize> { - // XOR(SHA256(password), SHA256(seed, SHA256(SHA256(password)))) - // https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/#sha-2-encrypted-password - let mut ctx = Sha256::new(); - - ctx.input(password); - - let mut pw_hash = ctx.result_reset(); - - ctx.input(&pw_hash); - - let pw_hash_hash = ctx.result_reset(); - - ctx.input(seed); - ctx.input(pw_hash_hash); - - let pw_seed_hash_hash = ctx.result(); - - xor_eq(&mut pw_hash, &pw_seed_hash_hash); - - pw_hash -} diff --git a/sqlx-core/src/mysql/protocol/auth_switch.rs b/sqlx-core/src/mysql/protocol/auth_switch.rs deleted file mode 100644 index 23e7dcfa7..000000000 --- a/sqlx-core/src/mysql/protocol/auth_switch.rs +++ /dev/null @@ -1,32 +0,0 @@ -use crate::io::Buf; -use crate::mysql::protocol::AuthPlugin; - -// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_auth_switch_request.html -#[derive(Debug)] -pub(crate) struct AuthSwitch { - pub(crate) auth_plugin: AuthPlugin, - pub(crate) auth_plugin_data: Box<[u8]>, -} - -impl AuthSwitch { - pub(crate) fn read(mut buf: &[u8]) -> crate::Result - where - Self: Sized, - { - let header = buf.get_u8()?; - if header != 0xFE { - return Err(protocol_err!( - "expected AUTH SWITCH (0xFE); received 0x{:X}", - header - ))?; - } - - let auth_plugin = AuthPlugin::from_opt_str(Some(buf.get_str_nul()?))?; - let auth_plugin_data = buf.get_bytes(buf.len())?.to_owned().into_boxed_slice(); - - Ok(Self { - auth_plugin_data, - auth_plugin, - }) - } -} diff --git a/sqlx-core/src/mysql/protocol/column_count.rs b/sqlx-core/src/mysql/protocol/column_count.rs deleted file mode 100644 index 3ed537d49..000000000 --- a/sqlx-core/src/mysql/protocol/column_count.rs +++ /dev/null @@ -1,16 +0,0 @@ -use byteorder::LittleEndian; - -use crate::mysql::io::BufExt; - -#[derive(Debug)] -pub struct ColumnCount { - pub columns: u64, -} - -impl ColumnCount { - pub(crate) fn read(mut buf: &[u8]) -> crate::Result { - let columns = buf.get_uint_lenenc::()?.unwrap_or(0); - - Ok(Self { columns }) - } -} diff --git a/sqlx-core/src/mysql/protocol/column_def.rs b/sqlx-core/src/mysql/protocol/column_def.rs deleted file mode 100644 index e20a9e414..000000000 --- a/sqlx-core/src/mysql/protocol/column_def.rs +++ /dev/null @@ -1,83 +0,0 @@ -use byteorder::LittleEndian; - -use crate::io::Buf; -use crate::mysql::io::BufExt; -use crate::mysql::protocol::{FieldFlags, TypeId}; - -// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_query_response_text_resultset_column_definition.html -// https://mariadb.com/kb/en/resultset/#column-definition-packet -#[derive(Debug)] -pub struct ColumnDefinition { - pub schema: Option>, - - pub table_alias: Option>, - pub table: Option>, - - pub column_alias: Option>, - pub column: Option>, - - pub char_set: u16, - - pub max_size: u32, - - pub type_id: TypeId, - - pub flags: FieldFlags, - - pub decimals: u8, -} - -impl ColumnDefinition { - pub fn name(&self) -> Option<&str> { - self.column_alias.as_deref().or(self.column.as_deref()) - } -} - -impl ColumnDefinition { - pub(crate) fn read(mut buf: &[u8]) -> crate::Result { - // catalog : string - let catalog = buf.get_str_lenenc::()?; - - if catalog != Some("def") { - return Err(protocol_err!( - "expected ColumnDefinition (\"def\"); received {:?}", - catalog - ))?; - } - - let schema = buf.get_str_lenenc::()?.map(Into::into); - let table_alias = buf.get_str_lenenc::()?.map(Into::into); - let table = buf.get_str_lenenc::()?.map(Into::into); - let column_alias = buf.get_str_lenenc::()?.map(Into::into); - let column = buf.get_str_lenenc::()?.map(Into::into); - - let len_fixed_fields = buf.get_uint_lenenc::()?.unwrap_or(0); - - if len_fixed_fields != 0x0c { - return Err(protocol_err!( - "expected ColumnDefinition (0x0c); received {:?}", - len_fixed_fields - ))?; - } - - let char_set = buf.get_u16::()?; - let max_size = buf.get_u32::()?; - - let type_id = buf.get_u8()?; - let flags = buf.get_u16::()?; - let decimals = buf.get_u8()?; - - Ok(Self { - schema, - table, - table_alias, - column, - column_alias, - char_set, - max_size, - type_id: TypeId(type_id), - flags: FieldFlags::from_bits_truncate(flags), - decimals, - }) - } -} diff --git a/sqlx-core/src/mysql/protocol/com_ping.rs b/sqlx-core/src/mysql/protocol/com_ping.rs deleted file mode 100644 index a90ce62ba..000000000 --- a/sqlx-core/src/mysql/protocol/com_ping.rs +++ /dev/null @@ -1,13 +0,0 @@ -use crate::io::BufMut; -use crate::mysql::protocol::{Capabilities, Encode}; - -// https://dev.mysql.com/doc/internals/en/com-ping.html -#[derive(Debug)] -pub struct ComPing; - -impl Encode for ComPing { - fn encode(&self, buf: &mut Vec, _: Capabilities) { - // COM_PING : int<1> - buf.put_u8(0x0e); - } -} diff --git a/sqlx-core/src/mysql/protocol/com_query.rs b/sqlx-core/src/mysql/protocol/com_query.rs deleted file mode 100644 index 0a8a20557..000000000 --- a/sqlx-core/src/mysql/protocol/com_query.rs +++ /dev/null @@ -1,18 +0,0 @@ -use crate::io::BufMut; -use crate::mysql::protocol::{Capabilities, Encode}; - -// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_query.html -#[derive(Debug)] -pub struct ComQuery<'a> { - pub query: &'a str, -} - -impl Encode for ComQuery<'_> { - fn encode(&self, buf: &mut Vec, _: Capabilities) { - // COM_QUERY : int<1> - buf.put_u8(0x03); - - // query : string - buf.put_str(self.query); - } -} diff --git a/sqlx-core/src/mysql/protocol/com_stmt_execute.rs b/sqlx-core/src/mysql/protocol/com_stmt_execute.rs deleted file mode 100644 index 5510ab17a..000000000 --- a/sqlx-core/src/mysql/protocol/com_stmt_execute.rs +++ /dev/null @@ -1,61 +0,0 @@ -use byteorder::LittleEndian; - -use crate::io::BufMut; -use crate::mysql::protocol::{Capabilities, Encode}; -use crate::mysql::type_info::MySqlTypeInfo; - -bitflags::bitflags! { - // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/mysql__com_8h.html#a3e5e9e744ff6f7b989a604fd669977da - // https://mariadb.com/kb/en/library/com_stmt_execute/#flag - pub struct Cursor: u8 { - const NO_CURSOR = 0; - const READ_ONLY = 1; - const FOR_UPDATE = 2; - const SCROLLABLE = 4; - } -} - -// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_execute.html -#[derive(Debug)] -pub struct ComStmtExecute<'a> { - pub statement_id: u32, - pub cursor: Cursor, - pub params: &'a [u8], - pub null_bitmap: &'a [u8], - pub param_types: &'a [MySqlTypeInfo], -} - -impl Encode for ComStmtExecute<'_> { - fn encode(&self, buf: &mut Vec, _: Capabilities) { - // COM_STMT_EXECUTE : int<1> - buf.put_u8(0x17); - - // statement_id : int<4> - buf.put_u32::(self.statement_id); - - // cursor : int<1> - buf.put_u8(self.cursor.bits()); - - // iterations (always 1) : int<4> - buf.put_u32::(1); - - if !self.param_types.is_empty() { - // null bitmap : byte<(param_count + 7)/8> - buf.put_bytes(self.null_bitmap); - - // send type to server (0 / 1) : byte<1> - buf.put_u8(1); - - for ty in self.param_types { - // field type : byte<1> - buf.put_u8(ty.id.0); - - // parameter flag : byte<1> - buf.put_u8(if ty.is_unsigned { 0x80 } else { 0 }); - } - - // byte binary parameter value - buf.put_bytes(self.params); - } - } -} diff --git a/sqlx-core/src/mysql/protocol/com_stmt_prepare.rs b/sqlx-core/src/mysql/protocol/com_stmt_prepare.rs deleted file mode 100644 index 743713370..000000000 --- a/sqlx-core/src/mysql/protocol/com_stmt_prepare.rs +++ /dev/null @@ -1,18 +0,0 @@ -use crate::io::BufMut; -use crate::mysql::protocol::{Capabilities, Encode}; - -// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_prepare.html -#[derive(Debug)] -pub struct ComStmtPrepare<'a> { - pub query: &'a str, -} - -impl Encode for ComStmtPrepare<'_> { - fn encode(&self, buf: &mut Vec, _: Capabilities) { - // COM_STMT_PREPARE : int<1> - buf.put_u8(0x16); - - // query : string - buf.put_str(self.query); - } -} diff --git a/sqlx-core/src/mysql/protocol/com_stmt_prepare_ok.rs b/sqlx-core/src/mysql/protocol/com_stmt_prepare_ok.rs deleted file mode 100644 index ae34d2b6f..000000000 --- a/sqlx-core/src/mysql/protocol/com_stmt_prepare_ok.rs +++ /dev/null @@ -1,48 +0,0 @@ -use byteorder::LittleEndian; - -use crate::io::Buf; - -// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_prepare.html#sect_protocol_com_stmt_prepare_response_ok -#[derive(Debug)] -pub(crate) struct ComStmtPrepareOk { - pub(crate) statement_id: u32, - - /// Number of columns in the returned result set (or 0 if statement - /// does not return result set). - pub(crate) columns: u16, - - /// Number of prepared statement parameters ('?' placeholders). - pub(crate) params: u16, - - /// Number of warnings. - pub(crate) warnings: u16, -} - -impl ComStmtPrepareOk { - pub(crate) fn read(mut buf: &[u8]) -> crate::Result { - let header = buf.get_u8()?; - - if header != 0x00 { - return Err(protocol_err!( - "expected COM_STMT_PREPARE_OK (0x00); received 0x{:X}", - header - ))?; - } - - let statement_id = buf.get_u32::()?; - let columns = buf.get_u16::()?; - let params = buf.get_u16::()?; - - // -not used- : string<1> - buf.advance(1); - - let warnings = buf.get_u16::()?; - - Ok(Self { - statement_id, - columns, - params, - warnings, - }) - } -} diff --git a/sqlx-core/src/mysql/protocol/connect/auth_switch.rs b/sqlx-core/src/mysql/protocol/connect/auth_switch.rs new file mode 100644 index 000000000..da0cc5506 --- /dev/null +++ b/sqlx-core/src/mysql/protocol/connect/auth_switch.rs @@ -0,0 +1,41 @@ +use bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::io::Encode; +use crate::io::{BufExt, Decode}; +use crate::mysql::protocol::auth::AuthPlugin; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_auth_switch_request.html + +#[derive(Debug)] +pub struct AuthSwitchRequest { + pub plugin: AuthPlugin, + pub data: Bytes, +} + +impl Decode<'_> for AuthSwitchRequest { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + let header = buf.get_u8(); + if header != 0xfe { + return Err(err_protocol!( + "expected 0xfe (AUTH_SWITCH) but found 0x{:x}", + header + )); + } + + let plugin = buf.get_str_nul()?.parse()?; + let data = buf.get_bytes(buf.len()); + + Ok(Self { plugin, data }) + } +} + +#[derive(Debug)] +pub struct AuthSwitchResponse(pub Vec); + +impl Encode<'_, Capabilities> for AuthSwitchResponse { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) { + buf.extend_from_slice(&self.0); + } +} diff --git a/sqlx-core/src/mysql/protocol/connect/handshake.rs b/sqlx-core/src/mysql/protocol/connect/handshake.rs new file mode 100644 index 000000000..d4f6f48c7 --- /dev/null +++ b/sqlx-core/src/mysql/protocol/connect/handshake.rs @@ -0,0 +1,194 @@ +use bytes::buf::ext::Chain; +use bytes::buf::BufExt as _; +use bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::io::{BufExt, Decode}; +use crate::mysql::protocol::auth::AuthPlugin; +use crate::mysql::protocol::response::Status; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake +// https://mariadb.com/kb/en/connection/#initial-handshake-packet + +#[derive(Debug)] +pub(crate) struct Handshake { + pub(crate) protocol_version: u8, + pub(crate) server_version: String, + pub(crate) connection_id: u32, + pub(crate) server_capabilities: Capabilities, + pub(crate) server_default_collation: u8, + pub(crate) status: Status, + pub(crate) auth_plugin: Option, + pub(crate) auth_plugin_data: Chain, +} + +impl Decode<'_> for Handshake { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + let protocol_version = buf.get_u8(); // int<1> + let server_version = buf.get_str_nul()?; // string + let connection_id = buf.get_u32_le(); // int<4> + let auth_plugin_data_1 = buf.get_bytes(8); // string<8> + + buf.advance(1); // reserved: string<1> + + let capabilities_1 = buf.get_u16_le(); // int<2> + let mut capabilities = Capabilities::from_bits_truncate(capabilities_1.into()); + + let collation = buf.get_u8(); // int<1> + let status = Status::from_bits_truncate(buf.get_u16_le()); + + let capabilities_2 = buf.get_u16_le(); // int<2> + capabilities |= Capabilities::from_bits_truncate(((capabilities_2 as u32) << 16).into()); + + let auth_plugin_data_len = if capabilities.contains(Capabilities::PLUGIN_AUTH) { + buf.get_u8() + } else { + buf.advance(1); // int<1> + 0 + }; + + buf.advance(6); // reserved: string<6> + + if capabilities.contains(Capabilities::MYSQL) { + buf.advance(4); // reserved: string<4> + } else { + let capabilities_3 = buf.get_u32_le(); // int<4> + capabilities |= Capabilities::from_bits_truncate((capabilities_3 as u64) << 32); + } + + let auth_plugin_data_2 = if capabilities.contains(Capabilities::SECURE_CONNECTION) { + let len = ((auth_plugin_data_len as isize) - 9).max(12) as usize; + let v = buf.get_bytes(len); + buf.advance(1); // NUL-terminator + + v + } else { + Bytes::new() + }; + + let auth_plugin = if capabilities.contains(Capabilities::PLUGIN_AUTH) { + Some(buf.get_str_nul()?.parse()?) + } else { + None + }; + + Ok(Self { + protocol_version, + server_version, + connection_id, + server_default_collation: collation, + status, + server_capabilities: capabilities, + auth_plugin, + auth_plugin_data: auth_plugin_data_1.chain(auth_plugin_data_2), + }) + } +} + +#[test] +fn test_decode_handshake_mysql_8_0_18() { + const HANDSHAKE_MYSQL_8_0_18: &[u8] = b"\n8.0.18\x00\x19\x00\x00\x00\x114aB0c\x06g\x00\xff\xff\xff\x02\x00\xff\xc7\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00tL\x03s\x0f[4\rl4. \x00caching_sha2_password\x00"; + + let mut p = Handshake::decode(HANDSHAKE_MYSQL_8_0_18.into()).unwrap(); + + assert_eq!(p.protocol_version, 10); + + p.server_capabilities.toggle( + Capabilities::MYSQL + | Capabilities::FOUND_ROWS + | Capabilities::LONG_FLAG + | Capabilities::CONNECT_WITH_DB + | Capabilities::NO_SCHEMA + | Capabilities::COMPRESS + | Capabilities::ODBC + | Capabilities::LOCAL_FILES + | Capabilities::IGNORE_SPACE + | Capabilities::PROTOCOL_41 + | Capabilities::INTERACTIVE + | Capabilities::SSL + | Capabilities::TRANSACTIONS + | Capabilities::SECURE_CONNECTION + | Capabilities::MULTI_STATEMENTS + | Capabilities::MULTI_RESULTS + | Capabilities::PS_MULTI_RESULTS + | Capabilities::PLUGIN_AUTH + | Capabilities::CONNECT_ATTRS + | Capabilities::PLUGIN_AUTH_LENENC_DATA + | Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS + | Capabilities::SESSION_TRACK + | Capabilities::DEPRECATE_EOF + | Capabilities::ZSTD_COMPRESSION_ALGORITHM + | Capabilities::SSL_VERIFY_SERVER_CERT + | Capabilities::OPTIONAL_RESULTSET_METADATA + | Capabilities::REMEMBER_OPTIONS, + ); + + assert!(p.server_capabilities.is_empty()); + + assert_eq!(p.server_default_collation, 255); + assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT)); + + assert!(matches!( + p.auth_plugin, + Some(AuthPlugin::CachingSha2Password) + )); + + assert_eq!( + &*p.auth_plugin_data.to_bytes(), + &[17, 52, 97, 66, 48, 99, 6, 103, 116, 76, 3, 115, 15, 91, 52, 13, 108, 52, 46, 32,] + ); +} + +#[test] +fn test_decode_handshake_mariadb_10_4_7() { + const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\" { + pub database: Option<&'a str>, + + /// Max size of a command packet that the client wants to send to the server + pub max_packet_size: u32, + + /// Default character set for the connection + pub char_set: u8, + + /// Name of the SQL account which client wants to log in + pub username: &'a str, + + /// Authentication method used by the client + pub auth_plugin: Option, + + /// Opaque authentication response + pub auth_response: Option<&'a [u8]>, +} + +impl Encode<'_, Capabilities> for HandshakeResponse<'_> { + fn encode_with(&self, buf: &mut Vec, mut capabilities: Capabilities) { + if self.auth_plugin.is_none() { + // ensure PLUGIN_AUTH is set *only* if we have a defined plugin + capabilities.remove(Capabilities::PLUGIN_AUTH); + } + + // NOTE: Half of this packet is identical to the SSL Request packet + SslRequest { + max_packet_size: self.max_packet_size, + char_set: self.char_set, + } + .encode_with(buf, capabilities); + + buf.put_str_nul(self.username); + + if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) { + buf.put_bytes_lenenc(self.auth_response.unwrap_or_default()); + } else if capabilities.contains(Capabilities::SECURE_CONNECTION) { + let response = self.auth_response.unwrap_or_default(); + + buf.push(response.len() as u8); + buf.extend(response); + } else { + buf.push(0); + } + + if capabilities.contains(Capabilities::CONNECT_WITH_DB) { + if let Some(database) = &self.database { + buf.put_str_nul(database); + } else { + buf.push(0); + } + } + + if capabilities.contains(Capabilities::PLUGIN_AUTH) { + if let Some(plugin) = &self.auth_plugin { + buf.put_str_nul(plugin.name()); + } else { + buf.push(0); + } + } + } +} diff --git a/sqlx-core/src/mysql/protocol/connect/mod.rs b/sqlx-core/src/mysql/protocol/connect/mod.rs new file mode 100644 index 000000000..0222ee89a --- /dev/null +++ b/sqlx-core/src/mysql/protocol/connect/mod.rs @@ -0,0 +1,13 @@ +//! Connection Phase +//! +//! + +mod auth_switch; +mod handshake; +mod handshake_response; +mod ssl_request; + +pub(crate) use auth_switch::{AuthSwitchRequest, AuthSwitchResponse}; +pub(crate) use handshake::Handshake; +pub(crate) use handshake_response::HandshakeResponse; +pub(crate) use ssl_request::SslRequest; diff --git a/sqlx-core/src/mysql/protocol/connect/ssl_request.rs b/sqlx-core/src/mysql/protocol/connect/ssl_request.rs new file mode 100644 index 000000000..8b927be5d --- /dev/null +++ b/sqlx-core/src/mysql/protocol/connect/ssl_request.rs @@ -0,0 +1,30 @@ +use crate::io::Encode; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_response.html +// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest + +#[derive(Debug)] +pub struct SslRequest { + pub max_packet_size: u32, + pub char_set: u8, +} + +impl Encode<'_, Capabilities> for SslRequest { + fn encode_with(&self, buf: &mut Vec, capabilities: Capabilities) { + buf.extend(&(capabilities.bits() as u32).to_le_bytes()); + buf.extend(&self.max_packet_size.to_le_bytes()); + buf.push(self.char_set); + + // reserved: string<19> + buf.extend(&[0_u8; 19]); + + if capabilities.contains(Capabilities::MYSQL) { + // reserved: string<4> + buf.extend(&[0_u8; 4]); + } else { + // extended client capabilities (MariaDB-specified): int<4> + buf.extend(&((capabilities.bits() >> 32) as u32).to_le_bytes()); + } + } +} diff --git a/sqlx-core/src/mysql/protocol/eof.rs b/sqlx-core/src/mysql/protocol/eof.rs deleted file mode 100644 index f8e80031a..000000000 --- a/sqlx-core/src/mysql/protocol/eof.rs +++ /dev/null @@ -1,35 +0,0 @@ -use byteorder::LittleEndian; - -use crate::io::Buf; -use crate::mysql::protocol::Status; - -// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_eof_packet.html -// https://mariadb.com/kb/en/eof_packet/ -#[derive(Debug)] -pub struct EofPacket { - pub warnings: u16, - pub status: Status, -} - -impl EofPacket { - pub(crate) fn read(mut buf: &[u8]) -> crate::Result - where - Self: Sized, - { - let header = buf.get_u8()?; - if header != 0xFE { - return Err(protocol_err!( - "expected EOF (0xFE); received 0x{:X}", - header - ))?; - } - - let warnings = buf.get_u16::()?; - let status = buf.get_u16::()?; - - Ok(Self { - warnings, - status: Status::from_bits_truncate(status), - }) - } -} diff --git a/sqlx-core/src/mysql/protocol/err.rs b/sqlx-core/src/mysql/protocol/err.rs deleted file mode 100644 index 595560e74..000000000 --- a/sqlx-core/src/mysql/protocol/err.rs +++ /dev/null @@ -1,75 +0,0 @@ -use byteorder::LittleEndian; - -use crate::io::Buf; -use crate::mysql::protocol::Capabilities; - -// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_err_packet.html -// https://mariadb.com/kb/en/err_packet/ -#[derive(Debug)] -pub struct ErrPacket { - pub error_code: u16, - pub sql_state: Option>, - pub error_message: Box, -} - -impl ErrPacket { - pub(crate) fn read(mut buf: &[u8], capabilities: Capabilities) -> crate::Result - where - Self: Sized, - { - let header = buf.get_u8()?; - if header != 0xFF { - return Err(protocol_err!( - "expected 0xFF for ERR_PACKET; received 0x{:X}", - header - ))?; - } - - let error_code = buf.get_u16::()?; - - let mut sql_state = None; - - if capabilities.contains(Capabilities::PROTOCOL_41) { - // If the next byte is '#' then we have a SQL STATE - if buf.get(0) == Some(&0x23) { - buf.advance(1); - sql_state = Some(buf.get_str(5)?.into()) - } - } - - let error_message = buf.get_str(buf.len())?.into(); - - Ok(Self { - error_code, - sql_state, - error_message, - }) - } -} - -#[cfg(test)] -mod tests { - use super::{Capabilities, ErrPacket}; - - const ERR_PACKETS_OUT_OF_ORDER: &[u8] = b"\xff\x84\x04Got packets out of order"; - - const ERR_HANDSHAKE_UNKNOWN_DB: &[u8] = b"\xff\x19\x04#42000Unknown database \'unknown\'"; - - #[test] - fn it_decodes_packets_out_of_order() { - let p = ErrPacket::read(ERR_PACKETS_OUT_OF_ORDER, Capabilities::PROTOCOL_41).unwrap(); - - assert_eq!(&*p.error_message, "Got packets out of order"); - assert_eq!(p.error_code, 1156); - assert_eq!(p.sql_state, None); - } - - #[test] - fn it_decodes_ok_handshake() { - let p = ErrPacket::read(ERR_HANDSHAKE_UNKNOWN_DB, Capabilities::PROTOCOL_41).unwrap(); - - assert_eq!(p.error_code, 1049); - assert_eq!(p.sql_state.as_deref(), Some("42000")); - assert_eq!(&*p.error_message, "Unknown database \'unknown\'"); - } -} diff --git a/sqlx-core/src/mysql/protocol/field.rs b/sqlx-core/src/mysql/protocol/field.rs deleted file mode 100644 index 1b3cc98de..000000000 --- a/sqlx-core/src/mysql/protocol/field.rs +++ /dev/null @@ -1,50 +0,0 @@ -// https://mariadb.com/kb/en/library/resultset/#field-detail-flag -// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__column__definition__flags.html -bitflags::bitflags! { - pub struct FieldFlags: u16 { - /// Field cannot be NULL - const NOT_NULL = 1; - - /// Field is **part of** a primary key - const PRIMARY_KEY = 2; - - /// Field is **part of** a unique key/constraint - const UNIQUE_KEY = 4; - - /// Field is **part of** a unique or primary key - const MULTIPLE_KEY = 8; - - /// Field is a blob. - const BLOB = 16; - - /// Field is unsigned - const UNSIGNED = 32; - - /// Field is zero filled. - const ZEROFILL = 64; - - /// Field is binary (set for strings) - const BINARY = 128; - - /// Field is an enumeration - const ENUM = 256; - - /// Field is an auto-increment field - const AUTO_INCREMENT = 512; - - /// Field is a timestamp - const TIMESTAMP = 1024; - - /// Field is a set - const SET = 2048; - - /// Field does not have a default value - const NO_DEFAULT_VALUE = 4096; - - /// Field is set to NOW on UPDATE - const ON_UPDATE_NOW = 8192; - - /// Field is a number - const NUM = 32768; - } -} diff --git a/sqlx-core/src/mysql/protocol/handshake.rs b/sqlx-core/src/mysql/protocol/handshake.rs index f060dd329..e69de29bb 100644 --- a/sqlx-core/src/mysql/protocol/handshake.rs +++ b/sqlx-core/src/mysql/protocol/handshake.rs @@ -1,208 +0,0 @@ -use byteorder::LittleEndian; - -use crate::io::Buf; -use crate::mysql::protocol::{AuthPlugin, Capabilities, Status}; - -// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_v10.html -// https://mariadb.com/kb/en/connection/#initial-handshake-packet -#[derive(Debug)] -pub(crate) struct Handshake { - pub(crate) protocol_version: u8, - pub(crate) server_version: Box, - pub(crate) connection_id: u32, - pub(crate) server_capabilities: Capabilities, - pub(crate) server_default_collation: u8, - pub(crate) status: Status, - pub(crate) auth_plugin: AuthPlugin, - pub(crate) auth_plugin_data: Box<[u8]>, -} - -impl Handshake { - pub(crate) fn read(mut buf: &[u8]) -> crate::Result - where - Self: Sized, - { - let protocol_version = buf.get_u8()?; - let server_version = buf.get_str_nul()?.into(); - let connection_id = buf.get_u32::()?; - - let mut scramble = Vec::with_capacity(8); - - // scramble first part : string<8> - scramble.extend_from_slice(&buf[..8]); - buf.advance(8); - - // reserved : string<1> - buf.advance(1); - - // capability_flags_1 : int<2> - let capabilities_1 = buf.get_u16::()?; - let mut capabilities = Capabilities::from_bits_truncate(capabilities_1.into()); - - // character_set : int<1> - let char_set = buf.get_u8()?; - - // status_flags : int<2> - let status = buf.get_u16::()?; - let status = Status::from_bits_truncate(status); - - // capability_flags_2 : int<2> - let capabilities_2 = buf.get_u16::()?; - capabilities |= Capabilities::from_bits_truncate(((capabilities_2 as u32) << 16).into()); - - let auth_plugin_data_len = if capabilities.contains(Capabilities::PLUGIN_AUTH) { - // plugin data length : int<1> - buf.get_u8()? - } else { - // 0x00 : int<1> - buf.advance(1); - 0 - }; - - // reserved: string<6> - buf.advance(6); - - if capabilities.contains(Capabilities::MYSQL) { - // reserved: string<4> - buf.advance(4); - } else { - // capability_flags_3 : int<4> - let capabilities_3 = buf.get_u32::()?; - capabilities |= Capabilities::from_bits_truncate((capabilities_3 as u64) << 32); - } - - if capabilities.contains(Capabilities::SECURE_CONNECTION) { - // scramble 2nd part : string ( Length = max(12, plugin data length - 9) ) - let len = ((auth_plugin_data_len as isize) - 9).max(12) as usize; - scramble.extend_from_slice(&buf[..len]); - buf.advance(len); - - // reserved : string<1> - buf.advance(1); - } - - let auth_plugin = if capabilities.contains(Capabilities::PLUGIN_AUTH) { - AuthPlugin::from_opt_str(Some(buf.get_str_nul()?))? - } else { - AuthPlugin::from_opt_str(None)? - }; - - Ok(Self { - protocol_version, - server_capabilities: capabilities, - server_version, - server_default_collation: char_set, - connection_id, - auth_plugin_data: scramble.into_boxed_slice(), - auth_plugin, - status, - }) - } -} - -#[cfg(test)] -mod tests { - use super::{AuthPlugin, Capabilities, Handshake, Status}; - - const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\" { - pub max_packet_size: u32, - pub client_collation: u8, - pub username: &'a str, - pub database: Option<&'a str>, - pub auth_plugin: &'a AuthPlugin, - pub auth_response: &'a [u8], -} - -impl Encode for HandshakeResponse<'_> { - fn encode(&self, buf: &mut Vec, capabilities: Capabilities) { - // client capabilities : int<4> - buf.put_u32::(capabilities.bits() as u32); - - // max packet size : int<4> - buf.put_u32::(self.max_packet_size); - - // client character collation : int<1> - buf.put_u8(self.client_collation); - - // reserved : string<19> - buf.advance(19); - - if capabilities.contains(Capabilities::MYSQL) { - // reserved : string<4> - buf.advance(4); - } else { - // extended client capabilities : int<4> - buf.put_u32::((capabilities.bits() >> 32) as u32); - } - - // username : string - buf.put_str_nul(self.username); - - if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) { - // auth_response : string - buf.put_bytes_lenenc::(self.auth_response); - } else if capabilities.contains(Capabilities::SECURE_CONNECTION) { - let auth_response = self.auth_response; - - // auth_response_length : int<1> - buf.put_u8(auth_response.len() as u8); - - // auth_response : string<{auth_response_length}> - buf.put_bytes(auth_response); - } else { - // no auth : int<1> - buf.put_u8(0); - } - - if capabilities.contains(Capabilities::CONNECT_WITH_DB) { - if let Some(database) = self.database { - // database : string - buf.put_str_nul(database); - } - } - - if capabilities.contains(Capabilities::PLUGIN_AUTH) { - // client_plugin_name : string - buf.put_str_nul(self.auth_plugin.as_str()); - } - } -} diff --git a/sqlx-core/src/mysql/protocol/mod.rs b/sqlx-core/src/mysql/protocol/mod.rs index f890c1e5c..9b19f3089 100644 --- a/sqlx-core/src/mysql/protocol/mod.rs +++ b/sqlx-core/src/mysql/protocol/mod.rs @@ -1,59 +1,13 @@ -mod auth_plugin; +pub(crate) mod auth; mod capabilities; -mod field; -mod status; -mod r#type; - -pub(crate) use auth_plugin::AuthPlugin; -pub(crate) use capabilities::Capabilities; -pub(crate) use field::FieldFlags; -pub(crate) use r#type::TypeId; -pub(crate) use status::Status; - -mod com_ping; -mod com_query; -mod com_stmt_execute; -mod com_stmt_prepare; -mod handshake; - -pub(crate) use com_ping::ComPing; -pub(crate) use com_query::ComQuery; -pub(crate) use com_stmt_execute::{ComStmtExecute, Cursor}; -pub(crate) use com_stmt_prepare::ComStmtPrepare; -pub(crate) use handshake::Handshake; - -mod auth_switch; -mod column_count; -mod column_def; -mod com_stmt_prepare_ok; -mod eof; -mod err; -mod handshake_response; -mod ok; +pub(crate) mod connect; +mod packet; +pub(crate) mod response; mod row; -#[cfg_attr(not(feature = "tls"), allow(unused_imports, dead_code))] -mod ssl_request; +pub(crate) mod rsa; +pub(crate) mod statement; +pub(crate) mod text; -pub(crate) use auth_switch::AuthSwitch; -pub(crate) use column_count::ColumnCount; -pub(crate) use column_def::ColumnDefinition; -pub(crate) use com_stmt_prepare_ok::ComStmtPrepareOk; -pub(crate) use eof::EofPacket; -pub(crate) use err::ErrPacket; -pub(crate) use handshake_response::HandshakeResponse; -pub(crate) use ok::OkPacket; +pub(crate) use capabilities::Capabilities; +pub(crate) use packet::Packet; pub(crate) use row::Row; -#[cfg_attr(not(feature = "tls"), allow(unused_imports, dead_code))] -pub(crate) use ssl_request::SslRequest; - -pub(crate) trait Encode { - fn encode(&self, buf: &mut Vec, capabilities: Capabilities); -} - -impl Encode for &'_ [u8] { - fn encode(&self, buf: &mut Vec, _: Capabilities) { - use crate::io::BufMut; - - buf.put_bytes(self); - } -} diff --git a/sqlx-core/src/mysql/protocol/ok.rs b/sqlx-core/src/mysql/protocol/ok.rs deleted file mode 100644 index 0639d09dd..000000000 --- a/sqlx-core/src/mysql/protocol/ok.rs +++ /dev/null @@ -1,64 +0,0 @@ -use byteorder::LittleEndian; - -use crate::io::Buf; -use crate::mysql::io::BufExt; -use crate::mysql::protocol::Status; - -// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_ok_packet.html -// https://mariadb.com/kb/en/ok_packet/ -#[derive(Debug)] -pub(crate) struct OkPacket { - pub(crate) affected_rows: u64, - pub(crate) last_insert_id: u64, - pub(crate) status: Status, - pub(crate) warnings: u16, - pub(crate) info: Box, -} - -impl OkPacket { - pub(crate) fn read(mut buf: &[u8]) -> crate::Result - where - Self: Sized, - { - let header = buf.get_u8()?; - if header != 0 && header != 0xFE { - return Err(protocol_err!( - "expected 0x00 or 0xFE; received 0x{:X}", - header - ))?; - } - - let affected_rows = buf.get_uint_lenenc::()?.unwrap_or(0); // 0 - let last_insert_id = buf.get_uint_lenenc::()?.unwrap_or(0); // 2 - let status = Status::from_bits_truncate(buf.get_u16::()?); // - let warnings = buf.get_u16::()?; - let info = buf.get_str(buf.len())?.into(); - - Ok(Self { - affected_rows, - last_insert_id, - status, - warnings, - info, - }) - } -} - -#[cfg(test)] -mod tests { - use super::{OkPacket, Status}; - - const OK_HANDSHAKE: &[u8] = b"\x00\x00\x00\x02@\x00\x00"; - - #[test] - fn it_decodes_ok_handshake() { - let p = OkPacket::read(OK_HANDSHAKE).unwrap(); - - assert_eq!(p.affected_rows, 0); - assert_eq!(p.last_insert_id, 0); - assert_eq!(p.warnings, 0); - assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT)); - assert!(p.status.contains(Status::SERVER_SESSION_STATE_CHANGED)); - assert!(p.info.is_empty()); - } -} diff --git a/sqlx-core/src/mysql/protocol/packet.rs b/sqlx-core/src/mysql/protocol/packet.rs new file mode 100644 index 000000000..8c49fcc33 --- /dev/null +++ b/sqlx-core/src/mysql/protocol/packet.rs @@ -0,0 +1,89 @@ +use std::ops::{Deref, DerefMut}; + +use bytes::Bytes; + +use crate::error::Error; +use crate::io::{Decode, Encode}; +use crate::mysql::protocol::response::{EofPacket, OkPacket}; +use crate::mysql::protocol::Capabilities; + +#[derive(Debug)] +pub struct Packet(pub(crate) T); + +impl<'en, 'stream, T> Encode<'stream, (Capabilities, &'stream mut u8)> for Packet +where + T: Encode<'en, Capabilities>, +{ + fn encode_with( + &self, + buf: &mut Vec, + (capabilities, sequence_id): (Capabilities, &'stream mut u8), + ) { + // reserve space to write the prefixed length + let offset = buf.len(); + buf.extend(&[0_u8; 4]); + + // encode the payload + self.0.encode_with(buf, capabilities); + + // determine the length of the encoded payload + // and write to our reserved space + let len = buf.len() - offset - 4; + let header = &mut buf[offset..]; + + // FIXME: Support larger packets + assert!(len < 0xFF_FF_FF); + + header[..4].copy_from_slice(&(len as u32).to_le_bytes()); + header[3] = *sequence_id; + + *sequence_id = sequence_id.wrapping_add(1); + } +} + +impl Packet { + pub(crate) fn decode<'de, T>(self) -> Result + where + T: Decode<'de, ()>, + { + self.decode_with(()) + } + + pub(crate) fn decode_with<'de, T, C>(self, context: C) -> Result + where + T: Decode<'de, C>, + { + T::decode_with(self.0, context) + } + + pub(crate) fn ok(self) -> Result { + self.decode() + } + + pub(crate) fn eof(self, capabilities: Capabilities) -> Result { + if capabilities.contains(Capabilities::DEPRECATE_EOF) { + let ok = self.ok()?; + + Ok(EofPacket { + warnings: ok.warnings, + status: ok.status, + }) + } else { + self.decode_with(capabilities) + } + } +} + +impl Deref for Packet { + type Target = Bytes; + + fn deref(&self) -> &Bytes { + &self.0 + } +} + +impl DerefMut for Packet { + fn deref_mut(&mut self) -> &mut Bytes { + &mut self.0 + } +} diff --git a/sqlx-core/src/mysql/protocol/response/eof.rs b/sqlx-core/src/mysql/protocol/response/eof.rs new file mode 100644 index 000000000..756c370a6 --- /dev/null +++ b/sqlx-core/src/mysql/protocol/response/eof.rs @@ -0,0 +1,35 @@ +use bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::io::Decode; +use crate::mysql::protocol::response::Status; +use crate::mysql::protocol::Capabilities; + +/// Marks the end of a result set, returning status and warnings. +/// +/// # Note +/// +/// The EOF packet is deprecated as of MySQL 5.7.5. SQLx only uses this packet for MySQL +/// prior MySQL versions. +#[derive(Debug)] +pub struct EofPacket { + pub warnings: u16, + pub status: Status, +} + +impl Decode<'_, Capabilities> for EofPacket { + fn decode_with(mut buf: Bytes, _: Capabilities) -> Result { + let header = buf.get_u8(); + if header != 0xfe { + return Err(err_protocol!( + "expected 0xfe (EOF_Packet) but found 0x{:x}", + header + )); + } + + let warnings = buf.get_u16_le(); + let status = Status::from_bits_truncate(buf.get_u16_le()); + + Ok(Self { status, warnings }) + } +} diff --git a/sqlx-core/src/mysql/protocol/response/err.rs b/sqlx-core/src/mysql/protocol/response/err.rs new file mode 100644 index 000000000..7cc2d8d0f --- /dev/null +++ b/sqlx-core/src/mysql/protocol/response/err.rs @@ -0,0 +1,71 @@ +use bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::io::{BufExt, Decode}; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_err_packet.html +// https://mariadb.com/kb/en/err_packet/ + +/// Indicates that an error occurred. +#[derive(Debug)] +pub struct ErrPacket { + pub error_code: u16, + pub sql_state: Option, + pub error_message: String, +} + +impl Decode<'_, Capabilities> for ErrPacket { + fn decode_with(mut buf: Bytes, capabilities: Capabilities) -> Result { + let header = buf.get_u8(); + if header != 0xff { + return Err(err_protocol!( + "expected 0xff (ERR_Packet) but found 0x{:x}", + header + )); + } + + let error_code = buf.get_u16_le(); + let mut sql_state = None; + + if capabilities.contains(Capabilities::PROTOCOL_41) { + // If the next byte is '#' then we have a SQL STATE + if buf.get(0) == Some(&0x23) { + buf.advance(1); + sql_state = Some(buf.get_str(5)?.to_owned()); + } + } + + let error_message = buf.get_str(buf.len())?.to_owned(); + + Ok(Self { + error_code, + sql_state, + error_message, + }) + } +} + +#[test] +fn test_decode_err_packet_out_of_order() { + const ERR_PACKETS_OUT_OF_ORDER: &[u8] = b"\xff\x84\x04Got packets out of order"; + + let p = + ErrPacket::decode_with(ERR_PACKETS_OUT_OF_ORDER.into(), Capabilities::PROTOCOL_41).unwrap(); + + assert_eq!(&p.error_message, "Got packets out of order"); + assert_eq!(p.error_code, 1156); + assert_eq!(p.sql_state, None); +} + +#[test] +fn test_decode_err_packet_unknown_database() { + const ERR_HANDSHAKE_UNKNOWN_DB: &[u8] = b"\xff\x19\x04#42000Unknown database \'unknown\'"; + + let p = + ErrPacket::decode_with(ERR_HANDSHAKE_UNKNOWN_DB.into(), Capabilities::PROTOCOL_41).unwrap(); + + assert_eq!(p.error_code, 1049); + assert_eq!(p.sql_state.as_deref(), Some("42000")); + assert_eq!(&p.error_message, "Unknown database \'unknown\'"); +} diff --git a/sqlx-core/src/mysql/protocol/response/mod.rs b/sqlx-core/src/mysql/protocol/response/mod.rs new file mode 100644 index 000000000..79767dc60 --- /dev/null +++ b/sqlx-core/src/mysql/protocol/response/mod.rs @@ -0,0 +1,14 @@ +//! Generic Response Packets +//! +//! +//! + +mod eof; +mod err; +mod ok; +mod status; + +pub use eof::EofPacket; +pub use err::ErrPacket; +pub use ok::OkPacket; +pub use status::Status; diff --git a/sqlx-core/src/mysql/protocol/response/ok.rs b/sqlx-core/src/mysql/protocol/response/ok.rs new file mode 100644 index 000000000..8ad607b79 --- /dev/null +++ b/sqlx-core/src/mysql/protocol/response/ok.rs @@ -0,0 +1,52 @@ +use bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::io::Decode; +use crate::mysql::io::MySqlBufExt; +use crate::mysql::protocol::response::Status; + +/// Indicates successful completion of a previous command sent by the client. +#[derive(Debug)] +pub struct OkPacket { + pub affected_rows: u64, + pub last_insert_id: u64, + pub status: Status, + pub warnings: u16, +} + +impl Decode<'_> for OkPacket { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + let header = buf.get_u8(); + if header != 0 && header != 0xfe { + return Err(err_protocol!( + "expected 0x00 or 0xfe (OK_Packet) but found 0x{:02x}", + header + )); + } + + let affected_rows = buf.get_uint_lenenc(); + let last_insert_id = buf.get_uint_lenenc(); + let status = Status::from_bits_truncate(buf.get_u16_le()); + let warnings = buf.get_u16_le(); + + Ok(Self { + affected_rows, + last_insert_id, + status, + warnings, + }) + } +} + +#[test] +fn test_decode_ok_packet() { + const DATA: &[u8] = b"\x00\x00\x00\x02@\x00\x00"; + + let p = OkPacket::decode(DATA.into()).unwrap(); + + assert_eq!(p.affected_rows, 0); + assert_eq!(p.last_insert_id, 0); + assert_eq!(p.warnings, 0); + assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT)); + assert!(p.status.contains(Status::SERVER_SESSION_STATE_CHANGED)); +} diff --git a/sqlx-core/src/mysql/protocol/status.rs b/sqlx-core/src/mysql/protocol/response/status.rs similarity index 100% rename from sqlx-core/src/mysql/protocol/status.rs rename to sqlx-core/src/mysql/protocol/response/status.rs diff --git a/sqlx-core/src/mysql/protocol/row.rs b/sqlx-core/src/mysql/protocol/row.rs index 2fdd0fa88..60e79c30b 100644 --- a/sqlx-core/src/mysql/protocol/row.rs +++ b/sqlx-core/src/mysql/protocol/row.rs @@ -1,323 +1,21 @@ use std::ops::Range; -use byteorder::{ByteOrder, LittleEndian}; +use bytes::Bytes; -use crate::io::Buf; -use crate::mysql::protocol::TypeId; -use crate::mysql::MySqlTypeInfo; - -pub(crate) struct Row<'c> { - buffer: &'c [u8], - values: &'c [Option>], - pub(crate) columns: &'c [MySqlTypeInfo], - pub(crate) binary: bool, +#[derive(Debug)] +pub(crate) struct Row { + pub(crate) storage: Bytes, + pub(crate) values: Vec>>, } -impl<'c> Row<'c> { +impl Row { pub(crate) fn len(&self) -> usize { self.values.len() } - pub(crate) fn get(&self, index: usize) -> Option<&'c [u8]> { - let range = self.values[index].as_ref()?; - - Some(&self.buffer[(range.start as usize)..(range.end as usize)]) + pub(crate) fn get(&self, index: usize) -> Option<&[u8]> { + self.values[index] + .as_ref() + .map(|col| &self.storage[(col.start as usize)..(col.end as usize)]) } } - -fn get_lenenc(buf: &[u8]) -> (usize, Option) { - match buf[0] { - 0xFB => (1, None), - - 0xFC => { - let len_size = 1 + 2; - let len = LittleEndian::read_u16(&buf[1..]); - - (len_size, Some(len as usize)) - } - - 0xFD => { - let len_size = 1 + 3; - let len = LittleEndian::read_u24(&buf[1..]); - - (len_size, Some(len as usize)) - } - - 0xFE => { - let len_size = 1 + 8; - let len = LittleEndian::read_u64(&buf[1..]); - - (len_size, Some(len as usize)) - } - - len => (1, Some(len as usize)), - } -} - -impl<'c> Row<'c> { - pub(crate) fn read( - mut buf: &'c [u8], - columns: &'c [MySqlTypeInfo], - values: &'c mut Vec>>, - binary: bool, - ) -> crate::Result { - let buffer = &*buf; - - values.clear(); - values.reserve(columns.len()); - - if !binary { - let mut index = 0; - - for _ in 0..columns.len() { - let (len_size, size) = get_lenenc(&buf[index..]); - - if let Some(size) = size { - values.push(Some((index + len_size)..(index + len_size + size))); - } else { - values.push(None); - } - - index += len_size + size.unwrap_or_default(); - } - - return Ok(Self { - buffer, - columns, - values: &*values, - binary: false, - }); - } - - // 0x00 header : byte<1> - let header = buf.get_u8()?; - if header != 0 { - return Err(protocol_err!("expected ROW (0x00), got: {:#04X}", header).into()); - } - - // NULL-Bitmap : byte<(number_of_columns + 9) / 8> - let null_len = (columns.len() + 9) / 8; - let null_bitmap = &buf[..]; - buf.advance(null_len); - - let buffer: Box<[u8]> = buf.into(); - let mut index = 0; - - for column_idx in 0..columns.len() { - // the null index for a column starts at the 3rd bit in the null bitmap - // for no reason at all besides mysql probably - let column_null_idx = column_idx + 2; - let is_null = - null_bitmap[column_null_idx / 8] & (1 << (column_null_idx % 8) as u8) != 0; - - if is_null { - values.push(None); - } else { - let (offset, size) = match columns[column_idx].id { - TypeId::TINY_INT => (0, 1), - TypeId::SMALL_INT => (0, 2), - TypeId::INT | TypeId::FLOAT => (0, 4), - TypeId::BIG_INT | TypeId::DOUBLE => (0, 8), - - TypeId::DATE => (0, 5), - TypeId::TIME => (0, 1 + buffer[index] as usize), - - TypeId::TIMESTAMP | TypeId::DATETIME => (0, 1 + buffer[index] as usize), - - TypeId::TINY_BLOB - | TypeId::MEDIUM_BLOB - | TypeId::LONG_BLOB - | TypeId::CHAR - | TypeId::TEXT - | TypeId::ENUM - | TypeId::VAR_CHAR => { - let (len_size, len) = get_lenenc(&buffer[index..]); - - (len_size, len.unwrap_or_default()) - } - - TypeId::NEWDECIMAL => (0, 1 + buffer[index] as usize), - - id => { - unimplemented!("encountered unknown field type id: {:?}", id); - } - }; - - values.push(Some((index + offset)..(index + offset + size))); - index += size + offset; - } - } - - Ok(Self { - buffer: buf, - values: &*values, - columns, - binary, - }) - } -} - -// #[cfg(test)] -// mod test { -// use super::super::column_count::ColumnCount; -// use super::super::column_def::ColumnDefinition; -// use super::super::eof::EofPacket; -// use super::*; -// -// #[test] -// fn null_bitmap_test() -> crate::Result<()> { -// let column_len = ColumnCount::decode(&[26])?; -// assert_eq!(column_len.columns, 26); -// -// let types: Vec = vec![ -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 2, 105, 100, 2, 105, 100, 12, 63, 0, 11, 0, 0, -// 0, 3, 11, 66, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 50, 6, 102, 105, -// 101, 108, 100, 50, 12, 224, 0, 120, 0, 0, 0, 253, 5, 64, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 51, 6, 102, 105, -// 101, 108, 100, 51, 12, 224, 0, 252, 3, 0, 0, 253, 1, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 52, 6, 102, 105, -// 101, 108, 100, 52, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 53, 6, 102, 105, -// 101, 108, 100, 53, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 54, 6, 102, 105, -// 101, 108, 100, 54, 12, 63, 0, 19, 0, 0, 0, 7, 128, 4, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 55, 6, 102, 105, -// 101, 108, 100, 55, 12, 63, 0, 4, 0, 0, 0, 1, 1, 64, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 56, 6, 102, 105, -// 101, 108, 100, 56, 12, 224, 0, 252, 255, 3, 0, 252, 16, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 6, 102, 105, 101, 108, 100, 57, 6, 102, 105, -// 101, 108, 100, 57, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 48, 7, 102, -// 105, 101, 108, 100, 49, 48, 12, 224, 0, 252, 3, 0, 0, 252, 16, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 49, 7, 102, -// 105, 101, 108, 100, 49, 49, 12, 224, 0, 252, 3, 0, 0, 252, 16, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 50, 7, 102, -// 105, 101, 108, 100, 49, 50, 12, 63, 0, 19, 0, 0, 0, 7, 129, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 51, 7, 102, -// 105, 101, 108, 100, 49, 51, 12, 63, 0, 4, 0, 0, 0, 1, 0, 64, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 52, 7, 102, -// 105, 101, 108, 100, 49, 52, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 53, 7, 102, -// 105, 101, 108, 100, 49, 53, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 54, 7, 102, -// 105, 101, 108, 100, 49, 54, 12, 63, 0, 4, 0, 0, 0, 1, 1, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 55, 7, 102, -// 105, 101, 108, 100, 49, 55, 12, 224, 0, 0, 1, 0, 0, 253, 0, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 56, 7, 102, -// 105, 101, 108, 100, 49, 56, 12, 63, 0, 11, 0, 0, 0, 3, 1, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 49, 57, 7, 102, -// 105, 101, 108, 100, 49, 57, 12, 63, 0, 11, 0, 0, 0, 3, 1, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 48, 7, 102, -// 105, 101, 108, 100, 50, 48, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 49, 7, 102, -// 105, 101, 108, 100, 50, 49, 12, 63, 0, 19, 0, 0, 0, 7, 128, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 50, 7, 102, -// 105, 101, 108, 100, 50, 50, 12, 63, 0, 3, 0, 0, 0, 3, 0, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 51, 7, 102, -// 105, 101, 108, 100, 50, 51, 12, 63, 0, 6, 0, 0, 0, 3, 1, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 52, 7, 102, -// 105, 101, 108, 100, 50, 52, 12, 63, 0, 6, 0, 0, 0, 3, 1, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 53, 7, 102, -// 105, 101, 108, 100, 50, 53, 12, 63, 0, 20, 0, 0, 0, 8, 1, 0, 0, 0, 0, -// ])?, -// ColumnDefinition::decode(&[ -// 3, 100, 101, 102, 4, 115, 113, 108, 120, 8, 97, 99, 99, 111, 117, 110, 116, 115, 8, -// 97, 99, 99, 111, 117, 110, 116, 115, 7, 102, 105, 101, 108, 100, 50, 54, 7, 102, -// 105, 101, 108, 100, 50, 54, 12, 63, 0, 11, 0, 0, 0, 3, 0, 0, 0, 0, 0, -// ])?, -// ] -// .into_iter() -// .map(|def| def.type_id) -// .collect(); -// -// EofPacket::decode(&[254, 0, 0, 34, 0])?; -// -// Row::read( -// &[ -// 0, 64, 90, 229, 0, 4, 0, 0, 0, 4, 114, 117, 115, 116, 0, 0, 7, 228, 7, 1, 16, 8, -// 10, 17, 0, 0, 4, 208, 7, 1, 1, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -// 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -// ], -// &types, -// true, -// )?; -// -// EofPacket::decode(&[254, 0, 0, 34, 0])?; -// Ok(()) -// } -// } diff --git a/sqlx-core/src/mysql/rsa.rs b/sqlx-core/src/mysql/protocol/rsa.rs similarity index 94% rename from sqlx-core/src/mysql/rsa.rs rename to sqlx-core/src/mysql/protocol/rsa.rs index a4e242a00..21c7b8839 100644 --- a/sqlx-core/src/mysql/rsa.rs +++ b/sqlx-core/src/mysql/protocol/rsa.rs @@ -2,14 +2,16 @@ use digest::Digest; use num_bigint::BigUint; use rand::{thread_rng, Rng}; +use crate::error::Error; + // This is mostly taken from https://github.com/RustCrypto/RSA/pull/18 // For the love of crypto, please delete as much of this as possible and use the RSA crate // directly when that PR is merged -pub fn encrypt(key: &[u8], message: &[u8]) -> crate::Result> { +pub fn encrypt(key: &[u8], message: &[u8]) -> Result, Error> { let key = std::str::from_utf8(key).map_err(|_err| { - // TODO(@abonander): protocol_err doesn't like referring to [err] - protocol_err!("unexpected error decoding what should be UTF-8") + // TODO(@abonander): err_protocol doesn't like referring to [err] + err_protocol!("unexpected error decoding what should be UTF-8") })?; let key = parse(key)?; @@ -96,7 +98,7 @@ fn oaep_encrypt( rng: &mut R, pub_key: &PublicKey, msg: &[u8], -) -> crate::Result> { +) -> Result, Error> { // size of [n] in bytes let k = (pub_key.n.bits() + 7) / 8; @@ -104,7 +106,7 @@ fn oaep_encrypt( let h_size = D::output_size(); if msg.len() > k - 2 * h_size - 2 { - return Err(protocol_err!("mysql: password too long").into()); + return Err(err_protocol!("mysql: password too long")); } let mut em = vec![0u8; k]; @@ -140,13 +142,13 @@ struct PublicKey { e: BigUint, } -fn parse(key: &str) -> crate::Result { +fn parse(key: &str) -> Result { // This takes advantage of the knowledge that we know // we are receiving a PKCS#8 RSA Public Key at all // times from MySQL if !key.starts_with("-----BEGIN PUBLIC KEY-----\n") { - return Err(protocol_err!( + return Err(err_protocol!( "unexpected format for RSA Public Key from MySQL (expected PKCS#8); first line: {:?}", key.splitn(1, '\n').next() ) @@ -158,8 +160,8 @@ fn parse(key: &str) -> crate::Result { let inner_key = key_with_trailer[..trailer_pos].replace('\n', ""); let inner = base64::decode(&inner_key).map_err(|_err| { - // TODO(@abonander): protocol_err doesn't like referring to [err] - protocol_err!("unexpected error decoding what should be base64-encoded data") + // TODO(@abonander): err_protocol doesn't like referring to [err] + err_protocol!("unexpected error decoding what should be base64-encoded data") })?; let len = inner.len(); diff --git a/sqlx-core/src/mysql/protocol/ssl_request.rs b/sqlx-core/src/mysql/protocol/ssl_request.rs deleted file mode 100644 index bbf781c65..000000000 --- a/sqlx-core/src/mysql/protocol/ssl_request.rs +++ /dev/null @@ -1,34 +0,0 @@ -use byteorder::LittleEndian; - -use crate::io::BufMut; -use crate::mysql::protocol::{Capabilities, Encode}; - -// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_response.html -// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest -#[derive(Debug)] -pub struct SslRequest { - pub max_packet_size: u32, - pub client_collation: u8, -} - -impl Encode for SslRequest { - fn encode(&self, buf: &mut Vec, capabilities: Capabilities) { - // SSL must be set or else it makes no sense to ask for an upgrade - assert!( - capabilities.contains(Capabilities::SSL), - "SSL bit must be set for Capabilities" - ); - - // client capabilities : int<4> - buf.put_u32::(capabilities.bits() as u32); - - // max packet size : int<4> - buf.put_u32::(self.max_packet_size); - - // client character collation : int<1> - buf.put_u8(self.client_collation); - - // reserved : string<23> - buf.advance(23); - } -} diff --git a/sqlx-core/src/mysql/protocol/statement/execute.rs b/sqlx-core/src/mysql/protocol/statement/execute.rs new file mode 100644 index 000000000..2186a927b --- /dev/null +++ b/sqlx-core/src/mysql/protocol/statement/execute.rs @@ -0,0 +1,38 @@ +use crate::io::Encode; +use crate::mysql::protocol::text::ColumnFlags; +use crate::mysql::protocol::Capabilities; +use crate::mysql::MySqlArguments; + +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_execute.html + +#[derive(Debug)] +pub struct Execute<'q> { + pub statement: u32, + pub arguments: &'q MySqlArguments, +} + +impl<'q> Encode<'_, Capabilities> for Execute<'q> { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) { + buf.push(0x17); // COM_STMT_EXECUTE + buf.extend(&self.statement.to_le_bytes()); + buf.push(0); // NO_CURSOR + buf.extend(&0_u32.to_le_bytes()); // iterations (always 1): int<4> + + if !self.arguments.is_empty() { + buf.extend(&*self.arguments.null_bitmap); + buf.push(1); // send type to server + + for ty in &self.arguments.types { + buf.push(ty.r#type as u8); + + buf.push(if ty.flags.contains(ColumnFlags::UNSIGNED) { + 0x80 + } else { + 0 + }); + } + + buf.extend(&*self.arguments.values); + } + } +} diff --git a/sqlx-core/src/mysql/protocol/statement/mod.rs b/sqlx-core/src/mysql/protocol/statement/mod.rs new file mode 100644 index 000000000..5ad292f56 --- /dev/null +++ b/sqlx-core/src/mysql/protocol/statement/mod.rs @@ -0,0 +1,9 @@ +mod execute; +mod prepare; +mod prepare_ok; +mod row; + +pub(crate) use execute::Execute; +pub(crate) use prepare::Prepare; +pub(crate) use prepare_ok::PrepareOk; +pub(crate) use row::BinaryRow; diff --git a/sqlx-core/src/mysql/protocol/statement/prepare.rs b/sqlx-core/src/mysql/protocol/statement/prepare.rs new file mode 100644 index 000000000..325b29faf --- /dev/null +++ b/sqlx-core/src/mysql/protocol/statement/prepare.rs @@ -0,0 +1,15 @@ +use crate::io::Encode; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html#packet-COM_STMT_PREPARE + +pub struct Prepare<'a> { + pub query: &'a str, +} + +impl Encode<'_, Capabilities> for Prepare<'_> { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) { + buf.push(0x16); // COM_STMT_PREPARE + buf.extend(self.query.as_bytes()); + } +} diff --git a/sqlx-core/src/mysql/protocol/statement/prepare_ok.rs b/sqlx-core/src/mysql/protocol/statement/prepare_ok.rs new file mode 100644 index 000000000..cac4cbbed --- /dev/null +++ b/sqlx-core/src/mysql/protocol/statement/prepare_ok.rs @@ -0,0 +1,42 @@ +use bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::io::Decode; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html#packet-COM_STMT_PREPARE_OK + +#[derive(Debug)] +pub(crate) struct PrepareOk { + pub(crate) statement_id: u32, + pub(crate) columns: u16, + pub(crate) params: u16, + pub(crate) warnings: u16, +} + +impl Decode<'_, Capabilities> for PrepareOk { + fn decode_with(mut buf: Bytes, _: Capabilities) -> Result { + let status = buf.get_u8(); + if status != 0x00 { + return Err(err_protocol!( + "expected 0x00 (COM_STMT_PREPARE_OK) but found 0x{:02x}", + status + )); + } + + let statement_id = buf.get_u32_le(); + let columns = buf.get_u16_le(); + let params = buf.get_u16_le(); + + buf.advance(1); // reserved: string<1> + + let warnings = buf.get_u16_le(); + + Ok(Self { + statement_id, + columns, + params, + warnings, + }) + } +} diff --git a/sqlx-core/src/mysql/protocol/statement/row.rs b/sqlx-core/src/mysql/protocol/statement/row.rs new file mode 100644 index 000000000..29723daf0 --- /dev/null +++ b/sqlx-core/src/mysql/protocol/statement/row.rs @@ -0,0 +1,92 @@ +use bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::io::{BufExt, Decode}; +use crate::mysql::io::MySqlBufExt; +use crate::mysql::protocol::text::ColumnType; +use crate::mysql::protocol::Row; +use crate::mysql::row::MySqlColumn; + +// https://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html#packet-ProtocolBinary::ResultsetRow +// https://dev.mysql.com/doc/internals/en/binary-protocol-value.html + +#[derive(Debug)] +pub(crate) struct BinaryRow(pub(crate) Row); + +impl<'de> Decode<'de, &'de [MySqlColumn]> for BinaryRow { + fn decode_with(mut buf: Bytes, columns: &'de [MySqlColumn]) -> Result { + let header = buf.get_u8(); + if header != 0 { + return Err(err_protocol!( + "exepcted 0x00 (ROW) but found 0x{:02x}", + header + )); + } + + let storage = buf.clone(); + let offset = buf.len(); + + let null_bitmap_len = (columns.len() + 9) / 8; + let null_bitmap = buf.get_bytes(null_bitmap_len); + + let mut values = Vec::with_capacity(columns.len()); + + for (column_idx, column) in columns.iter().enumerate() { + // NOTE: the column index starts at the 3rd bit + let column_null_idx = column_idx + 2; + let is_null = + null_bitmap[column_null_idx / 8] & (1 << (column_null_idx % 8) as u8) != 0; + + if is_null { + values.push(None); + continue; + } + + // NOTE: MySQL will never generate NULL types for non-NULL values + let type_info = column.type_info.as_ref().unwrap(); + + let size: usize = match type_info.r#type { + ColumnType::String + | ColumnType::VarChar + | ColumnType::VarString + | ColumnType::Enum + | ColumnType::Set + | ColumnType::LongBlob + | ColumnType::MediumBlob + | ColumnType::Blob + | ColumnType::TinyBlob + | ColumnType::Geometry + | ColumnType::Bit + | ColumnType::Decimal + | ColumnType::Json + | ColumnType::NewDecimal => buf.get_uint_lenenc() as usize, + + ColumnType::LongLong => 8, + ColumnType::Long | ColumnType::Int24 => 4, + ColumnType::Short | ColumnType::Year => 2, + ColumnType::Tiny => 1, + ColumnType::Float => 4, + ColumnType::Double => 8, + + ColumnType::Time + | ColumnType::Timestamp + | ColumnType::Date + | ColumnType::Datetime => { + // The size of this type is important for decoding + buf[0] as usize + 1 + } + + // NOTE: MySQL will never generate NULL types for non-NULL values + ColumnType::Null => unreachable!(), + }; + + let offset = offset - buf.len(); + + values.push(Some(offset..(offset + size))); + + buf.advance(size); + } + + Ok(BinaryRow(Row { values, storage })) + } +} diff --git a/sqlx-core/src/mysql/protocol/text/column.rs b/sqlx-core/src/mysql/protocol/text/column.rs new file mode 100644 index 000000000..d7d421879 --- /dev/null +++ b/sqlx-core/src/mysql/protocol/text/column.rs @@ -0,0 +1,244 @@ +use std::str::from_utf8; + +use bitflags::bitflags; +use bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::io::Decode; +use crate::mysql::io::MySqlBufExt; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/group__group__cs__column__definition__flags.html + +bitflags! { + #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] + pub(crate) struct ColumnFlags: u16 { + /// Field can't be `NULL`. + const NOT_NULL = 1; + + /// Field is part of a primary key. + const PRIMARY_KEY = 2; + + /// Field is part of a unique key. + const UNIQUE_KEY = 4; + + /// Field is part of a multi-part unique or primary key. + const MULTIPLE_KEY = 8; + + /// Field is a blob. + const BLOB = 16; + + /// Field is unsigned. + const UNSIGNED = 32; + + /// Field is zero filled. + const ZEROFILL = 64; + + /// Field is binary. + const BINARY = 128; + + /// Field is an enumeration. + const ENUM = 256; + + /// Field is an auto-incement field. + const AUTO_INCREMENT = 512; + + /// Field is a timestamp. + const TIMESTAMP = 1024; + + /// Field is a set. + const SET = 2048; + + /// Field does not have a default value. + const NO_DEFAULT_VALUE = 4096; + + /// Field is set to NOW on UPDATE. + const ON_UPDATE_NOW = 8192; + + /// Field is a number. + const NUM = 32768; + } +} + +// https://dev.mysql.com/doc/internals/en/com-query-response.html#column-type + +#[derive(Debug, Copy, Clone, PartialEq)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +#[repr(u8)] +pub enum ColumnType { + Decimal = 0x00, + Tiny = 0x01, + Short = 0x02, + Long = 0x03, + Float = 0x04, + Double = 0x05, + Null = 0x06, + Timestamp = 0x07, + LongLong = 0x08, + Int24 = 0x09, + Date = 0x0a, + Time = 0x0b, + Datetime = 0x0c, + Year = 0x0d, + VarChar = 0x0f, + Bit = 0x10, + Json = 0xf5, + NewDecimal = 0xf6, + Enum = 0xf7, + Set = 0xf8, + TinyBlob = 0xf9, + MediumBlob = 0xfa, + LongBlob = 0xfb, + Blob = 0xfc, + VarString = 0xfd, + String = 0xfe, + Geometry = 0xff, +} + +// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_query_response_text_resultset_column_definition.html +// https://mariadb.com/kb/en/resultset/#column-definition-packet +// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 + +#[derive(Debug)] +pub(crate) struct ColumnDefinition { + catalog: Bytes, + schema: Bytes, + table_alias: Bytes, + table: Bytes, + alias: Bytes, + name: Bytes, + pub(crate) char_set: u16, + max_size: u32, + pub(crate) r#type: ColumnType, + pub(crate) flags: ColumnFlags, + decimals: u8, +} + +impl ColumnDefinition { + // NOTE: strings in-protocol are transmitted according to the client character set + // as this is UTF-8, all these strings should be UTF-8 + + pub(crate) fn name(&self) -> Result<&str, Error> { + from_utf8(&self.name).map_err(Error::protocol) + } + + pub(crate) fn alias(&self) -> Result<&str, Error> { + from_utf8(&self.alias).map_err(Error::protocol) + } +} + +impl Decode<'_, Capabilities> for ColumnDefinition { + fn decode_with(mut buf: Bytes, _: Capabilities) -> Result { + let catalog = buf.get_bytes_lenenc(); + let schema = buf.get_bytes_lenenc(); + let table_alias = buf.get_bytes_lenenc(); + let table = buf.get_bytes_lenenc(); + let alias = buf.get_bytes_lenenc(); + let name = buf.get_bytes_lenenc(); + let _next_len = buf.get_uint_lenenc(); // always 0x0c + let char_set = buf.get_u16_le(); + let max_size = buf.get_u32_le(); + let type_id = buf.get_u8(); + let flags = buf.get_u16_le(); + let decimals = buf.get_u8(); + + Ok(Self { + catalog, + schema, + table_alias, + table, + alias, + name, + char_set, + max_size, + r#type: ColumnType::try_from_u16(type_id)?, + flags: ColumnFlags::from_bits_truncate(flags), + decimals, + }) + } +} + +impl ColumnType { + pub(crate) fn name(self, char_set: u16) -> &'static str { + let is_binary = char_set == 63; + match self { + ColumnType::Tiny => "TINYINT", + ColumnType::Short => "SMALLINT", + ColumnType::Long => "INT", + ColumnType::Float => "FLOAT", + ColumnType::Double => "DOUBLE", + ColumnType::Null => "NULL", + ColumnType::Timestamp => "TIMESTAMP", + ColumnType::LongLong => "BIGINT", + ColumnType::Int24 => "MEDIUMINT", + ColumnType::Date => "DATE", + ColumnType::Time => "TIME", + ColumnType::Datetime => "DATETIME", + ColumnType::Year => "YEAR", + ColumnType::Bit => "BIT", + ColumnType::Enum => "ENUM", + ColumnType::Set => "SET", + ColumnType::Decimal | ColumnType::NewDecimal => "DECIMAL", + ColumnType::Geometry => "GEOMETRY", + ColumnType::Json => "JSON", + + ColumnType::String if is_binary => "BINARY", + ColumnType::VarChar | ColumnType::VarString if is_binary => "VARBINARY", + + ColumnType::String => "CHAR", + ColumnType::VarChar | ColumnType::VarString => "VARCHAR", + + ColumnType::TinyBlob if is_binary => "TINYBLOB", + ColumnType::TinyBlob => "TINYTEXT", + + ColumnType::Blob if is_binary => "BLOB", + ColumnType::Blob => "TEXT", + + ColumnType::MediumBlob if is_binary => "MEDIUMBLOB", + ColumnType::MediumBlob => "MEDIUMTEXT", + + ColumnType::LongBlob if is_binary => "LONGBLOB", + ColumnType::LongBlob => "LONGTEXT", + } + } + + pub(crate) fn try_from_u16(id: u8) -> Result { + Ok(match id { + 0x00 => ColumnType::Decimal, + 0x01 => ColumnType::Tiny, + 0x02 => ColumnType::Short, + 0x03 => ColumnType::Long, + 0x04 => ColumnType::Float, + 0x05 => ColumnType::Double, + 0x06 => ColumnType::Null, + 0x07 => ColumnType::Timestamp, + 0x08 => ColumnType::LongLong, + 0x09 => ColumnType::Int24, + 0x0a => ColumnType::Date, + 0x0b => ColumnType::Time, + 0x0c => ColumnType::Datetime, + 0x0d => ColumnType::Year, + // [internal] 0x0e => ColumnType::NewDate, + 0x0f => ColumnType::VarChar, + 0x10 => ColumnType::Bit, + // [internal] 0x11 => ColumnType::Timestamp2, + // [internal] 0x12 => ColumnType::Datetime2, + // [internal] 0x13 => ColumnType::Time2, + 0xf5 => ColumnType::Json, + 0xf6 => ColumnType::NewDecimal, + 0xf7 => ColumnType::Enum, + 0xf8 => ColumnType::Set, + 0xf9 => ColumnType::TinyBlob, + 0xfa => ColumnType::MediumBlob, + 0xfb => ColumnType::LongBlob, + 0xfc => ColumnType::Blob, + 0xfd => ColumnType::VarString, + 0xfe => ColumnType::String, + 0xff => ColumnType::Geometry, + + _ => { + return Err(err_protocol!("unknown column type 0x{:02x}", id)); + } + }) + } +} diff --git a/sqlx-core/src/mysql/protocol/text/mod.rs b/sqlx-core/src/mysql/protocol/text/mod.rs new file mode 100644 index 000000000..2286ee890 --- /dev/null +++ b/sqlx-core/src/mysql/protocol/text/mod.rs @@ -0,0 +1,11 @@ +mod column; +mod ping; +mod query; +mod quit; +mod row; + +pub(crate) use column::{ColumnDefinition, ColumnFlags, ColumnType}; +pub(crate) use ping::Ping; +pub(crate) use query::Query; +pub(crate) use quit::Quit; +pub(crate) use row::TextRow; diff --git a/sqlx-core/src/mysql/protocol/text/ping.rs b/sqlx-core/src/mysql/protocol/text/ping.rs new file mode 100644 index 000000000..ad3b08844 --- /dev/null +++ b/sqlx-core/src/mysql/protocol/text/ping.rs @@ -0,0 +1,13 @@ +use crate::io::Encode; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/internals/en/com-ping.html + +#[derive(Debug)] +pub(crate) struct Ping; + +impl Encode<'_, Capabilities> for Ping { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) { + buf.push(0x0e); // COM_PING + } +} diff --git a/sqlx-core/src/mysql/protocol/text/query.rs b/sqlx-core/src/mysql/protocol/text/query.rs new file mode 100644 index 000000000..15edeb6ff --- /dev/null +++ b/sqlx-core/src/mysql/protocol/text/query.rs @@ -0,0 +1,14 @@ +use crate::io::Encode; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/internals/en/com-query.html + +#[derive(Debug)] +pub(crate) struct Query<'q>(pub(crate) &'q str); + +impl Encode<'_, Capabilities> for Query<'_> { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) { + buf.push(0x03); // COM_QUERY + buf.extend(self.0.as_bytes()) + } +} diff --git a/sqlx-core/src/mysql/protocol/text/quit.rs b/sqlx-core/src/mysql/protocol/text/quit.rs new file mode 100644 index 000000000..86a8c49a6 --- /dev/null +++ b/sqlx-core/src/mysql/protocol/text/quit.rs @@ -0,0 +1,13 @@ +use crate::io::Encode; +use crate::mysql::protocol::Capabilities; + +// https://dev.mysql.com/doc/internals/en/com-quit.html + +#[derive(Debug)] +pub(crate) struct Quit; + +impl Encode<'_, Capabilities> for Quit { + fn encode_with(&self, buf: &mut Vec, _: Capabilities) { + buf.push(0x01); // COM_QUIT + } +} diff --git a/sqlx-core/src/mysql/protocol/text/row.rs b/sqlx-core/src/mysql/protocol/text/row.rs new file mode 100644 index 000000000..c1a1e1568 --- /dev/null +++ b/sqlx-core/src/mysql/protocol/text/row.rs @@ -0,0 +1,36 @@ +use bytes::{Buf, Bytes}; + +use crate::error::Error; +use crate::io::Decode; +use crate::mysql::io::MySqlBufExt; +use crate::mysql::protocol::Row; +use crate::mysql::row::MySqlColumn; + +#[derive(Debug)] +pub(crate) struct TextRow(pub(crate) Row); + +impl<'de> Decode<'de, &'de [MySqlColumn]> for TextRow { + fn decode_with(mut buf: Bytes, columns: &'de [MySqlColumn]) -> Result { + let storage = buf.clone(); + let offset = buf.len(); + + let mut values = Vec::with_capacity(columns.len()); + + for _ in columns { + if buf[0] == 0xfb { + // NULL is sent as 0xfb + values.push(None); + buf.advance(1); + } else { + let size = buf.get_uint_lenenc() as usize; + let offset = offset - buf.len(); + + values.push(Some(offset..(offset + size))); + + buf.advance(size); + } + } + + Ok(TextRow(Row { values, storage })) + } +} diff --git a/sqlx-core/src/mysql/protocol/type.rs b/sqlx-core/src/mysql/protocol/type.rs index 5b92a951d..8b1378917 100644 --- a/sqlx-core/src/mysql/protocol/type.rs +++ b/sqlx-core/src/mysql/protocol/type.rs @@ -1,48 +1 @@ -// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/binary__log__types_8h.html -// https://mariadb.com/kb/en/library/resultset/#field-types -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] -pub struct TypeId(pub u8); -// https://github.com/google/mysql/blob/c01fc2134d439282a21a2ddf687566e198ddee28/include/mysql_com.h#L429 -impl TypeId { - pub const NULL: TypeId = TypeId(6); - - // String: CHAR, VARCHAR, TEXT - // Bytes: BINARY, VARBINARY, BLOB - pub const CHAR: TypeId = TypeId(254); // or BINARY - pub const VAR_CHAR: TypeId = TypeId(253); // or VAR_BINARY - pub const TEXT: TypeId = TypeId(252); // or BLOB - - // Enum - pub const ENUM: TypeId = TypeId(247); - - // More Bytes - pub const TINY_BLOB: TypeId = TypeId(249); - pub const MEDIUM_BLOB: TypeId = TypeId(250); - pub const LONG_BLOB: TypeId = TypeId(251); - - // Numeric: TINYINT, SMALLINT, INT, BIGINT - pub const TINY_INT: TypeId = TypeId(1); - pub const SMALL_INT: TypeId = TypeId(2); - pub const INT: TypeId = TypeId(3); - pub const BIG_INT: TypeId = TypeId(8); - // pub const MEDIUM_INT: TypeId = TypeId(9); - - // Numeric: FLOAT, DOUBLE - pub const FLOAT: TypeId = TypeId(4); - pub const DOUBLE: TypeId = TypeId(5); - pub const NEWDECIMAL: TypeId = TypeId(246); - - // Date/Time: DATE, TIME, DATETIME, TIMESTAMP - pub const DATE: TypeId = TypeId(10); - pub const TIME: TypeId = TypeId(11); - pub const DATETIME: TypeId = TypeId(12); - pub const TIMESTAMP: TypeId = TypeId(7); -} - -impl Default for TypeId { - fn default() -> TypeId { - TypeId::NULL - } -} diff --git a/sqlx-core/src/mysql/row.rs b/sqlx-core/src/mysql/row.rs index 82cb141ce..5fe703cd6 100644 --- a/sqlx-core/src/mysql/row.rs +++ b/sqlx-core/src/mysql/row.rs @@ -1,59 +1,60 @@ -use std::collections::HashMap; use std::sync::Arc; -use crate::mysql::protocol; -use crate::mysql::{MySql, MySqlValue}; +use hashbrown::HashMap; + +use crate::error::Error; +use crate::ext::ustr::UStr; +use crate::mysql::{protocol, MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef}; use crate::row::{ColumnIndex, Row}; -pub struct MySqlRow<'c> { - pub(super) row: protocol::Row<'c>, - pub(super) names: Arc, u16>>, +// TODO: Merge with the other XXColumn types +#[derive(Debug, Clone)] +pub(crate) struct MySqlColumn { + pub(crate) name: Option, + pub(crate) type_info: Option, } -impl crate::row::private_row::Sealed for MySqlRow<'_> {} +/// Implementation of [`Row`] for MySQL. +#[derive(Debug)] +pub struct MySqlRow { + pub(crate) row: protocol::Row, + pub(crate) columns: Arc>, + pub(crate) column_names: Arc>, + pub(crate) format: MySqlValueFormat, +} -impl<'c> Row<'c> for MySqlRow<'c> { +impl crate::row::private_row::Sealed for MySqlRow {} + +impl Row for MySqlRow { type Database = MySql; + #[inline] fn len(&self) -> usize { self.row.len() } - #[doc(hidden)] - fn try_get_raw(&self, index: I) -> crate::Result> + fn try_get_raw(&self, index: I) -> Result where - I: ColumnIndex<'c, Self>, + I: ColumnIndex, { let index = index.index(self)?; - let column_ty = self.row.columns[index].clone(); - let buffer = self.row.get(index); - let value = match (self.row.binary, buffer) { - (_, None) => MySqlValue::null(), - (true, Some(buf)) => MySqlValue::binary(column_ty, buf), - (false, Some(buf)) => MySqlValue::text(column_ty, buf), - }; + let column = &self.columns[index]; + let value = self.row.get(index); - Ok(value) + Ok(MySqlValueRef { + format: self.format, + row: Some(&self.row.storage), + type_info: column.type_info.clone(), + value, + }) } } -impl<'c> ColumnIndex<'c, MySqlRow<'c>> for usize { - fn index(&self, row: &MySqlRow<'c>) -> crate::Result { - let len = Row::len(row); - - if *self >= len { - return Err(crate::Error::ColumnIndexOutOfBounds { len, index: *self }); - } - - Ok(*self) - } -} - -impl<'c> ColumnIndex<'c, MySqlRow<'c>> for str { - fn index(&self, row: &MySqlRow<'c>) -> crate::Result { - row.names - .get(self) - .ok_or_else(|| crate::Error::ColumnNotFound((*self).into())) - .map(|&index| index as usize) +impl ColumnIndex for &'_ str { + fn index(&self, row: &MySqlRow) -> Result { + row.column_names + .get(*self) + .ok_or_else(|| Error::ColumnNotFound((*self).into())) + .map(|v| *v) } } diff --git a/sqlx-core/src/mysql/stream.rs b/sqlx-core/src/mysql/stream.rs deleted file mode 100644 index 5e0b4d13a..000000000 --- a/sqlx-core/src/mysql/stream.rs +++ /dev/null @@ -1,228 +0,0 @@ -use std::net::Shutdown; - -use byteorder::{ByteOrder, LittleEndian}; - -use crate::io::{Buf, BufMut, BufStream, MaybeTlsStream}; -use crate::mysql::protocol::{Capabilities, Encode, EofPacket, ErrPacket, OkPacket}; - -use crate::mysql::MySqlError; -use crate::url::Url; - -// Size before a packet is split -const MAX_PACKET_SIZE: u32 = 1024; - -pub(crate) struct MySqlStream { - pub(super) stream: BufStream, - - // Is the stream ready to send commands - // Put another way, are we still expecting an EOF or OK packet to terminate - pub(super) is_ready: bool, - - // Active capabilities - pub(super) capabilities: Capabilities, - - // Packets in a command sequence have an incrementing sequence number - // This number must be 0 at the start of each command - pub(super) seq_no: u8, - - // Packets are buffered into a second buffer from the stream - // as we may have compressed or split packets to figure out before - // decoding - packet_buf: Vec, - packet_len: usize, -} - -impl MySqlStream { - pub(super) async fn new(url: &Url) -> crate::Result { - let host = url.host().unwrap_or("localhost"); - let port = url.port(3306); - let stream = MaybeTlsStream::connect(host, port).await?; - - let mut capabilities = Capabilities::PROTOCOL_41 - | Capabilities::IGNORE_SPACE - | Capabilities::DEPRECATE_EOF - | Capabilities::FOUND_ROWS - | Capabilities::TRANSACTIONS - | Capabilities::SECURE_CONNECTION - | Capabilities::PLUGIN_AUTH_LENENC_DATA - | Capabilities::MULTI_STATEMENTS - | Capabilities::MULTI_RESULTS - | Capabilities::PLUGIN_AUTH; - - if url.database().is_some() { - capabilities |= Capabilities::CONNECT_WITH_DB; - } - - if cfg!(feature = "tls") { - capabilities |= Capabilities::SSL; - } - - Ok(Self { - capabilities, - stream: BufStream::new(stream), - packet_buf: Vec::with_capacity(MAX_PACKET_SIZE as usize), - packet_len: 0, - seq_no: 0, - is_ready: true, - }) - } - - pub(super) fn is_tls(&self) -> bool { - self.stream.is_tls() - } - - pub(super) fn shutdown(&self) -> crate::Result<()> { - Ok(self.stream.shutdown(Shutdown::Both)?) - } - - #[inline] - pub(super) async fn send(&mut self, packet: T, initial: bool) -> crate::Result<()> - where - T: Encode + std::fmt::Debug, - { - if initial { - self.seq_no = 0; - } - - self.write(packet); - self.flush().await - } - - #[inline] - pub(super) async fn flush(&mut self) -> crate::Result<()> { - Ok(self.stream.flush().await?) - } - - /// Write the packet to the buffered stream ( do not send to the server ) - pub(super) fn write(&mut self, packet: T) - where - T: Encode, - { - let buf = self.stream.buffer_mut(); - - // Allocate room for the header that we write after the packet; - // so, we can get an accurate and cheap measure of packet length - - let header_offset = buf.len(); - buf.advance(4); - - packet.encode(buf, self.capabilities); - - // Determine length of encoded packet - // and write to allocated header - - let len = buf.len() - header_offset - 4; - let mut header = &mut buf[header_offset..]; - - LittleEndian::write_u32(&mut header, len as u32); - - // Take the last sequence number received, if any, and increment by 1 - // If there was no sequence number, we only increment if we split packets - header[3] = self.seq_no; - self.seq_no = self.seq_no.wrapping_add(1); - } - - #[inline] - pub(super) async fn receive(&mut self) -> crate::Result<&[u8]> { - self.read().await?; - - Ok(self.packet()) - } - - pub(super) async fn read(&mut self) -> crate::Result<()> { - self.packet_buf.clear(); - self.packet_len = 0; - - // Read the packet header which contains the length and the sequence number - // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html - // https://mariadb.com/kb/en/library/0-packet/#standard-packet - let mut header = self.stream.peek(4_usize).await?; - - self.packet_len = header.get_uint::(3)? as usize; - self.seq_no = header.get_u8()?.wrapping_add(1); - - self.stream.consume(4); - - // Read the packet body and copy it into our internal buf - // We must have a separate buffer around the stream as we can't operate directly - // on bytes returned from the stream. We have various kinds of payload manipulation - // that must be handled before decoding. - let payload = self.stream.peek(self.packet_len).await?; - - self.packet_buf.reserve(payload.len()); - self.packet_buf.extend_from_slice(payload); - - self.stream.consume(self.packet_len); - - // TODO: Implement packet compression - // TODO: Implement packet joining - - Ok(()) - } - - /// Returns a reference to the most recently received packet data. - /// A call to `read` invalidates this buffer. - #[inline] - pub(super) fn packet(&self) -> &[u8] { - &self.packet_buf[..self.packet_len] - } -} - -impl MySqlStream { - pub(crate) async fn maybe_receive_eof(&mut self) -> crate::Result<()> { - if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) { - let _eof = EofPacket::read(self.receive().await?)?; - } - - Ok(()) - } - - pub(crate) fn maybe_handle_eof(&mut self) -> crate::Result> { - if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) && self.packet()[0] == 0xFE { - Ok(Some(EofPacket::read(self.packet())?)) - } else { - Ok(None) - } - } - - pub(crate) fn handle_unexpected(&mut self) -> crate::Result { - Err(protocol_err!("unexpected packet identifier 0x{:X?}", self.packet()[0]).into()) - } - - pub(crate) fn handle_err(&mut self) -> crate::Result { - self.is_ready = true; - Err(MySqlError(ErrPacket::read(self.packet(), self.capabilities)?).into()) - } - - pub(crate) fn handle_ok(&mut self) -> crate::Result { - self.is_ready = true; - OkPacket::read(self.packet()) - } - - pub(crate) async fn wait_until_ready(&mut self) -> crate::Result<()> { - if !self.is_ready { - loop { - let packet_id = self.receive().await?[0]; - match packet_id { - 0xFE if self.packet().len() < 0xFF_FF_FF => { - // OK or EOF packet - self.is_ready = true; - break; - } - - 0xFF => { - // ERR packet - self.is_ready = true; - return self.handle_err(); - } - - _ => { - // Something else; skip - } - } - } - } - - Ok(()) - } -} diff --git a/sqlx-core/src/mysql/tls.rs b/sqlx-core/src/mysql/tls.rs deleted file mode 100644 index 49a218498..000000000 --- a/sqlx-core/src/mysql/tls.rs +++ /dev/null @@ -1,123 +0,0 @@ -use crate::mysql::stream::MySqlStream; -use crate::url::Url; - -#[cfg_attr(not(feature = "tls"), allow(unused_variables))] -pub(super) async fn upgrade_if_needed(stream: &mut MySqlStream, url: &Url) -> crate::Result<()> { - #[cfg_attr(not(feature = "tls"), allow(unused_imports))] - use crate::mysql::protocol::Capabilities; - - let ca_file = url.param("ssl-ca"); - let ssl_mode = url.param("ssl-mode"); - - // https://dev.mysql.com/doc/refman/5.7/en/connection-options.html#option_general_ssl-mode - match ssl_mode.as_deref() { - Some("DISABLED") => {} - - #[cfg(feature = "tls")] - Some("PREFERRED") | None if !stream.capabilities.contains(Capabilities::SSL) => {} - - #[cfg(feature = "tls")] - Some("PREFERRED") => { - if let Err(_error) = try_upgrade(stream, &url, None, true).await { - // TLS upgrade failed; fall back to a normal connection - } - } - - #[cfg(feature = "tls")] - None => { - if let Err(_error) = try_upgrade(stream, &url, ca_file.as_deref(), true).await { - // TLS upgrade failed; fall back to a normal connection - } - } - - #[cfg(feature = "tls")] - Some("REQUIRED") | Some("VERIFY_CA") | Some("VERIFY_IDENTITY") - if !stream.capabilities.contains(Capabilities::SSL) => - { - return Err(tls_err!("server does not support TLS").into()); - } - - #[cfg(feature = "tls")] - Some("VERIFY_CA") | Some("VERIFY_IDENTITY") if ca_file.is_none() => { - return Err( - tls_err!("`ssl-mode` of {:?} requires `ssl-ca` to be set", ssl_mode).into(), - ); - } - - #[cfg(feature = "tls")] - Some(mode @ "REQUIRED") | Some(mode @ "VERIFY_CA") | Some(mode @ "VERIFY_IDENTITY") => { - try_upgrade( - stream, - url, - // false for both verify-ca and verify-full - ca_file.as_deref(), - // false for only verify-full - mode != "VERIFY_IDENTITY", - ) - .await?; - } - - #[cfg(not(feature = "tls"))] - None => { - // The user neither explicitly enabled TLS in the connection string - // nor did they turn the `tls` feature on - } - - #[cfg(not(feature = "tls"))] - Some(mode @ "PREFERRED") - | Some(mode @ "REQUIRED") - | Some(mode @ "VERIFY_CA") - | Some(mode @ "VERIFY_IDENTITY") => { - return Err(tls_err!( - "ssl-mode {:?} unsupported; SQLx was compiled without `tls` feature", - mode - ) - .into()); - } - - Some(mode) => { - return Err(tls_err!("unknown `ssl-mode` value: {:?}", mode).into()); - } - } - - Ok(()) -} - -#[cfg(feature = "tls")] -async fn try_upgrade( - stream: &mut MySqlStream, - url: &Url, - ca_file: Option<&str>, - accept_invalid_hostnames: bool, -) -> crate::Result<()> { - use crate::mysql::protocol::SslRequest; - use crate::runtime::fs; - - use async_native_tls::{Certificate, TlsConnector}; - - let mut connector = TlsConnector::new() - .danger_accept_invalid_certs(ca_file.is_none()) - .danger_accept_invalid_hostnames(accept_invalid_hostnames); - - if let Some(ca_file) = ca_file { - let root_cert = fs::read(ca_file).await?; - - connector = connector.add_root_certificate(Certificate::from_pem(&root_cert)?); - } - - // send upgrade request and then immediately try TLS handshake - stream - .send( - SslRequest { - client_collation: super::connection::COLLATE_UTF8MB4_UNICODE_CI, - max_packet_size: super::connection::MAX_PACKET_SIZE, - }, - false, - ) - .await?; - - stream - .stream - .upgrade(url.host().unwrap_or("localhost"), connector) - .await -} diff --git a/sqlx-core/src/mysql/type_info.rs b/sqlx-core/src/mysql/type_info.rs index 10c8ba0c4..d3f882aa2 100644 --- a/sqlx-core/src/mysql/type_info.rs +++ b/sqlx-core/src/mysql/type_info.rs @@ -1,144 +1,67 @@ -use std::fmt::{self, Display}; +use std::fmt::{self, Display, Formatter}; -use crate::mysql::protocol::{ColumnDefinition, FieldFlags, TypeId}; -use crate::types::TypeInfo; +use crate::mysql::protocol::text::{ColumnDefinition, ColumnFlags, ColumnType}; +use crate::type_info::TypeInfo; -#[derive(Clone, Debug, Default)] +/// Type information for a MySql type. +#[derive(Debug, Clone)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub struct MySqlTypeInfo { - pub(crate) id: TypeId, - pub(crate) is_unsigned: bool, - pub(crate) is_binary: bool, + pub(crate) r#type: ColumnType, + pub(crate) flags: ColumnFlags, pub(crate) char_set: u16, } impl MySqlTypeInfo { - pub(crate) const fn new(id: TypeId) -> Self { + pub(crate) const fn binary(ty: ColumnType) -> Self { Self { - id, - is_unsigned: false, - is_binary: true, - char_set: 0, + r#type: ty, + flags: ColumnFlags::BINARY, + char_set: 63, } } - pub(crate) const fn unsigned(id: TypeId) -> Self { - Self { - id, - is_unsigned: true, - is_binary: false, - char_set: 0, - } - } - - #[doc(hidden)] - pub const fn r#enum() -> Self { - Self { - id: TypeId::ENUM, - is_unsigned: false, - is_binary: false, - char_set: 0, - } - } - - pub(crate) fn from_nullable_column_def(def: &ColumnDefinition) -> Self { - Self { - id: def.type_id, - is_unsigned: def.flags.contains(FieldFlags::UNSIGNED), - is_binary: def.flags.contains(FieldFlags::BINARY), - char_set: def.char_set, - } - } - - pub(crate) fn from_column_def(def: &ColumnDefinition) -> Option { - if def.type_id == TypeId::NULL { - return None; - } - - Some(Self::from_nullable_column_def(def)) - } - - #[doc(hidden)] - pub fn type_feature_gate(&self) -> Option<&'static str> { - match self.id { - TypeId::DATE | TypeId::TIME | TypeId::DATETIME | TypeId::TIMESTAMP => Some("chrono"), - _ => None, + pub(crate) fn from_column(column: &ColumnDefinition) -> Option { + if column.r#type == ColumnType::Null { + None + } else { + Some(Self { + r#type: column.r#type, + flags: column.flags, + char_set: column.char_set, + }) } } } impl Display for MySqlTypeInfo { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.id { - TypeId::NULL => f.write_str("NULL"), + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str(self.r#type.name(self.char_set))?; - TypeId::TINY_INT if self.is_unsigned => f.write_str("TINYINT UNSIGNED"), - TypeId::SMALL_INT if self.is_unsigned => f.write_str("SMALLINT UNSIGNED"), - TypeId::INT if self.is_unsigned => f.write_str("INT UNSIGNED"), - TypeId::BIG_INT if self.is_unsigned => f.write_str("BIGINT UNSIGNED"), - - TypeId::TINY_INT => f.write_str("TINYINT"), - TypeId::SMALL_INT => f.write_str("SMALLINT"), - TypeId::INT => f.write_str("INT"), - TypeId::BIG_INT => f.write_str("BIGINT"), - - TypeId::FLOAT => f.write_str("FLOAT"), - TypeId::DOUBLE => f.write_str("DOUBLE"), - - TypeId::CHAR if self.is_binary => f.write_str("BINARY"), - TypeId::VAR_CHAR if self.is_binary => f.write_str("VARBINARY"), - TypeId::TEXT if self.is_binary => f.write_str("BLOB"), - - TypeId::CHAR => f.write_str("CHAR"), - TypeId::VAR_CHAR => f.write_str("VARCHAR"), - TypeId::TEXT => f.write_str("TEXT"), - - TypeId::DATE => f.write_str("DATE"), - TypeId::TIME => f.write_str("TIME"), - TypeId::DATETIME => f.write_str("DATETIME"), - TypeId::TIMESTAMP => f.write_str("TIMESTAMP"), - - id => write!(f, "<{:#x}>", id.0), + if self.flags.contains(ColumnFlags::UNSIGNED) { + f.write_str(" UNSIGNED")?; } + + Ok(()) } } +impl TypeInfo for MySqlTypeInfo {} + impl PartialEq for MySqlTypeInfo { fn eq(&self, other: &MySqlTypeInfo) -> bool { - match self.id { - TypeId::VAR_CHAR - | TypeId::TEXT - | TypeId::CHAR - | TypeId::TINY_BLOB - | TypeId::MEDIUM_BLOB - | TypeId::LONG_BLOB - | TypeId::ENUM - if (self.is_binary == other.is_binary) - && match other.id { - TypeId::VAR_CHAR - | TypeId::TEXT - | TypeId::CHAR - | TypeId::TINY_BLOB - | TypeId::MEDIUM_BLOB - | TypeId::LONG_BLOB - | TypeId::ENUM => true, - - _ => false, - } => - { - return true; - } - - _ => {} - } - - if self.id.0 != other.id.0 { + if self.r#type != other.r#type { return false; } - match self.id { - TypeId::TINY_INT | TypeId::SMALL_INT | TypeId::INT | TypeId::BIG_INT => { - return self.is_unsigned == other.is_unsigned; + match self.r#type { + ColumnType::Tiny + | ColumnType::Short + | ColumnType::Long + | ColumnType::Int24 + | ColumnType::LongLong => { + return self.flags.contains(ColumnFlags::UNSIGNED) + == other.flags.contains(ColumnFlags::UNSIGNED); } _ => {} @@ -148,103 +71,4 @@ impl PartialEq for MySqlTypeInfo { } } -impl TypeInfo for MySqlTypeInfo { - fn compatible(&self, other: &Self) -> bool { - // NOTE: MySQL is weakly typed so much of this may be surprising to a Rust developer. - - if self.id == TypeId::NULL || other.id == TypeId::NULL { - // NULL is the "bottom" type - // If the user is trying to select into a non-Option, we catch this soon with an - // UnexpectedNull error message - return true; - } - - match self.id { - // All integer types should be considered compatible - TypeId::TINY_INT | TypeId::SMALL_INT | TypeId::INT | TypeId::BIG_INT - if (self.is_unsigned == other.is_unsigned) - && match other.id { - TypeId::TINY_INT | TypeId::SMALL_INT | TypeId::INT | TypeId::BIG_INT => { - true - } - - _ => false, - } => - { - true - } - - // All textual types should be considered compatible - TypeId::VAR_CHAR - | TypeId::TEXT - | TypeId::CHAR - | TypeId::TINY_BLOB - | TypeId::MEDIUM_BLOB - | TypeId::LONG_BLOB - if match other.id { - TypeId::VAR_CHAR - | TypeId::TEXT - | TypeId::CHAR - | TypeId::TINY_BLOB - | TypeId::MEDIUM_BLOB - | TypeId::LONG_BLOB => true, - - _ => false, - } => - { - true - } - - // Enums are considered compatible with other text/binary types - TypeId::ENUM - if match other.id { - TypeId::VAR_CHAR - | TypeId::TEXT - | TypeId::CHAR - | TypeId::TINY_BLOB - | TypeId::MEDIUM_BLOB - | TypeId::LONG_BLOB - | TypeId::ENUM => true, - - _ => false, - } => - { - true - } - - TypeId::VAR_CHAR - | TypeId::TEXT - | TypeId::CHAR - | TypeId::TINY_BLOB - | TypeId::MEDIUM_BLOB - | TypeId::LONG_BLOB - | TypeId::ENUM - if other.id == TypeId::ENUM => - { - true - } - - // FLOAT is compatible with DOUBLE - TypeId::FLOAT | TypeId::DOUBLE - if match other.id { - TypeId::FLOAT | TypeId::DOUBLE => true, - _ => false, - } => - { - true - } - - // DATETIME is compatible with TIMESTAMP - TypeId::DATETIME | TypeId::TIMESTAMP - if match other.id { - TypeId::DATETIME | TypeId::TIMESTAMP => true, - _ => false, - } => - { - true - } - - _ => self.eq(other), - } - } -} +impl Eq for MySqlTypeInfo {} diff --git a/sqlx-core/src/mysql/types/bigdecimal.rs b/sqlx-core/src/mysql/types/bigdecimal.rs index 5cdcc2136..cef2baf7e 100644 --- a/sqlx-core/src/mysql/types/bigdecimal.rs +++ b/sqlx-core/src/mysql/types/bigdecimal.rs @@ -1,92 +1,30 @@ use bigdecimal::BigDecimal; +use crate::database::{Database, HasArguments}; use crate::decode::Decode; -use crate::encode::Encode; -use crate::io::Buf; -use crate::mysql::protocol::TypeId; -use crate::mysql::{MySql, MySqlData, MySqlTypeInfo, MySqlValue}; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::mysql::io::MySqlBufMutExt; +use crate::mysql::protocol::text::{ColumnFlags, ColumnType}; +use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef}; use crate::types::Type; -use crate::Error; -use std::str::FromStr; impl Type for BigDecimal { fn type_info() -> MySqlTypeInfo { - MySqlTypeInfo::new(TypeId::NEWDECIMAL) + MySqlTypeInfo::binary(ColumnType::NewDecimal) } } -impl Encode for BigDecimal { - fn encode(&self, buf: &mut Vec) { - let size = Encode::::size_hint(self) - 1; - assert!(size <= std::u8::MAX as usize, "Too large size"); - buf.push(size as u8); - let s = self.to_string(); - buf.extend_from_slice(s.as_bytes()); - } +impl Encode<'_, MySql> for BigDecimal { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + buf.put_str_lenenc(&self.to_string()); - fn size_hint(&self) -> usize { - let s = self.to_string(); - s.as_bytes().len() + 1 + IsNull::No } } impl Decode<'_, MySql> for BigDecimal { - fn decode(value: MySqlValue) -> crate::Result { - match value.try_get()? { - MySqlData::Binary(mut binary) => { - let _len = binary.get_u8()?; - let s = std::str::from_utf8(binary).map_err(Error::decode)?; - Ok(BigDecimal::from_str(s).map_err(Error::decode)?) - } - MySqlData::Text(s) => { - let s = std::str::from_utf8(s).map_err(Error::decode)?; - Ok(BigDecimal::from_str(s).map_err(Error::decode)?) - } - } + fn decode(value: MySqlValueRef<'_>) -> Result { + Ok(value.as_str()?.parse()?) } } - -#[test] -fn test_encode_decimal() { - let v: BigDecimal = BigDecimal::from_str("-1.05").unwrap(); - let mut buf: Vec = vec![]; - >::encode(&v, &mut buf); - assert_eq!(buf, vec![0x05, b'-', b'1', b'.', b'0', b'5']); - - let v: BigDecimal = BigDecimal::from_str("-105000").unwrap(); - let mut buf: Vec = vec![]; - >::encode(&v, &mut buf); - assert_eq!(buf, vec![0x07, b'-', b'1', b'0', b'5', b'0', b'0', b'0']); - - let v: BigDecimal = BigDecimal::from_str("0.00105").unwrap(); - let mut buf: Vec = vec![]; - >::encode(&v, &mut buf); - assert_eq!(buf, vec![0x07, b'0', b'.', b'0', b'0', b'1', b'0', b'5']); -} - -#[test] -fn test_decode_decimal() { - let buf: Vec = vec![0x05, b'-', b'1', b'.', b'0', b'5']; - let v = >::decode(MySqlValue::binary( - MySqlTypeInfo::new(TypeId::NEWDECIMAL), - buf.as_slice(), - )) - .unwrap(); - assert_eq!(v.to_string(), "-1.05"); - - let buf: Vec = vec![0x04, b'0', b'.', b'0', b'5']; - let v = >::decode(MySqlValue::binary( - MySqlTypeInfo::new(TypeId::NEWDECIMAL), - buf.as_slice(), - )) - .unwrap(); - assert_eq!(v.to_string(), "0.05"); - - let buf: Vec = vec![0x06, b'-', b'9', b'0', b'0', b'0', b'0']; - let v = >::decode(MySqlValue::binary( - MySqlTypeInfo::new(TypeId::NEWDECIMAL), - buf.as_slice(), - )) - .unwrap(); - assert_eq!(v.to_string(), "-90000"); -} diff --git a/sqlx-core/src/mysql/types/bool.rs b/sqlx-core/src/mysql/types/bool.rs index 56d8657eb..d47a121a0 100644 --- a/sqlx-core/src/mysql/types/bool.rs +++ b/sqlx-core/src/mysql/types/bool.rs @@ -1,34 +1,28 @@ use crate::decode::Decode; -use crate::encode::Encode; -use crate::mysql::protocol::TypeId; -use crate::mysql::type_info::MySqlTypeInfo; -use crate::mysql::{MySql, MySqlData, MySqlValue}; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef}; use crate::types::Type; impl Type for bool { fn type_info() -> MySqlTypeInfo { - MySqlTypeInfo::new(TypeId::TINY_INT) + // MySQL has no actual `BOOLEAN` type, the type is an alias of `TINYINT(1)` + >::type_info() } } -impl Encode for bool { - fn encode(&self, buf: &mut Vec) { - buf.push(*self as u8); +impl Encode<'_, MySql> for bool { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + >::encode(*self as i8, buf) } } -impl<'de> Decode<'de, MySql> for bool { - fn decode(value: MySqlValue<'de>) -> crate::Result { - match value.try_get()? { - MySqlData::Binary(buf) => Ok(buf.get(0).map(|&b| b != 0).unwrap_or_default()), +impl Decode<'_, MySql> for bool { + fn accepts(ty: &MySqlTypeInfo) -> bool { + >::accepts(ty) + } - MySqlData::Text(b"0") => Ok(false), - - MySqlData::Text(b"1") => Ok(true), - - MySqlData::Text(s) => Err(crate::Error::Decode( - format!("unexpected value {:?} for boolean", s).into(), - )), - } + fn decode(value: MySqlValueRef<'_>) -> Result { + Ok(>::decode(value)? != 0) } } diff --git a/sqlx-core/src/mysql/types/bytes.rs b/sqlx-core/src/mysql/types/bytes.rs index 850b1348f..0988f4660 100644 --- a/sqlx-core/src/mysql/types/bytes.rs +++ b/sqlx-core/src/mysql/types/bytes.rs @@ -1,21 +1,42 @@ -use byteorder::LittleEndian; - use crate::decode::Decode; -use crate::encode::Encode; -use crate::mysql::io::BufMutExt; -use crate::mysql::protocol::TypeId; -use crate::mysql::type_info::MySqlTypeInfo; -use crate::mysql::{MySql, MySqlData, MySqlValue}; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::mysql::io::MySqlBufMutExt; +use crate::mysql::protocol::text::ColumnType; +use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef}; use crate::types::Type; impl Type for [u8] { fn type_info() -> MySqlTypeInfo { - MySqlTypeInfo { - id: TypeId::TEXT, - is_binary: true, - is_unsigned: false, - char_set: 63, // binary - } + MySqlTypeInfo::binary(ColumnType::Blob) + } +} + +impl Encode<'_, MySql> for &'_ [u8] { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + buf.put_bytes_lenenc(self); + + IsNull::No + } +} + +impl<'r> Decode<'r, MySql> for &'r [u8] { + fn accepts(ty: &MySqlTypeInfo) -> bool { + matches!( + ty.r#type, + ColumnType::VarChar + | ColumnType::Blob + | ColumnType::TinyBlob + | ColumnType::MediumBlob + | ColumnType::LongBlob + | ColumnType::String + | ColumnType::VarString + | ColumnType::Enum + ) + } + + fn decode(value: MySqlValueRef<'r>) -> Result { + value.as_bytes() } } @@ -25,30 +46,18 @@ impl Type for Vec { } } -impl Encode for [u8] { - fn encode(&self, buf: &mut Vec) { - buf.put_bytes_lenenc::(self); +impl Encode<'_, MySql> for Vec { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + <&[u8] as Encode>::encode(&**self, buf) } } -impl Encode for Vec { - fn encode(&self, buf: &mut Vec) { - <[u8] as Encode>::encode(self, buf); +impl Decode<'_, MySql> for Vec { + fn accepts(ty: &MySqlTypeInfo) -> bool { + <&[u8] as Decode>::accepts(ty) } -} -impl<'de> Decode<'de, MySql> for Vec { - fn decode(value: MySqlValue<'de>) -> crate::Result { - match value.try_get()? { - MySqlData::Binary(buf) | MySqlData::Text(buf) => Ok(buf.to_vec()), - } - } -} - -impl<'de> Decode<'de, MySql> for &'de [u8] { - fn decode(value: MySqlValue<'de>) -> crate::Result { - match value.try_get()? { - MySqlData::Binary(buf) | MySqlData::Text(buf) => Ok(buf), - } + fn decode(value: MySqlValueRef<'_>) -> Result { + <&[u8] as Decode>::decode(value).map(ToOwned::to_owned) } } diff --git a/sqlx-core/src/mysql/types/chrono.rs b/sqlx-core/src/mysql/types/chrono.rs index 490219026..1c77d4fb9 100644 --- a/sqlx-core/src/mysql/types/chrono.rs +++ b/sqlx-core/src/mysql/types/chrono.rs @@ -1,32 +1,35 @@ use std::convert::TryFrom; +use std::str::from_utf8; -use byteorder::{ByteOrder, LittleEndian}; +use bytes::Buf; use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc}; use crate::decode::Decode; -use crate::encode::Encode; -use crate::io::{Buf, BufMut}; -use crate::mysql::protocol::TypeId; +use crate::encode::{Encode, IsNull}; +use crate::error::{BoxDynError, Error}; +use crate::mysql::protocol::text::ColumnType; use crate::mysql::type_info::MySqlTypeInfo; -use crate::mysql::{MySql, MySqlData, MySqlValue}; +use crate::mysql::{MySql, MySqlValue, MySqlValueFormat, MySqlValueRef}; use crate::types::Type; -use crate::Error; -use std::str::from_utf8; impl Type for DateTime { fn type_info() -> MySqlTypeInfo { - MySqlTypeInfo::new(TypeId::TIMESTAMP) + MySqlTypeInfo::binary(ColumnType::Timestamp) } } -impl Encode for DateTime { - fn encode(&self, buf: &mut Vec) { - Encode::::encode(&self.naive_utc(), buf); +impl Encode<'_, MySql> for DateTime { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + Encode::::encode(&self.naive_utc(), buf) } } -impl<'de> Decode<'de, MySql> for DateTime { - fn decode(value: MySqlValue<'de>) -> crate::Result { +impl<'r> Decode<'r, MySql> for DateTime { + fn accepts(ty: &MySqlTypeInfo) -> bool { + matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp) + } + + fn decode(value: MySqlValueRef<'r>) -> Result { let naive: NaiveDateTime = Decode::::decode(value)?; Ok(DateTime::from_utc(naive, Utc)) @@ -35,12 +38,12 @@ impl<'de> Decode<'de, MySql> for DateTime { impl Type for NaiveTime { fn type_info() -> MySqlTypeInfo { - MySqlTypeInfo::new(TypeId::TIME) + MySqlTypeInfo::binary(ColumnType::Time) } } -impl Encode for NaiveTime { - fn encode(&self, buf: &mut Vec) { +impl Encode<'_, MySql> for NaiveTime { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { let len = Encode::::size_hint(self) - 1; buf.push(len as u8); @@ -49,9 +52,11 @@ impl Encode for NaiveTime { // "date on 4 bytes little-endian format" (?) // https://mariadb.com/kb/en/resultset-row/#teimstamp-binary-encoding - buf.advance(4); + buf.extend_from_slice(&[0_u8; 4]); encode_time(self, len > 9, buf); + + IsNull::No } fn size_hint(&self) -> usize { @@ -65,27 +70,29 @@ impl Encode for NaiveTime { } } -impl<'de> Decode<'de, MySql> for NaiveTime { - fn decode(buf: MySqlValue<'de>) -> crate::Result { - match buf.try_get()? { - MySqlData::Binary(mut buf) => { +impl<'r> Decode<'r, MySql> for NaiveTime { + fn decode(value: MySqlValueRef<'r>) -> Result { + match value.format() { + MySqlValueFormat::Binary => { + let mut buf = value.as_bytes()?; + // data length, expecting 8 or 12 (fractional seconds) - let len = buf.get_u8()?; + let len = buf.get_u8(); // is negative : int<1> - let is_negative = buf.get_u8()?; - assert_eq!(is_negative, 0, "Negative dates/times are not supported"); + let is_negative = buf.get_u8(); + debug_assert_eq!(is_negative, 0, "Negative dates/times are not supported"); // "date on 4 bytes little-endian format" (?) // https://mariadb.com/kb/en/resultset-row/#timestamp-binary-encoding buf.advance(4); - decode_time(len - 5, buf) + Ok(decode_time(len - 5, buf)) } - MySqlData::Text(buf) => { - let s = from_utf8(buf).map_err(Error::decode)?; - NaiveTime::parse_from_str(s, "%H:%M:%S%.f").map_err(Error::decode) + MySqlValueFormat::Text => { + let s = value.as_str()?; + NaiveTime::parse_from_str(s, "%H:%M:%S%.f").map_err(Into::into) } } } @@ -93,15 +100,17 @@ impl<'de> Decode<'de, MySql> for NaiveTime { impl Type for NaiveDate { fn type_info() -> MySqlTypeInfo { - MySqlTypeInfo::new(TypeId::DATE) + MySqlTypeInfo::binary(ColumnType::Date) } } -impl Encode for NaiveDate { - fn encode(&self, buf: &mut Vec) { +impl Encode<'_, MySql> for NaiveDate { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { buf.push(4); encode_date(self, buf); + + IsNull::No } fn size_hint(&self) -> usize { @@ -109,14 +118,14 @@ impl Encode for NaiveDate { } } -impl<'de> Decode<'de, MySql> for NaiveDate { - fn decode(buf: MySqlValue<'de>) -> crate::Result { - match buf.try_get()? { - MySqlData::Binary(buf) => Ok(decode_date(&buf[1..])), +impl<'r> Decode<'r, MySql> for NaiveDate { + fn decode(value: MySqlValueRef<'r>) -> Result { + match value.format() { + MySqlValueFormat::Binary => Ok(decode_date(&value.as_bytes()?[1..])), - MySqlData::Text(buf) => { - let s = from_utf8(buf).map_err(Error::decode)?; - NaiveDate::parse_from_str(s, "%Y-%m-%d").map_err(Error::decode) + MySqlValueFormat::Text => { + let s = value.as_str()?; + NaiveDate::parse_from_str(s, "%Y-%m-%d").map_err(Into::into) } } } @@ -124,12 +133,12 @@ impl<'de> Decode<'de, MySql> for NaiveDate { impl Type for NaiveDateTime { fn type_info() -> MySqlTypeInfo { - MySqlTypeInfo::new(TypeId::DATETIME) + MySqlTypeInfo::binary(ColumnType::Datetime) } } -impl Encode for NaiveDateTime { - fn encode(&self, buf: &mut Vec) { +impl Encode<'_, MySql> for NaiveDateTime { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { let len = Encode::::size_hint(self) - 1; buf.push(len as u8); @@ -138,6 +147,8 @@ impl Encode for NaiveDateTime { if len > 4 { encode_time(&self.time(), len > 8, buf); } + + IsNull::No } fn size_hint(&self) -> usize { @@ -162,15 +173,21 @@ impl Encode for NaiveDateTime { } } -impl<'de> Decode<'de, MySql> for NaiveDateTime { - fn decode(buf: MySqlValue<'de>) -> crate::Result { - match buf.try_get()? { - MySqlData::Binary(buf) => { +impl<'r> Decode<'r, MySql> for NaiveDateTime { + fn accepts(ty: &MySqlTypeInfo) -> bool { + matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp) + } + + fn decode(value: MySqlValueRef<'r>) -> Result { + match value.format() { + MySqlValueFormat::Binary => { + let mut buf = value.as_bytes()?; + let len = buf[0]; let date = decode_date(&buf[1..]); let dt = if len > 4 { - date.and_time(decode_time(len - 4, &buf[5..])?) + date.and_time(decode_time(len - 4, &buf[5..])) } else { date.and_hms(0, 0, 0) }; @@ -178,9 +195,9 @@ impl<'de> Decode<'de, MySql> for NaiveDateTime { Ok(dt) } - MySqlData::Text(buf) => { - let s = from_utf8(buf).map_err(Error::decode)?; - NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f").map_err(Error::decode) + MySqlValueFormat::Text => { + let s = value.as_str()?; + NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f").map_err(Into::into) } } } @@ -196,12 +213,9 @@ fn encode_date(date: &NaiveDate, buf: &mut Vec) { buf.push(date.day() as u8); } -fn decode_date(buf: &[u8]) -> NaiveDate { - NaiveDate::from_ymd( - LittleEndian::read_u16(buf) as i32, - buf[2] as u32, - buf[3] as u32, - ) +fn decode_date(mut buf: &[u8]) -> NaiveDate { + let year = buf.get_u16_le(); + NaiveDate::from_ymd(year as i32, buf[0] as u32, buf[1] as u32) } fn encode_time(time: &NaiveTime, include_micros: bool, buf: &mut Vec) { @@ -210,93 +224,21 @@ fn encode_time(time: &NaiveTime, include_micros: bool, buf: &mut Vec) { buf.push(time.second() as u8); if include_micros { - buf.put_u32::((time.nanosecond() / 1000) as u32); + buf.extend(&((time.nanosecond() / 1000) as u32).to_le_bytes()); } } -fn decode_time(len: u8, mut buf: &[u8]) -> crate::Result { - let hour = buf.get_u8()?; - let minute = buf.get_u8()?; - let seconds = buf.get_u8()?; +fn decode_time(len: u8, mut buf: &[u8]) -> NaiveTime { + let hour = buf.get_u8(); + let minute = buf.get_u8(); + let seconds = buf.get_u8(); let micros = if len > 3 { // microseconds : int - buf.get_uint::(buf.len())? + buf.get_uint_le(buf.len()) } else { 0 }; - Ok(NaiveTime::from_hms_micro( - hour as u32, - minute as u32, - seconds as u32, - micros as u32, - )) -} - -#[test] -fn test_encode_date_time() { - let mut buf = Vec::new(); - - // test values from https://dev.mysql.com/doc/internals/en/binary-protocol-value.html - let date1: NaiveDateTime = "2010-10-17T19:27:30.000001".parse().unwrap(); - Encode::::encode(&date1, &mut buf); - assert_eq!(*buf, [11, 218, 7, 10, 17, 19, 27, 30, 1, 0, 0, 0]); - - buf.clear(); - - let date2: NaiveDateTime = "2010-10-17T19:27:30".parse().unwrap(); - Encode::::encode(&date2, &mut buf); - assert_eq!(*buf, [7, 218, 7, 10, 17, 19, 27, 30]); - - buf.clear(); - - let date3: NaiveDateTime = "2010-10-17T00:00:00".parse().unwrap(); - Encode::::encode(&date3, &mut buf); - assert_eq!(*buf, [4, 218, 7, 10, 17]); -} - -#[test] -fn test_decode_date_time() { - // test values from https://dev.mysql.com/doc/internals/en/binary-protocol-value.html - let buf = [11, 218, 7, 10, 17, 19, 27, 30, 1, 0, 0, 0]; - let date1 = >::decode(MySqlValue::binary( - MySqlTypeInfo::default(), - &buf, - )) - .unwrap(); - assert_eq!(date1.to_string(), "2010-10-17 19:27:30.000001"); - - let buf = [7, 218, 7, 10, 17, 19, 27, 30]; - let date2 = >::decode(MySqlValue::binary( - MySqlTypeInfo::default(), - &buf, - )) - .unwrap(); - assert_eq!(date2.to_string(), "2010-10-17 19:27:30"); - - let buf = [4, 218, 7, 10, 17]; - let date3 = >::decode(MySqlValue::binary( - MySqlTypeInfo::default(), - &buf, - )) - .unwrap(); - assert_eq!(date3.to_string(), "2010-10-17 00:00:00"); -} - -#[test] -fn test_encode_date() { - let mut buf = Vec::new(); - let date: NaiveDate = "2010-10-17".parse().unwrap(); - Encode::::encode(&date, &mut buf); - assert_eq!(*buf, [4, 218, 7, 10, 17]); -} - -#[test] -fn test_decode_date() { - let buf = [4, 218, 7, 10, 17]; - let date = - >::decode(MySqlValue::binary(MySqlTypeInfo::default(), &buf)) - .unwrap(); - assert_eq!(date.to_string(), "2010-10-17"); + NaiveTime::from_hms_micro(hour as u32, minute as u32, seconds as u32, micros as u32) } diff --git a/sqlx-core/src/mysql/types/float.rs b/sqlx-core/src/mysql/types/float.rs index cdf2d308a..990bf20ca 100644 --- a/sqlx-core/src/mysql/types/float.rs +++ b/sqlx-core/src/mysql/types/float.rs @@ -1,83 +1,77 @@ -use byteorder::{LittleEndian, ReadBytesExt}; +use byteorder::{ByteOrder, LittleEndian}; use crate::decode::Decode; -use crate::encode::Encode; -use crate::mysql::protocol::TypeId; -use crate::mysql::type_info::MySqlTypeInfo; -use crate::mysql::{MySql, MySqlData, MySqlValue}; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::mysql::protocol::text::ColumnType; +use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef}; use crate::types::Type; -use crate::Error; -use std::str::from_utf8; -/// The equivalent MySQL type for `f32` is `FLOAT`. -/// -/// ### Note -/// While we added support for `f32` as `FLOAT` for completeness, we don't recommend using -/// it for any real-life applications as it cannot precisely represent some fractional values, -/// and may be implicitly widened to `DOUBLE` in some cases, resulting in a slightly different -/// value: -/// -/// ```rust -/// // Widening changes the equivalent decimal value, these two expressions are not equal -/// // (This is expected behavior for floating points and happens both in Rust and in MySQL) -/// assert_ne!(10.2f32 as f64, 10.2f64); -/// ``` +fn real_accepts(ty: &MySqlTypeInfo) -> bool { + matches!(ty.r#type, ColumnType::Float | ColumnType::Double) +} + impl Type for f32 { fn type_info() -> MySqlTypeInfo { - MySqlTypeInfo::new(TypeId::FLOAT) + MySqlTypeInfo::binary(ColumnType::Float) } } -impl Encode for f32 { - fn encode(&self, buf: &mut Vec) { - >::encode(&(self.to_bits() as i32), buf); - } -} - -impl<'de> Decode<'de, MySql> for f32 { - fn decode(value: MySqlValue<'de>) -> crate::Result { - match value.try_get()? { - MySqlData::Binary(mut buf) => buf - .read_i32::() - .map_err(crate::Error::decode) - .map(|value| f32::from_bits(value as u32)), - - MySqlData::Text(s) => from_utf8(s) - .map_err(Error::decode)? - .parse() - .map_err(Error::decode), - } - } -} - -/// The equivalent MySQL type for `f64` is `DOUBLE`. -/// -/// Note that `DOUBLE` is a floating-point type and cannot represent some fractional values -/// exactly. impl Type for f64 { fn type_info() -> MySqlTypeInfo { - MySqlTypeInfo::new(TypeId::DOUBLE) + MySqlTypeInfo::binary(ColumnType::Double) } } -impl Encode for f64 { - fn encode(&self, buf: &mut Vec) { - >::encode(&(self.to_bits() as i64), buf); +impl Encode<'_, MySql> for f32 { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + buf.extend(&self.to_le_bytes()); + + IsNull::No } } -impl<'de> Decode<'de, MySql> for f64 { - fn decode(value: MySqlValue<'de>) -> crate::Result { - match value.try_get()? { - MySqlData::Binary(mut buf) => buf - .read_i64::() - .map_err(crate::Error::decode) - .map(|value| f64::from_bits(value as u64)), +impl Encode<'_, MySql> for f64 { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + buf.extend(&self.to_le_bytes()); - MySqlData::Text(s) => from_utf8(s) - .map_err(Error::decode)? - .parse() - .map_err(Error::decode), - } + IsNull::No + } +} + +impl Decode<'_, MySql> for f32 { + fn accepts(ty: &MySqlTypeInfo) -> bool { + real_accepts(ty) + } + + fn decode(value: MySqlValueRef<'_>) -> Result { + Ok(match value.format() { + MySqlValueFormat::Binary => { + let buf = value.as_bytes()?; + + if buf.len() == 8 { + // MySQL can return 8-byte DOUBLE values for a FLOAT + // We take and truncate to f32 as that's the same behavior as *in* MySQL + LittleEndian::read_f64(buf) as f32 + } else { + LittleEndian::read_f32(buf) + } + } + + MySqlValueFormat::Text => value.as_str()?.parse()?, + }) + } +} + +impl Decode<'_, MySql> for f64 { + fn accepts(ty: &MySqlTypeInfo) -> bool { + real_accepts(ty) + } + + fn decode(value: MySqlValueRef<'_>) -> Result { + Ok(match value.format() { + MySqlValueFormat::Binary => LittleEndian::read_f64(value.as_bytes()?), + MySqlValueFormat::Text => value.as_str()?.parse()?, + }) } } diff --git a/sqlx-core/src/mysql/types/int.rs b/sqlx-core/src/mysql/types/int.rs index 981517847..aa54b7c72 100644 --- a/sqlx-core/src/mysql/types/int.rs +++ b/sqlx-core/src/mysql/types/int.rs @@ -1,111 +1,127 @@ -use std::str::from_utf8; - -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use byteorder::{ByteOrder, LittleEndian}; use crate::decode::Decode; -use crate::encode::Encode; -use crate::mysql::protocol::TypeId; -use crate::mysql::type_info::MySqlTypeInfo; -use crate::mysql::{MySql, MySqlData, MySqlValue}; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::mysql::protocol::text::{ColumnFlags, ColumnType}; +use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef}; use crate::types::Type; -use crate::Error; + +fn int_accepts(ty: &MySqlTypeInfo) -> bool { + matches!( + ty.r#type, + ColumnType::Tiny + | ColumnType::Short + | ColumnType::Long + | ColumnType::Int24 + | ColumnType::LongLong + ) && !ty.flags.contains(ColumnFlags::UNSIGNED) +} impl Type for i8 { fn type_info() -> MySqlTypeInfo { - MySqlTypeInfo::new(TypeId::TINY_INT) - } -} - -impl Encode for i8 { - fn encode(&self, buf: &mut Vec) { - let _ = buf.write_i8(*self); - } -} - -impl<'de> Decode<'de, MySql> for i8 { - fn decode(value: MySqlValue<'de>) -> crate::Result { - match value.try_get()? { - MySqlData::Binary(mut buf) => buf.read_i8().map_err(Into::into), - - MySqlData::Text(s) => from_utf8(s) - .map_err(Error::decode)? - .parse() - .map_err(Error::decode), - } + MySqlTypeInfo::binary(ColumnType::Tiny) } } impl Type for i16 { fn type_info() -> MySqlTypeInfo { - MySqlTypeInfo::new(TypeId::SMALL_INT) - } -} - -impl Encode for i16 { - fn encode(&self, buf: &mut Vec) { - let _ = buf.write_i16::(*self); - } -} - -impl<'de> Decode<'de, MySql> for i16 { - fn decode(value: MySqlValue<'de>) -> crate::Result { - match value.try_get()? { - MySqlData::Binary(mut buf) => buf.read_i16::().map_err(Into::into), - - MySqlData::Text(s) => from_utf8(s) - .map_err(Error::decode)? - .parse() - .map_err(Error::decode), - } + MySqlTypeInfo::binary(ColumnType::Short) } } impl Type for i32 { fn type_info() -> MySqlTypeInfo { - MySqlTypeInfo::new(TypeId::INT) - } -} - -impl Encode for i32 { - fn encode(&self, buf: &mut Vec) { - let _ = buf.write_i32::(*self); - } -} - -impl<'de> Decode<'de, MySql> for i32 { - fn decode(value: MySqlValue<'de>) -> crate::Result { - match value.try_get()? { - MySqlData::Binary(mut buf) => buf.read_i32::().map_err(Into::into), - - MySqlData::Text(s) => from_utf8(s) - .map_err(Error::decode)? - .parse() - .map_err(Error::decode), - } + MySqlTypeInfo::binary(ColumnType::Long) } } impl Type for i64 { fn type_info() -> MySqlTypeInfo { - MySqlTypeInfo::new(TypeId::BIG_INT) + MySqlTypeInfo::binary(ColumnType::LongLong) } } -impl Encode for i64 { - fn encode(&self, buf: &mut Vec) { - let _ = buf.write_i64::(*self); +impl Encode<'_, MySql> for i8 { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + buf.extend(&self.to_le_bytes()); + + IsNull::No } } -impl<'de> Decode<'de, MySql> for i64 { - fn decode(value: MySqlValue<'de>) -> crate::Result { - match value.try_get()? { - MySqlData::Binary(mut buf) => buf.read_i64::().map_err(Into::into), +impl Encode<'_, MySql> for i16 { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + buf.extend(&self.to_le_bytes()); - MySqlData::Text(s) => from_utf8(s) - .map_err(Error::decode)? - .parse() - .map_err(Error::decode), - } + IsNull::No + } +} + +impl Encode<'_, MySql> for i32 { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + buf.extend(&self.to_le_bytes()); + + IsNull::No + } +} + +impl Encode<'_, MySql> for i64 { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + buf.extend(&self.to_le_bytes()); + + IsNull::No + } +} + +impl Decode<'_, MySql> for i8 { + fn accepts(ty: &MySqlTypeInfo) -> bool { + int_accepts(ty) + } + + fn decode(value: MySqlValueRef<'_>) -> Result { + Ok(match value.format() { + MySqlValueFormat::Binary => value.as_bytes()?[0] as i8, + MySqlValueFormat::Text => value.as_str()?.parse()?, + }) + } +} + +impl Decode<'_, MySql> for i16 { + fn accepts(ty: &MySqlTypeInfo) -> bool { + int_accepts(ty) + } + + fn decode(value: MySqlValueRef<'_>) -> Result { + Ok(match value.format() { + MySqlValueFormat::Binary => LittleEndian::read_i16(value.as_bytes()?), + MySqlValueFormat::Text => value.as_str()?.parse()?, + }) + } +} + +impl Decode<'_, MySql> for i32 { + fn accepts(ty: &MySqlTypeInfo) -> bool { + int_accepts(ty) + } + + fn decode(value: MySqlValueRef<'_>) -> Result { + Ok(match value.format() { + MySqlValueFormat::Binary => LittleEndian::read_i32(value.as_bytes()?), + MySqlValueFormat::Text => value.as_str()?.parse()?, + }) + } +} + +impl Decode<'_, MySql> for i64 { + fn accepts(ty: &MySqlTypeInfo) -> bool { + int_accepts(ty) + } + + fn decode(value: MySqlValueRef<'_>) -> Result { + Ok(match value.format() { + MySqlValueFormat::Binary => LittleEndian::read_i64(value.as_bytes()?), + MySqlValueFormat::Text => value.as_str()?.parse()?, + }) } } diff --git a/sqlx-core/src/mysql/types/json.rs b/sqlx-core/src/mysql/types/json.rs index 5be6d68c8..029f2f605 100644 --- a/sqlx-core/src/mysql/types/json.rs +++ b/sqlx-core/src/mysql/types/json.rs @@ -1,46 +1,47 @@ -use crate::decode::Decode; -use crate::encode::Encode; -use crate::mysql::database::MySql; -use crate::mysql::protocol::TypeId; -use crate::mysql::types::*; -use crate::mysql::{MySqlTypeInfo, MySqlValue}; -use crate::types::{Json, Type}; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; -use serde_json::Value as JsonValue; -impl Type for JsonValue { - fn type_info() -> MySqlTypeInfo { - as Type>::type_info() - } -} +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::mysql::protocol::text::ColumnType; +use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef}; +use crate::types::{Json, Type}; impl Type for Json { fn type_info() -> MySqlTypeInfo { - // MySql uses the CHAR type to pass JSON data from and to the client - MySqlTypeInfo::new(TypeId::CHAR) + // MySql uses the `CHAR` type to pass JSON data from and to the client + // NOTE: This is forwards-compatible with MySQL v8+ as CHAR is a common transmission format + // and has nothing to do with the native storage ability of MySQL v8+ + MySqlTypeInfo::binary(ColumnType::String) } } -impl Encode for Json +impl Encode<'_, MySql> for Json where T: Serialize, { - fn encode(&self, buf: &mut Vec) { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { let json_string_value = serde_json::to_string(&self.0).expect("serde_json failed to convert to string"); - >::encode(json_string_value.as_str(), buf); + + <&str as Encode>::encode(json_string_value.as_str(), buf) } } -impl<'de, T> Decode<'de, MySql> for Json +impl<'r, T> Decode<'r, MySql> for Json where - T: 'de, - T: for<'de1> Deserialize<'de1>, + T: 'r + DeserializeOwned, { - fn decode(value: MySqlValue<'de>) -> crate::Result { - let string_value = <&'de str as Decode>::decode(value).unwrap(); + fn accepts(ty: &MySqlTypeInfo) -> bool { + ty.r#type == ColumnType::Json || <&str as Decode>::accepts(ty) + } + + fn decode(value: MySqlValueRef<'r>) -> Result { + let string_value = <&str as Decode>::decode(value)?; + serde_json::from_str(&string_value) .map(Json) - .map_err(crate::Error::decode) + .map_err(Into::into) } } diff --git a/sqlx-core/src/mysql/types/mod.rs b/sqlx-core/src/mysql/types/mod.rs index 2324c5849..940ff4942 100644 --- a/sqlx-core/src/mysql/types/mod.rs +++ b/sqlx-core/src/mysql/types/mod.rs @@ -4,7 +4,7 @@ //! //! | Rust type | MySQL type(s) | //! |---------------------------------------|------------------------------------------------------| -//! | `bool` | TINYINT(1) | +//! | `bool` | TINYINT(1), BOOLEAN | //! | `i8` | TINYINT | //! | `i16` | SMALLINT | //! | `i32` | INT | @@ -47,6 +47,7 @@ //! | Rust type | MySQL type(s) | //! |---------------------------------------|------------------------------------------------------| //! | `bigdecimal::BigDecimal` | DECIMAL | +//! //! ### [`json`](https://crates.io/crates/json) //! //! Requires the `json` Cargo feature flag. @@ -79,19 +80,3 @@ mod time; #[cfg(feature = "json")] mod json; - -use crate::decode::Decode; -use crate::mysql::{MySql, MySqlValue}; - -impl<'de, T> Decode<'de, MySql> for Option -where - T: Decode<'de, MySql>, -{ - fn decode(value: MySqlValue<'de>) -> crate::Result { - Ok(if value.get().is_some() { - Some(>::decode(value)?) - } else { - None - }) - } -} diff --git a/sqlx-core/src/mysql/types/str.rs b/sqlx-core/src/mysql/types/str.rs index a13f50956..86210901e 100644 --- a/sqlx-core/src/mysql/types/str.rs +++ b/sqlx-core/src/mysql/types/str.rs @@ -1,30 +1,46 @@ -use std::str; - -use byteorder::LittleEndian; - use crate::decode::Decode; -use crate::encode::Encode; -use crate::mysql::io::BufMutExt; -use crate::mysql::protocol::TypeId; -use crate::mysql::type_info::MySqlTypeInfo; -use crate::mysql::{MySql, MySqlData, MySqlValue}; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::mysql::io::MySqlBufMutExt; +use crate::mysql::protocol::text::{ColumnFlags, ColumnType}; +use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef}; use crate::types::Type; -use std::str::from_utf8; impl Type for str { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo { - id: TypeId::TEXT, - is_binary: false, - is_unsigned: false, - char_set: 224, // utf8mb4_unicode_ci + r#type: ColumnType::Blob, // TEXT + char_set: 224, // utf8mb4_unicode_ci + flags: ColumnFlags::empty(), } } } -impl Encode for str { - fn encode(&self, buf: &mut Vec) { - buf.put_str_lenenc::(self); +impl Encode<'_, MySql> for &'_ str { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + buf.put_str_lenenc(self); + + IsNull::No + } +} + +impl<'r> Decode<'r, MySql> for &'r str { + fn accepts(ty: &MySqlTypeInfo) -> bool { + matches!( + ty.r#type, + ColumnType::VarChar + | ColumnType::Blob + | ColumnType::TinyBlob + | ColumnType::MediumBlob + | ColumnType::LongBlob + | ColumnType::String + | ColumnType::VarString + | ColumnType::Enum + ) + } + + fn decode(value: MySqlValueRef<'r>) -> Result { + value.as_str() } } @@ -34,24 +50,18 @@ impl Type for String { } } -impl Encode for String { - fn encode(&self, buf: &mut Vec) { - >::encode(self.as_str(), buf) +impl Encode<'_, MySql> for String { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + <&str as Encode>::encode(&**self, buf) } } -impl<'de> Decode<'de, MySql> for &'de str { - fn decode(value: MySqlValue<'de>) -> crate::Result { - match value.try_get()? { - MySqlData::Binary(buf) | MySqlData::Text(buf) => { - from_utf8(buf).map_err(crate::Error::decode) - } - } +impl Decode<'_, MySql> for String { + fn accepts(ty: &MySqlTypeInfo) -> bool { + <&str as Decode>::accepts(ty) } -} -impl<'de> Decode<'de, MySql> for String { - fn decode(value: MySqlValue<'de>) -> crate::Result { - <&'de str as Decode>::decode(value).map(ToOwned::to_owned) + fn decode(value: MySqlValueRef<'_>) -> Result { + <&str as Decode>::decode(value).map(ToOwned::to_owned) } } diff --git a/sqlx-core/src/mysql/types/time.rs b/sqlx-core/src/mysql/types/time.rs index 81d612dec..87cfd1e91 100644 --- a/sqlx-core/src/mysql/types/time.rs +++ b/sqlx-core/src/mysql/types/time.rs @@ -2,33 +2,38 @@ use std::borrow::Cow; use std::convert::TryFrom; use byteorder::{ByteOrder, LittleEndian}; +use bytes::Buf; use time::{Date, OffsetDateTime, PrimitiveDateTime, Time, UtcOffset}; use crate::decode::Decode; -use crate::encode::Encode; -use crate::io::{Buf, BufMut}; -use crate::mysql::protocol::TypeId; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::mysql::protocol::text::ColumnType; use crate::mysql::type_info::MySqlTypeInfo; -use crate::mysql::{MySql, MySqlData, MySqlValue}; +use crate::mysql::{MySql, MySqlValueFormat, MySqlValueRef}; use crate::types::Type; impl Type for OffsetDateTime { fn type_info() -> MySqlTypeInfo { - MySqlTypeInfo::new(TypeId::TIMESTAMP) + MySqlTypeInfo::binary(ColumnType::Timestamp) } } -impl Encode for OffsetDateTime { - fn encode(&self, buf: &mut Vec) { +impl Encode<'_, MySql> for OffsetDateTime { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { let utc_dt = self.to_offset(UtcOffset::UTC); let primitive_dt = PrimitiveDateTime::new(utc_dt.date(), utc_dt.time()); - Encode::::encode(&primitive_dt, buf); + Encode::::encode(&primitive_dt, buf) } } -impl<'de> Decode<'de, MySql> for OffsetDateTime { - fn decode(value: MySqlValue<'de>) -> crate::Result { +impl<'r> Decode<'r, MySql> for OffsetDateTime { + fn accepts(ty: &MySqlTypeInfo) -> bool { + matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp) + } + + fn decode(value: MySqlValueRef<'r>) -> Result { let primitive: PrimitiveDateTime = Decode::::decode(value)?; Ok(primitive.assume_utc()) @@ -37,12 +42,12 @@ impl<'de> Decode<'de, MySql> for OffsetDateTime { impl Type for Time { fn type_info() -> MySqlTypeInfo { - MySqlTypeInfo::new(TypeId::TIME) + MySqlTypeInfo::binary(ColumnType::Time) } } -impl Encode for Time { - fn encode(&self, buf: &mut Vec) { +impl Encode<'_, MySql> for Time { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { let len = Encode::::size_hint(self) - 1; buf.push(len as u8); @@ -51,9 +56,11 @@ impl Encode for Time { // "date on 4 bytes little-endian format" (?) // https://mariadb.com/kb/en/resultset-row/#teimstamp-binary-encoding - buf.advance(4); + buf.extend_from_slice(&[0_u8; 4]); encode_time(self, len > 9, buf); + + IsNull::No } fn size_hint(&self) -> usize { @@ -67,15 +74,17 @@ impl Encode for Time { } } -impl<'de> Decode<'de, MySql> for Time { - fn decode(value: MySqlValue<'de>) -> crate::Result { - match value.try_get()? { - MySqlData::Binary(mut buf) => { +impl<'r> Decode<'r, MySql> for Time { + fn decode(value: MySqlValueRef<'r>) -> Result { + match value.format() { + MySqlValueFormat::Binary => { + let mut buf = value.as_bytes()?; + // data length, expecting 8 or 12 (fractional seconds) - let len = buf.get_u8()?; + let len = buf.get_u8(); // is negative : int<1> - let is_negative = buf.get_u8()?; + let is_negative = buf.get_u8(); assert_eq!(is_negative, 0, "Negative dates/times are not supported"); // "date on 4 bytes little-endian format" (?) @@ -85,8 +94,8 @@ impl<'de> Decode<'de, MySql> for Time { decode_time(len - 5, buf) } - MySqlData::Text(buf) => { - let s = from_utf8(buf).map_err(crate::Error::decode)?; + MySqlValueFormat::Text => { + let s = value.as_str()?; // If there are less than 9 digits after the decimal point // We need to zero-pad @@ -98,7 +107,7 @@ impl<'de> Decode<'de, MySql> for Time { Cow::Borrowed(s) }; - Time::parse(&*s, "%H:%M:%S.%N").map_err(crate::Error::decode) + Time::parse(&*s, "%H:%M:%S.%N").map_err(Into::into) } } } @@ -106,15 +115,17 @@ impl<'de> Decode<'de, MySql> for Time { impl Type for Date { fn type_info() -> MySqlTypeInfo { - MySqlTypeInfo::new(TypeId::DATE) + MySqlTypeInfo::binary(ColumnType::Date) } } -impl Encode for Date { - fn encode(&self, buf: &mut Vec) { +impl Encode<'_, MySql> for Date { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { buf.push(4); encode_date(self, buf); + + IsNull::No } fn size_hint(&self) -> usize { @@ -122,13 +133,13 @@ impl Encode for Date { } } -impl<'de> Decode<'de, MySql> for Date { - fn decode(value: MySqlValue<'de>) -> crate::Result { - match value.try_get()? { - MySqlData::Binary(buf) => decode_date(&buf[1..]), - MySqlData::Text(buf) => { - let s = from_utf8(buf).map_err(crate::Error::decode)?; - Date::parse(s, "%Y-%m-%d").map_err(crate::Error::decode) +impl<'r> Decode<'r, MySql> for Date { + fn decode(value: MySqlValueRef<'r>) -> Result { + match value.format() { + MySqlValueFormat::Binary => decode_date(&value.as_bytes()?[1..]), + MySqlValueFormat::Text => { + let s = value.as_str()?; + Date::parse(s, "%Y-%m-%d").map_err(Into::into) } } } @@ -136,12 +147,12 @@ impl<'de> Decode<'de, MySql> for Date { impl Type for PrimitiveDateTime { fn type_info() -> MySqlTypeInfo { - MySqlTypeInfo::new(TypeId::DATETIME) + MySqlTypeInfo::binary(ColumnType::Datetime) } } -impl Encode for PrimitiveDateTime { - fn encode(&self, buf: &mut Vec) { +impl Encode<'_, MySql> for PrimitiveDateTime { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { let len = Encode::::size_hint(self) - 1; buf.push(len as u8); @@ -150,6 +161,8 @@ impl Encode for PrimitiveDateTime { if len > 4 { encode_time(&self.time(), len > 8, buf); } + + IsNull::No } fn size_hint(&self) -> usize { @@ -169,10 +182,15 @@ impl Encode for PrimitiveDateTime { } } -impl<'de> Decode<'de, MySql> for PrimitiveDateTime { - fn decode(value: MySqlValue<'de>) -> crate::Result { - match value.try_get()? { - MySqlData::Binary(buf) => { +impl<'r> Decode<'r, MySql> for PrimitiveDateTime { + fn accepts(ty: &MySqlTypeInfo) -> bool { + matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp) + } + + fn decode(value: MySqlValueRef<'r>) -> Result { + match value.format() { + MySqlValueFormat::Binary => { + let mut buf = value.as_bytes()?; let len = buf[0]; let date = decode_date(&buf[1..])?; @@ -185,8 +203,8 @@ impl<'de> Decode<'de, MySql> for PrimitiveDateTime { Ok(dt) } - MySqlData::Text(buf) => { - let s = from_utf8(buf).map_err(crate::Error::decode)?; + MySqlValueFormat::Text => { + let s = value.as_str()?; // If there are less than 9 digits after the decimal point // We need to zero-pad @@ -202,7 +220,7 @@ impl<'de> Decode<'de, MySql> for PrimitiveDateTime { Cow::Borrowed(s) }; - PrimitiveDateTime::parse(&*s, "%Y-%m-%d %H:%M:%S.%N").map_err(crate::Error::decode) + PrimitiveDateTime::parse(&*s, "%Y-%m-%d %H:%M:%S.%N").map_err(Into::into) } } } @@ -218,13 +236,13 @@ fn encode_date(date: &Date, buf: &mut Vec) { buf.push(date.day()); } -fn decode_date(buf: &[u8]) -> crate::Result { +fn decode_date(buf: &[u8]) -> Result { Date::try_from_ymd( LittleEndian::read_u16(buf) as i32, buf[2] as u8, buf[3] as u8, ) - .map_err(|e| decode_err!("Error while decoding Date: {}", e)) + .map_err(Into::into) } fn encode_time(time: &Time, include_micros: bool, buf: &mut Vec) { @@ -233,92 +251,22 @@ fn encode_time(time: &Time, include_micros: bool, buf: &mut Vec) { buf.push(time.second()); if include_micros { - buf.put_u32::((time.nanosecond() / 1000) as u32); + buf.extend(&((time.nanosecond() / 1000) as u32).to_le_bytes()); } } -fn decode_time(len: u8, mut buf: &[u8]) -> crate::Result