diff --git a/Cargo.lock b/Cargo.lock index c3b04fa9..82349ed7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -611,6 +611,15 @@ version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" +[[package]] +name = "encoding_rs" +version = "0.8.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8ac63f94732332f44fe654443c46f6375d1939684c17b0afb6cb56b0456e171" +dependencies = [ + "cfg-if", +] + [[package]] name = "env_logger" version = "0.7.1" @@ -2087,6 +2096,7 @@ dependencies = [ "crossbeam-utils 0.7.2", "digest", "either", + "encoding_rs", "futures-channel", "futures-core", "futures-util", diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 1ccff94a..105a39c8 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -19,7 +19,7 @@ default = [ "runtime-async-std" ] postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac", "futures-channel/sink", "futures-util/sink" ] mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ] sqlite = [ "libsqlite3-sys" ] -mssql = [ "uuid" ] +mssql = [ "uuid", "encoding_rs" ] # types all-types = [ "chrono", "time", "bigdecimal", "ipnetwork", "json", "uuid" ] @@ -48,6 +48,7 @@ crossbeam-queue = "0.2.1" crossbeam-channel = "0.4.2" crossbeam-utils = { version = "0.7.2", default-features = false } digest = { version = "0.8.1", default-features = false, optional = true, features = [ "std" ] } +encoding_rs = { version = "0.8.23", optional = true } either = "1.5.3" futures-channel = { version = "0.3.4", default-features = false, features = [ "alloc", "std" ] } futures-core = { version = "0.3.4", default-features = false } diff --git a/sqlx-core/src/mssql/protocol/type_info.rs b/sqlx-core/src/mssql/protocol/type_info.rs index ef15f186..23995280 100644 --- a/sqlx-core/src/mssql/protocol/type_info.rs +++ b/sqlx-core/src/mssql/protocol/type_info.rs @@ -2,11 +2,11 @@ use std::borrow::Cow; use bitflags::bitflags; use bytes::{Buf, Bytes}; +use encoding_rs::Encoding; use crate::encode::Encode; use crate::error::Error; use crate::mssql::MsSql; -use url::quirks::set_search; bitflags! { pub(crate) struct CollationFlags: u8 { @@ -106,6 +106,31 @@ impl TypeInfo { } } + pub(crate) fn encoding(&self) -> Result<&'static Encoding, Error> { + match self.ty { + DataType::NChar | DataType::NVarChar => Ok(encoding_rs::UTF_16LE), + + DataType::VarChar | DataType::Char | DataType::BigChar | DataType::BigVarChar => { + // unwrap: impossible to unwrap here, collation will be set + Ok(match self.collation.unwrap().locale { + // This is the Western encoding for Windows. It is an extension of ISO-8859-1, + // which is known as Latin 1. + 0x0409 => encoding_rs::WINDOWS_1252, + + locale => { + return Err(err_protocol!("unsupported locale 0x{:?}", locale)); + } + }) + } + + _ => { + // default to UTF-8 for anything + // else coming in here + Ok(encoding_rs::UTF_8) + } + } + } + // reads a TYPE_INFO from the buffer pub(crate) fn get(buf: &mut Bytes) -> Result { let ty = DataType::get(buf)?; @@ -445,13 +470,35 @@ impl TypeInfo { _ => unreachable!("invalid size {} for float"), }), - DataType::NVarChar => { - s.push_str("nvarchar("); - let _ = itoa::fmt(&mut *s, self.size / 2); - s.push_str(")"); + DataType::VarChar + | DataType::NVarChar + | DataType::BigVarChar + | DataType::Char + | DataType::BigChar + | DataType::NChar => { + // name + s.push_str(match self.ty { + DataType::VarChar => "varchar", + DataType::NVarChar => "nvarchar", + DataType::BigVarChar => "bigvarchar", + DataType::Char => "char", + DataType::BigChar => "bigchar", + DataType::NChar => "nchar", + + _ => unreachable!(), + }); + + // size + if self.size < 8000 && self.size > 0 { + s.push_str("("); + let _ = itoa::fmt(&mut *s, self.size); + s.push_str(")"); + } else { + s.push_str("(max)"); + } } - _ => unimplemented!("unsupported data type {:?}", self.ty), + _ => unimplemented!("fmt: unsupported data type {:?}", self.ty), } } } @@ -511,8 +558,8 @@ impl DataType { impl Collation { pub(crate) fn get(buf: &mut Bytes) -> Collation { - let locale_sort_version = buf.get_u32(); - let locale = locale_sort_version & 0xF_FFFF; + let locale_sort_version = buf.get_u32_le(); + let locale = locale_sort_version & 0xfffff; let flags = CollationFlags::from_bits_truncate(((locale_sort_version >> 20) & 0xFF) as u8); let version = (locale_sort_version >> 28) as u8; let sort = buf.get_u8(); diff --git a/sqlx-core/src/mssql/types/str.rs b/sqlx-core/src/mssql/types/str.rs index a50791e8..6417efc8 100644 --- a/sqlx-core/src/mssql/types/str.rs +++ b/sqlx-core/src/mssql/types/str.rs @@ -1,11 +1,12 @@ use byteorder::{ByteOrder, LittleEndian}; +use bytes::Buf; use crate::database::{Database, HasArguments, HasValueRef}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::mssql::io::MsSqlBufMutExt; -use crate::mssql::protocol::type_info::{DataType, TypeInfo}; +use crate::mssql::protocol::type_info::{Collation, DataType, TypeInfo}; use crate::mssql::{MsSql, MsSqlTypeInfo, MsSqlValueRef}; use crate::types::Type; @@ -15,6 +16,12 @@ impl Type for str { } } +impl Type for String { + fn type_info() -> MsSqlTypeInfo { + >::type_info() + } +} + impl Encode<'_, MsSql> for &'_ str { fn produces(&self) -> MsSqlTypeInfo { MsSqlTypeInfo(TypeInfo::new(DataType::NVarChar, (self.len() * 2) as u32)) @@ -26,3 +33,37 @@ impl Encode<'_, MsSql> for &'_ str { IsNull::No } } + +impl Encode<'_, MsSql> for String { + fn produces(&self) -> MsSqlTypeInfo { + <&str as Encode>::produces(&self.as_str()) + } + + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + <&str as Encode>::encode_by_ref(&self.as_str(), buf) + } +} + +impl Decode<'_, MsSql> for String { + fn accepts(ty: &MsSqlTypeInfo) -> bool { + matches!( + ty.0.ty, + DataType::NVarChar + | DataType::NChar + | DataType::BigVarChar + | DataType::VarChar + | DataType::BigChar + | DataType::Char + ) + } + + fn decode(value: MsSqlValueRef<'_>) -> Result { + Ok(value + .type_info + .0 + .encoding()? + .decode_without_bom_handling(value.as_bytes()?) + .0 + .into_owned()) + } +}