mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-10-02 15:25:32 +00:00
feat(mssql): implement parameterized queries
This commit is contained in:
parent
9a701313bc
commit
c64122c03f
@ -1,26 +0,0 @@
|
||||
use std::{
|
||||
ascii::escape_default,
|
||||
fmt::{self, Debug},
|
||||
str::from_utf8,
|
||||
};
|
||||
|
||||
// Wrapper type for byte slices that will debug print
|
||||
// as a binary string
|
||||
pub struct ByteStr<'a>(pub &'a [u8]);
|
||||
|
||||
impl Debug for ByteStr<'_> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "b\"")?;
|
||||
|
||||
for &b in self.0 {
|
||||
let part: Vec<u8> = escape_default(b).collect();
|
||||
let s = from_utf8(&part).unwrap();
|
||||
|
||||
write!(f, "{}", s)?;
|
||||
}
|
||||
|
||||
write!(f, "\"")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -1,152 +0,0 @@
|
||||
use std::io;
|
||||
use std::net::Shutdown;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use crate::runtime::{AsyncRead, AsyncWrite, TcpStream};
|
||||
|
||||
use self::Inner::*;
|
||||
|
||||
pub struct MaybeTlsStream {
|
||||
inner: Inner,
|
||||
}
|
||||
|
||||
enum Inner {
|
||||
NotTls(TcpStream),
|
||||
#[cfg(all(feature = "postgres", unix))]
|
||||
UnixStream(crate::runtime::UnixStream),
|
||||
#[cfg(feature = "tls")]
|
||||
Tls(async_native_tls::TlsStream<TcpStream>),
|
||||
#[cfg(feature = "tls")]
|
||||
Upgrading,
|
||||
}
|
||||
|
||||
impl MaybeTlsStream {
|
||||
#[cfg(all(feature = "postgres", unix))]
|
||||
pub async fn connect_uds<S: AsRef<std::ffi::OsStr>>(p: S) -> crate::Result<Self> {
|
||||
let conn = crate::runtime::UnixStream::connect(p.as_ref()).await?;
|
||||
Ok(Self {
|
||||
inner: Inner::UnixStream(conn),
|
||||
})
|
||||
}
|
||||
pub async fn connect(host: &str, port: u16) -> crate::Result<Self> {
|
||||
let conn = TcpStream::connect((host, port)).await?;
|
||||
Ok(Self {
|
||||
inner: Inner::NotTls(conn),
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn is_tls(&self) -> bool {
|
||||
match self.inner {
|
||||
Inner::NotTls(_) => false,
|
||||
#[cfg(all(feature = "postgres", unix))]
|
||||
Inner::UnixStream(_) => false,
|
||||
#[cfg(feature = "tls")]
|
||||
Inner::Tls(_) => true,
|
||||
#[cfg(feature = "tls")]
|
||||
Inner::Upgrading => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
|
||||
pub async fn upgrade(
|
||||
&mut self,
|
||||
host: &str,
|
||||
connector: async_native_tls::TlsConnector,
|
||||
) -> crate::Result<()> {
|
||||
let conn = match std::mem::replace(&mut self.inner, Upgrading) {
|
||||
NotTls(conn) => conn,
|
||||
#[cfg(all(feature = "postgres", unix))]
|
||||
UnixStream(_) => {
|
||||
return Err(tls_err!("TLS is not supported with unix domain sockets").into())
|
||||
}
|
||||
Tls(_) => return Err(tls_err!("connection already upgraded").into()),
|
||||
Upgrading => return Err(tls_err!("connection already failed to upgrade").into()),
|
||||
};
|
||||
|
||||
self.inner = Tls(connector.connect(host, conn).await?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
|
||||
match self.inner {
|
||||
NotTls(ref conn) => conn.shutdown(how),
|
||||
#[cfg(all(feature = "postgres", unix))]
|
||||
UnixStream(ref conn) => conn.shutdown(how),
|
||||
#[cfg(feature = "tls")]
|
||||
Tls(ref conn) => conn.get_ref().shutdown(how),
|
||||
#[cfg(feature = "tls")]
|
||||
// connection already closed
|
||||
Upgrading => Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! forward_pin (
|
||||
($self:ident.$method:ident($($arg:ident),*)) => (
|
||||
match &mut $self.inner {
|
||||
NotTls(ref mut conn) => Pin::new(conn).$method($($arg),*),
|
||||
#[cfg(all(feature = "postgres", unix))]
|
||||
UnixStream(ref mut conn) => Pin::new(conn).$method($($arg),*),
|
||||
#[cfg(feature = "tls")]
|
||||
Tls(ref mut conn) => Pin::new(conn).$method($($arg),*),
|
||||
#[cfg(feature = "tls")]
|
||||
Upgrading => Err(io::Error::new(io::ErrorKind::Other, "connection broken; TLS upgrade failed")).into(),
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
impl AsyncRead for MaybeTlsStream {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
forward_pin!(self.poll_read(cx, buf))
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime-async-std")]
|
||||
fn poll_read_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context,
|
||||
bufs: &mut [std::io::IoSliceMut],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
forward_pin!(self.poll_read_vectored(cx, bufs))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for MaybeTlsStream {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
forward_pin!(self.poll_write(cx, buf))
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
forward_pin!(self.poll_flush(cx))
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime-async-std")]
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
forward_pin!(self.poll_close(cx))
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime-tokio")]
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
forward_pin!(self.poll_shutdown(cx))
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime-async-std")]
|
||||
fn poll_write_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context,
|
||||
bufs: &[std::io::IoSlice],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
forward_pin!(self.poll_write_vectored(cx, bufs))
|
||||
}
|
||||
}
|
@ -1,21 +1,88 @@
|
||||
use crate::arguments::Arguments;
|
||||
use crate::encode::Encode;
|
||||
use crate::mssql::database::MsSql;
|
||||
use crate::mssql::io::MsSqlBufMutExt;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct MsSqlArguments {}
|
||||
pub struct MsSqlArguments {
|
||||
// next ordinal to be used when formatting a positional parameter name
|
||||
pub(crate) ordinal: usize,
|
||||
// temporary string buffer used to format parameter names
|
||||
name: String,
|
||||
pub(crate) data: Vec<u8>,
|
||||
pub(crate) declarations: String,
|
||||
}
|
||||
|
||||
impl MsSqlArguments {
|
||||
pub(crate) fn add_named<'q, T: Encode<'q, MsSql>>(&mut self, name: &str, value: T) {
|
||||
let ty = value.produces();
|
||||
|
||||
let mut ty_name = String::new();
|
||||
ty.0.fmt(&mut ty_name);
|
||||
|
||||
self.data.put_b_varchar(name); // [ParamName]
|
||||
self.data.push(0); // [StatusFlags]
|
||||
|
||||
ty.0.put(&mut self.data); // [TYPE_INFO]
|
||||
ty.0.put_value(&mut self.data, value); // [ParamLenData]
|
||||
}
|
||||
|
||||
pub(crate) fn add_unnamed<'q, T: Encode<'q, MsSql>>(&mut self, value: T) {
|
||||
self.add_named("", value);
|
||||
}
|
||||
|
||||
pub(crate) fn append(&mut self, arguments: &mut MsSqlArguments) {
|
||||
self.ordinal += arguments.ordinal;
|
||||
self.data.append(&mut arguments.data);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'q> Arguments<'q> for MsSqlArguments {
|
||||
type Database = MsSql;
|
||||
|
||||
fn reserve(&mut self, additional: usize, size: usize) {
|
||||
unimplemented!()
|
||||
fn reserve(&mut self, _additional: usize, size: usize) {
|
||||
self.data.reserve(size + 10); // est. 4 chars for name, 1 for status, 1 for TYPE_INFO
|
||||
}
|
||||
|
||||
fn add<T>(&mut self, value: T)
|
||||
where
|
||||
T: 'q + Encode<'q, Self::Database>,
|
||||
{
|
||||
unimplemented!()
|
||||
let ty = value.produces();
|
||||
|
||||
// produce an ordinal parameter name
|
||||
// @p1, @p2, ... @pN
|
||||
|
||||
self.name.clear();
|
||||
self.name.push_str("@p");
|
||||
|
||||
self.ordinal += 1;
|
||||
let _ = itoa::fmt(&mut self.name, self.ordinal);
|
||||
|
||||
let MsSqlArguments {
|
||||
ref name,
|
||||
ref mut declarations,
|
||||
ref mut data,
|
||||
..
|
||||
} = self;
|
||||
|
||||
// add this to our variable declaration list
|
||||
// @p1 int, @p2 nvarchar(10), ...
|
||||
|
||||
if !declarations.is_empty() {
|
||||
declarations.push_str(",");
|
||||
}
|
||||
|
||||
declarations.push_str(name);
|
||||
declarations.push(' ');
|
||||
ty.0.fmt(declarations);
|
||||
|
||||
// write out the parameter
|
||||
|
||||
data.put_b_varchar(name); // [ParamName]
|
||||
data.push(0); // [StatusFlags]
|
||||
|
||||
ty.0.put(data); // [TYPE_INFO]
|
||||
ty.0.put_value(data, value); // [ParamLenData]
|
||||
}
|
||||
}
|
||||
|
@ -31,14 +31,7 @@ impl MsSqlConnection {
|
||||
stream.flush().await?;
|
||||
|
||||
let (_, packet) = stream.recv_packet().await?;
|
||||
let pl = PreLogin::decode(packet)?;
|
||||
|
||||
log::trace!(
|
||||
"acknowledged PRELOGIN from MSSQL v{}.{}.{}",
|
||||
pl.version.major,
|
||||
pl.version.minor,
|
||||
pl.version.build
|
||||
);
|
||||
let _ = PreLogin::decode(packet)?;
|
||||
|
||||
// LOGIN7 defines the authentication rules for use between client and server
|
||||
|
||||
@ -70,12 +63,9 @@ impl MsSqlConnection {
|
||||
// all messages are mostly informational (ENVCHANGE, INFO, LOGINACK)
|
||||
|
||||
match stream.recv_message().await? {
|
||||
Message::LoginAck(ack) => {
|
||||
log::trace!(
|
||||
"established connection to {} {}",
|
||||
ack.program_name,
|
||||
ack.program_version
|
||||
);
|
||||
Message::LoginAck(_) => {
|
||||
// indicates that the login was successful
|
||||
// no action is needed, we are just going to keep waiting till we hit <Done>
|
||||
}
|
||||
|
||||
Message::Done(_) => {
|
||||
|
@ -9,13 +9,38 @@ use crate::error::Error;
|
||||
use crate::executor::{Execute, Executor};
|
||||
use crate::mssql::protocol::message::Message;
|
||||
use crate::mssql::protocol::packet::PacketType;
|
||||
use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest};
|
||||
use crate::mssql::protocol::sql_batch::SqlBatch;
|
||||
use crate::mssql::{MsSql, MsSqlConnection, MsSqlRow};
|
||||
use crate::mssql::{MsSql, MsSqlArguments, MsSqlConnection, MsSqlRow};
|
||||
|
||||
impl MsSqlConnection {
|
||||
async fn run(&mut self, query: &str) -> Result<(), Error> {
|
||||
async fn run(&mut self, query: &str, arguments: Option<MsSqlArguments>) -> Result<(), Error> {
|
||||
if let Some(mut arguments) = arguments {
|
||||
let proc = Either::Right(Procedure::ExecuteSql);
|
||||
let mut proc_args = MsSqlArguments::default();
|
||||
|
||||
// SQL
|
||||
proc_args.add_unnamed(query);
|
||||
|
||||
// Declarations
|
||||
// NAME TYPE, NAME TYPE, ...
|
||||
proc_args.add_unnamed(&*arguments.declarations);
|
||||
|
||||
// Add the list of SQL parameters _after_ our RPC parameters
|
||||
proc_args.append(&mut arguments);
|
||||
|
||||
self.stream.write_packet(
|
||||
PacketType::Rpc,
|
||||
RpcRequest {
|
||||
arguments: &proc_args,
|
||||
procedure: proc,
|
||||
options: OptionFlags::empty(),
|
||||
},
|
||||
);
|
||||
} else {
|
||||
self.stream
|
||||
.write_packet(PacketType::SqlBatch, SqlBatch { sql: query });
|
||||
}
|
||||
|
||||
self.stream.flush().await?;
|
||||
|
||||
@ -35,10 +60,10 @@ impl<'c> Executor<'c> for &'c mut MsSqlConnection {
|
||||
E: Execute<'q, Self::Database>,
|
||||
{
|
||||
let s = query.query();
|
||||
// TODO: let arguments = query.take_arguments();
|
||||
let arguments = query.take_arguments();
|
||||
|
||||
Box::pin(try_stream! {
|
||||
self.run(s).await?;
|
||||
self.run(s, arguments).await?;
|
||||
|
||||
loop {
|
||||
match self.stream.recv_message().await? {
|
||||
|
@ -1,9 +1,12 @@
|
||||
use std::fmt::{self, Debug, Formatter};
|
||||
use std::net::Shutdown;
|
||||
|
||||
use futures_core::future::BoxFuture;
|
||||
use futures_util::{future::ready, FutureExt, TryFutureExt};
|
||||
|
||||
use crate::connection::{Connect, Connection};
|
||||
use crate::error::{BoxDynError, Error};
|
||||
use crate::executor::Executor;
|
||||
use crate::mssql::connection::stream::MsSqlStream;
|
||||
use crate::mssql::{MsSql, MsSqlConnectOptions};
|
||||
|
||||
@ -25,23 +28,28 @@ impl Connection for MsSqlConnection {
|
||||
type Database = MsSql;
|
||||
|
||||
fn close(self) -> BoxFuture<'static, Result<(), Error>> {
|
||||
unimplemented!()
|
||||
// NOTE: there does not seem to be a clean shutdown packet to send to MSSQL
|
||||
ready(self.stream.shutdown(Shutdown::Both).map_err(Into::into)).boxed()
|
||||
}
|
||||
|
||||
fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> {
|
||||
unimplemented!()
|
||||
// NOTE: we do not use `SELECT 1` as that *could* interact with any ongoing transactions
|
||||
self.execute("/* SQLx ping */").map_ok(|_| ()).boxed()
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
fn get_ref(&self) -> &MsSqlConnection {
|
||||
unimplemented!()
|
||||
self
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
fn get_mut(&mut self) -> &mut MsSqlConnection {
|
||||
unimplemented!()
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -8,12 +8,13 @@ use crate::io::{BufStream, Encode};
|
||||
use crate::mssql::protocol::col_meta_data::{ColMetaData, ColumnData};
|
||||
use crate::mssql::protocol::done::Done;
|
||||
use crate::mssql::protocol::env_change::EnvChange;
|
||||
use crate::mssql::protocol::error::Error as ProtocolError;
|
||||
use crate::mssql::protocol::info::Info;
|
||||
use crate::mssql::protocol::login_ack::LoginAck;
|
||||
use crate::mssql::protocol::message::{Message, MessageType};
|
||||
use crate::mssql::protocol::packet::{PacketHeader, PacketType, Status};
|
||||
use crate::mssql::protocol::row::Row;
|
||||
use crate::mssql::MsSqlConnectOptions;
|
||||
use crate::mssql::{MsSqlConnectOptions, MsSqlDatabaseError};
|
||||
use crate::net::MaybeTlsStream;
|
||||
|
||||
pub(crate) struct MsSqlStream {
|
||||
@ -104,13 +105,20 @@ impl MsSqlStream {
|
||||
break;
|
||||
};
|
||||
|
||||
return Ok(match MessageType::get(buf)? {
|
||||
let ty = MessageType::get(buf)?;
|
||||
|
||||
return Ok(match ty {
|
||||
MessageType::EnvChange => Message::EnvChange(EnvChange::get(buf)?),
|
||||
MessageType::Info => Message::Info(Info::get(buf)?),
|
||||
MessageType::Row => Message::Row(Row::get(buf, &self.columns)?),
|
||||
MessageType::LoginAck => Message::LoginAck(LoginAck::get(buf)?),
|
||||
MessageType::Done => Message::Done(Done::get(buf)?),
|
||||
|
||||
MessageType::Error => {
|
||||
let err = ProtocolError::get(buf)?;
|
||||
return Err(MsSqlDatabaseError(err).into());
|
||||
}
|
||||
|
||||
MessageType::ColMetaData => {
|
||||
// NOTE: there isn't anything to return as the data gets
|
||||
// consumed by the stream for use in subsequent Row decoding
|
||||
|
@ -1,42 +1,52 @@
|
||||
use crate::error::DatabaseError;
|
||||
use std::error::Error;
|
||||
use std::error::Error as StdError;
|
||||
use std::fmt::{self, Debug, Display, Formatter};
|
||||
|
||||
use crate::error::DatabaseError;
|
||||
use crate::mssql::protocol::error::Error;
|
||||
|
||||
/// An error returned from the MSSQL database.
|
||||
pub struct MsSqlDatabaseError {}
|
||||
pub struct MsSqlDatabaseError(pub(crate) Error);
|
||||
|
||||
impl Debug for MsSqlDatabaseError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
unimplemented!()
|
||||
f.debug_struct("MsSqlDatabaseError")
|
||||
.field("message", &self.0.message)
|
||||
.field("number", &self.0.number)
|
||||
.field("state", &self.0.state)
|
||||
.field("class", &self.0.class)
|
||||
.field("server", &self.0.server)
|
||||
.field("procedure", &self.0.procedure)
|
||||
.field("line", &self.0.line)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for MsSqlDatabaseError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
unimplemented!()
|
||||
f.pad(self.message())
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for MsSqlDatabaseError {}
|
||||
impl StdError for MsSqlDatabaseError {}
|
||||
|
||||
impl DatabaseError for MsSqlDatabaseError {
|
||||
#[inline]
|
||||
fn message(&self) -> &str {
|
||||
unimplemented!()
|
||||
&self.0.message
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
fn as_error(&self) -> &(dyn Error + Send + Sync + 'static) {
|
||||
fn as_error(&self) -> &(dyn StdError + Send + Sync + 'static) {
|
||||
self
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
fn as_error_mut(&mut self) -> &mut (dyn Error + Send + Sync + 'static) {
|
||||
fn as_error_mut(&mut self) -> &mut (dyn StdError + Send + Sync + 'static) {
|
||||
self
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
fn into_error(self: Box<Self>) -> Box<dyn Error + Send + Sync + 'static> {
|
||||
fn into_error(self: Box<Self>) -> Box<dyn StdError + Send + Sync + 'static> {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
pub trait MsSqlBufMutExt {
|
||||
fn put_b_varchar(&mut self, s: &str);
|
||||
fn put_utf16_str(&mut self, s: &str);
|
||||
}
|
||||
|
||||
@ -9,4 +10,9 @@ impl MsSqlBufMutExt for Vec<u8> {
|
||||
self.extend_from_slice(&ch.to_le_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
fn put_b_varchar(&mut self, s: &str) {
|
||||
self.extend(&(s.len() as u8).to_le_bytes());
|
||||
self.put_utf16_str(s);
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ use crate::io::Decode;
|
||||
use crate::mssql::io::MsSqlBufExt;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub(crate) enum EnvChange {
|
||||
Database(String),
|
||||
Language(String),
|
||||
|
54
sqlx-core/src/mssql/protocol/error.rs
Normal file
54
sqlx-core/src/mssql/protocol/error.rs
Normal file
@ -0,0 +1,54 @@
|
||||
use crate::mssql::io::MsSqlBufExt;
|
||||
use bytes::{Buf, Bytes};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Error {
|
||||
// The error number
|
||||
pub(crate) number: i32,
|
||||
|
||||
// The error state, used as a modifier to the error number.
|
||||
pub(crate) state: u8,
|
||||
|
||||
// The class (severity) of the error. A class of less than 10 indicates
|
||||
// an informational message.
|
||||
pub(crate) class: u8,
|
||||
|
||||
// The message text length and message text using US_VARCHAR format.
|
||||
pub(crate) message: String,
|
||||
|
||||
// The server name length and server name using B_VARCHAR format
|
||||
pub(crate) server: String,
|
||||
|
||||
// The stored procedure name length and the stored procedure name using B_VARCHAR format
|
||||
pub(crate) procedure: String,
|
||||
|
||||
// The line number in the SQL batch or stored procedure that caused the error. Line numbers
|
||||
// begin at 1. If the line number is not applicable to the message, the
|
||||
// value of LineNumber is 0.
|
||||
pub(crate) line: i32,
|
||||
}
|
||||
|
||||
impl Error {
|
||||
pub(crate) fn get(buf: &mut Bytes) -> Result<Self, crate::error::Error> {
|
||||
let len = buf.get_u16_le();
|
||||
let mut data = buf.split_to(len as usize);
|
||||
|
||||
let number = data.get_i32_le();
|
||||
let state = data.get_u8();
|
||||
let class = data.get_u8();
|
||||
let message = data.get_us_varchar()?;
|
||||
let server = data.get_b_varchar()?;
|
||||
let procedure = data.get_b_varchar()?;
|
||||
let line = data.get_i32_le();
|
||||
|
||||
Ok(Self {
|
||||
number,
|
||||
state,
|
||||
class,
|
||||
message,
|
||||
server,
|
||||
procedure,
|
||||
line,
|
||||
})
|
||||
}
|
||||
}
|
45
sqlx-core/src/mssql/protocol/header.rs
Normal file
45
sqlx-core/src/mssql/protocol/header.rs
Normal file
@ -0,0 +1,45 @@
|
||||
use crate::io::Encode;
|
||||
|
||||
pub(crate) struct AllHeaders<'a>(pub(crate) &'a [Header]);
|
||||
|
||||
impl Encode<'_> for AllHeaders<'_> {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
let offset = buf.len();
|
||||
buf.resize(buf.len() + 4, 0);
|
||||
|
||||
for header in self.0 {
|
||||
header.encode_with(buf, ());
|
||||
}
|
||||
|
||||
let len = buf.len() - offset;
|
||||
buf[offset..(offset + 4)].copy_from_slice(&(len as u32).to_le_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) enum Header {
|
||||
TransactionDescriptor {
|
||||
// number of requests currently active on the connection
|
||||
outstanding_request_count: u32,
|
||||
|
||||
// for each connection, a number that uniquely identifies the transaction with which the
|
||||
// request is associated; initially generated by the server when a new transaction is
|
||||
// created and returned to the client as part of the ENVCHANGE token stream
|
||||
transaction_descriptor: u64,
|
||||
},
|
||||
}
|
||||
|
||||
impl Encode<'_> for Header {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
match self {
|
||||
Header::TransactionDescriptor {
|
||||
outstanding_request_count,
|
||||
transaction_descriptor,
|
||||
} => {
|
||||
buf.extend(&18_u32.to_le_bytes()); // [HeaderLength] 4 + 2 + 8 + 4
|
||||
buf.extend(&2_u16.to_le_bytes()); // [HeaderType]
|
||||
buf.extend(&transaction_descriptor.to_le_bytes());
|
||||
buf.extend(&outstanding_request_count.to_le_bytes());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,9 +1,9 @@
|
||||
use bytes::{Buf, Bytes};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::mssql::protocol::col_meta_data::ColMetaData;
|
||||
use crate::mssql::protocol::done::Done;
|
||||
use crate::mssql::protocol::env_change::EnvChange;
|
||||
use crate::mssql::protocol::error::Error;
|
||||
use crate::mssql::protocol::info::Info;
|
||||
use crate::mssql::protocol::login_ack::LoginAck;
|
||||
use crate::mssql::protocol::row::Row;
|
||||
@ -15,7 +15,6 @@ pub(crate) enum Message {
|
||||
EnvChange(EnvChange),
|
||||
Done(Done),
|
||||
Row(Row),
|
||||
ColMetaData(ColMetaData),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -25,13 +24,15 @@ pub(crate) enum MessageType {
|
||||
EnvChange,
|
||||
Done,
|
||||
Row,
|
||||
Error,
|
||||
ColMetaData,
|
||||
}
|
||||
|
||||
impl MessageType {
|
||||
pub(crate) fn get(buf: &mut Bytes) -> Result<Self, Error> {
|
||||
pub(crate) fn get(buf: &mut Bytes) -> Result<Self, crate::error::Error> {
|
||||
Ok(match buf.get_u8() {
|
||||
0x81 => MessageType::ColMetaData,
|
||||
0xaa => MessageType::Error,
|
||||
0xab => MessageType::Info,
|
||||
0xad => MessageType::LoginAck,
|
||||
0xd1 => MessageType::Row,
|
||||
|
@ -1,6 +1,8 @@
|
||||
pub(crate) mod col_meta_data;
|
||||
pub(crate) mod done;
|
||||
pub(crate) mod env_change;
|
||||
pub(crate) mod error;
|
||||
pub(crate) mod header;
|
||||
pub(crate) mod info;
|
||||
pub(crate) mod login;
|
||||
pub(crate) mod login_ack;
|
||||
@ -8,5 +10,6 @@ pub(crate) mod message;
|
||||
pub(crate) mod packet;
|
||||
pub(crate) mod pre_login;
|
||||
pub(crate) mod row;
|
||||
pub(crate) mod rpc;
|
||||
pub(crate) mod sql_batch;
|
||||
pub(crate) mod type_info;
|
||||
|
@ -21,7 +21,7 @@ pub(crate) struct PreLogin<'a> {
|
||||
}
|
||||
|
||||
impl<'de> Decode<'de> for PreLogin<'de> {
|
||||
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
|
||||
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
|
||||
let mut version = None;
|
||||
let mut encryption = None;
|
||||
|
||||
|
@ -25,7 +25,7 @@ impl Row {
|
||||
if column.type_info.is_null() {
|
||||
values.push(None);
|
||||
} else {
|
||||
values.push(Some(buf.split_to(column.type_info.size())));
|
||||
values.push(Some(column.type_info.get_value(buf)));
|
||||
}
|
||||
}
|
||||
|
||||
|
91
sqlx-core/src/mssql/protocol/rpc.rs
Normal file
91
sqlx-core/src/mssql/protocol/rpc.rs
Normal file
@ -0,0 +1,91 @@
|
||||
use bitflags::bitflags;
|
||||
use either::Either;
|
||||
|
||||
use crate::io::Encode;
|
||||
use crate::mssql::io::MsSqlBufMutExt;
|
||||
use crate::mssql::protocol::header::{AllHeaders, Header};
|
||||
use crate::mssql::MsSqlArguments;
|
||||
|
||||
pub(crate) struct RpcRequest<'a> {
|
||||
// the procedure can be encoded as a u16 of a built-in or the name for a custom one
|
||||
pub(crate) procedure: Either<&'a str, Procedure>,
|
||||
pub(crate) options: OptionFlags,
|
||||
pub(crate) arguments: &'a MsSqlArguments,
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
#[repr(u16)]
|
||||
#[allow(dead_code)]
|
||||
pub(crate) enum Procedure {
|
||||
Cursor = 1,
|
||||
CursorOpen = 2,
|
||||
CursorPrepare = 3,
|
||||
CursorExecute = 4,
|
||||
CursorPrepareExecute = 5,
|
||||
CursorUnprepare = 6,
|
||||
CursorFetch = 7,
|
||||
CursorOption = 8,
|
||||
CursorClose = 9,
|
||||
ExecuteSql = 10,
|
||||
Prepare = 11,
|
||||
Execute = 12,
|
||||
PrepareExecute = 13,
|
||||
PrepareExecuteRpc = 14,
|
||||
Unprepare = 15,
|
||||
}
|
||||
|
||||
bitflags! {
|
||||
pub(crate) struct OptionFlags: u16 {
|
||||
const WITH_RECOMPILE = 1;
|
||||
|
||||
// The server sends NoMetaData only if fNoMetadata is set to 1 in the request
|
||||
const NO_META_DATA = 2;
|
||||
|
||||
// 1 if the metadata has not changed from the previous call and the server SHOULD reuse
|
||||
// its cached metadata (the metadata MUST still be sent).
|
||||
const REUSE_META_DATA = 4;
|
||||
}
|
||||
}
|
||||
|
||||
bitflags! {
|
||||
pub(crate) struct StatusFlags: u8 {
|
||||
// if the parameter is passed by reference (OUTPUT parameter) or
|
||||
// 0 if parameter is passed by value
|
||||
const BY_REF_VALUE = 1;
|
||||
|
||||
// 1 if the parameter being passed is to be the default value
|
||||
const DEFAULT_VALUE = 2;
|
||||
|
||||
// 1 if the parameter that is being passed is encrypted. This flag is valid
|
||||
// only when the column encryption feature is negotiated by client and server
|
||||
// and is turned on
|
||||
const ENCRYPTED = 8;
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<'_> for RpcRequest<'_> {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
AllHeaders(&[Header::TransactionDescriptor {
|
||||
outstanding_request_count: 1,
|
||||
transaction_descriptor: 0,
|
||||
}])
|
||||
.encode(buf);
|
||||
|
||||
match &self.procedure {
|
||||
Either::Left(name) => {
|
||||
buf.extend(&(name.len() as u16).to_le_bytes());
|
||||
buf.put_utf16_str(name);
|
||||
}
|
||||
|
||||
Either::Right(id) => {
|
||||
buf.extend(&(0xffff_u16).to_le_bytes());
|
||||
buf.extend(&(*id as u16).to_le_bytes());
|
||||
}
|
||||
}
|
||||
|
||||
buf.extend(&self.options.bits.to_le_bytes());
|
||||
buf.extend(&self.arguments.data);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Test serialization of this?
|
@ -1,7 +1,6 @@
|
||||
use crate::io::Encode;
|
||||
use crate::mssql::io::MsSqlBufMutExt;
|
||||
|
||||
const HEADER_TRANSACTION_DESCRIPTOR: u16 = 0x00_02;
|
||||
use crate::mssql::protocol::header::{AllHeaders, Header};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct SqlBatch<'a> {
|
||||
@ -10,24 +9,11 @@ pub(crate) struct SqlBatch<'a> {
|
||||
|
||||
impl Encode<'_> for SqlBatch<'_> {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
// ALL_HEADERS -> TotalLength
|
||||
buf.extend(&(4_u32 + 18).to_le_bytes()); // 4 + 18
|
||||
|
||||
// [Header] Transaction Descriptor
|
||||
// SQL_BATCH messages require this header
|
||||
// contains information regarding number of outstanding requests for MARS
|
||||
buf.extend(&18_u32.to_le_bytes()); // 4 + 2 + 8 + 4
|
||||
buf.extend(&HEADER_TRANSACTION_DESCRIPTOR.to_le_bytes());
|
||||
|
||||
// [TransactionDescriptor] a number that uniquely identifies the current transaction
|
||||
// TODO: use this once we support transactions, it will be given to us from the
|
||||
// server ENVCHANGE event
|
||||
buf.extend(&0_u64.to_le_bytes());
|
||||
|
||||
// [OutstandingRequestCount] Number of active requests to MSSQL from the
|
||||
// same connection
|
||||
// NOTE: Long-term when we support MARS we need to connect this value correctly
|
||||
buf.extend(&(1_u32.to_le_bytes()));
|
||||
AllHeaders(&[Header::TransactionDescriptor {
|
||||
outstanding_request_count: 1,
|
||||
transaction_descriptor: 0,
|
||||
}])
|
||||
.encode(buf);
|
||||
|
||||
// SQLText
|
||||
buf.put_utf16_str(self.sql);
|
||||
|
@ -1,9 +1,36 @@
|
||||
use crate::error::Error;
|
||||
use std::borrow::Cow;
|
||||
|
||||
use bitflags::bitflags;
|
||||
use bytes::{Buf, Bytes};
|
||||
|
||||
use crate::encode::Encode;
|
||||
use crate::error::Error;
|
||||
use crate::mssql::MsSql;
|
||||
use url::quirks::set_search;
|
||||
|
||||
bitflags! {
|
||||
pub(crate) struct CollationFlags: u8 {
|
||||
const IGNORE_CASE = (1 << 0);
|
||||
const IGNORE_ACCENT = (1 << 1);
|
||||
const IGNORE_WIDTH = (1 << 2);
|
||||
const IGNORE_KANA = (1 << 3);
|
||||
const BINARY = (1 << 4);
|
||||
const BINARY2 = (1 << 5);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum DataType {
|
||||
// Fixed-length data types
|
||||
pub(crate) struct Collation {
|
||||
pub(crate) locale: u32,
|
||||
pub(crate) flags: CollationFlags,
|
||||
pub(crate) sort: u8,
|
||||
pub(crate) version: u8,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
#[repr(u8)]
|
||||
pub(crate) enum DataType {
|
||||
// fixed-length data types
|
||||
// https://docs.microsoft.com/en-us/openspecs/sql_server_protocols/ms-sstds/d33ef17b-7e53-4380-ad11-2ba42c8dda8d
|
||||
Null = 0x1f,
|
||||
TinyInt = 0x30,
|
||||
@ -17,39 +44,414 @@ pub enum DataType {
|
||||
Float = 0x3e,
|
||||
SmallMoney = 0x7a,
|
||||
BigInt = 0x7f,
|
||||
|
||||
// variable-length data types
|
||||
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/ce3183a6-9d89-47e8-a02f-de5a1a1303de
|
||||
|
||||
// byte length
|
||||
Guid = 0x24,
|
||||
IntN = 0x26,
|
||||
Decimal = 0x37, // legacy
|
||||
Numeric = 0x3f, // legacy
|
||||
BitN = 0x68,
|
||||
DecimalN = 0x6a,
|
||||
NumericN = 0x6c,
|
||||
FloatN = 0x6d,
|
||||
MoneyN = 0x6e,
|
||||
DateTimeN = 0x6f,
|
||||
DateN = 0x28,
|
||||
TimeN = 0x29,
|
||||
DateTime2N = 0x2a,
|
||||
DateTimeOffsetN = 0x2b,
|
||||
Char = 0x2f, // legacy
|
||||
VarChar = 0x27, // legacy
|
||||
Binary = 0x2d, // legacy
|
||||
VarBinary = 0x25, // legacy
|
||||
|
||||
// short length
|
||||
BigVarBinary = 0xa5,
|
||||
BigVarChar = 0xa7,
|
||||
BigBinary = 0xad,
|
||||
BigChar = 0xaf,
|
||||
NVarChar = 0xe7,
|
||||
NChar = 0xef,
|
||||
Xml = 0xf1,
|
||||
UserDefined = 0xf0,
|
||||
|
||||
// long length
|
||||
Text = 0x23,
|
||||
Image = 0x22,
|
||||
NText = 0x63,
|
||||
Variant = 0x62,
|
||||
}
|
||||
|
||||
// http://msdn.microsoft.com/en-us/library/dd358284.aspx
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub(crate) struct TypeInfo {
|
||||
pub(crate) ty: DataType,
|
||||
pub(crate) size: u32,
|
||||
pub(crate) scale: u8,
|
||||
pub(crate) precision: u8,
|
||||
pub(crate) collation: Option<Collation>,
|
||||
}
|
||||
|
||||
impl TypeInfo {
|
||||
pub(crate) const fn new(ty: DataType, size: u32) -> Self {
|
||||
Self {
|
||||
ty,
|
||||
size,
|
||||
scale: 0,
|
||||
precision: 0,
|
||||
collation: None,
|
||||
}
|
||||
}
|
||||
|
||||
// reads a TYPE_INFO from the buffer
|
||||
pub(crate) fn get(buf: &mut Bytes) -> Result<Self, Error> {
|
||||
let ty = DataType::get(buf)?;
|
||||
|
||||
Ok(Self { ty })
|
||||
Ok(match ty {
|
||||
DataType::Null => Self::new(ty, 0),
|
||||
|
||||
DataType::TinyInt | DataType::Bit => Self::new(ty, 1),
|
||||
|
||||
DataType::SmallInt => Self::new(ty, 2),
|
||||
|
||||
DataType::Int | DataType::SmallDateTime | DataType::Real | DataType::SmallMoney => {
|
||||
Self::new(ty, 4)
|
||||
}
|
||||
|
||||
DataType::BigInt | DataType::Money | DataType::DateTime | DataType::Float => {
|
||||
Self::new(ty, 8)
|
||||
}
|
||||
|
||||
DataType::DateN => Self::new(ty, 3),
|
||||
|
||||
DataType::TimeN | DataType::DateTime2N | DataType::DateTimeOffsetN => {
|
||||
let scale = buf.get_u8();
|
||||
|
||||
let mut size = match scale {
|
||||
0 | 1 | 2 => 3,
|
||||
3 | 4 => 4,
|
||||
5 | 6 | 7 => 5,
|
||||
|
||||
scale => {
|
||||
return Err(err_protocol!("invalid scale {} for type {:?}", scale, ty));
|
||||
}
|
||||
};
|
||||
|
||||
match ty {
|
||||
DataType::DateTime2N => {
|
||||
size += 3;
|
||||
}
|
||||
|
||||
DataType::DateTimeOffsetN => {
|
||||
size += 5;
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Self {
|
||||
scale,
|
||||
size,
|
||||
ty,
|
||||
precision: 0,
|
||||
collation: None,
|
||||
}
|
||||
}
|
||||
|
||||
DataType::Guid
|
||||
| DataType::IntN
|
||||
| DataType::BitN
|
||||
| DataType::FloatN
|
||||
| DataType::MoneyN
|
||||
| DataType::DateTimeN
|
||||
| DataType::Char
|
||||
| DataType::VarChar
|
||||
| DataType::Binary
|
||||
| DataType::VarBinary => Self::new(ty, buf.get_u8() as u32),
|
||||
|
||||
DataType::Decimal | DataType::Numeric | DataType::DecimalN | DataType::NumericN => {
|
||||
let size = buf.get_u8() as u32;
|
||||
let precision = buf.get_u8();
|
||||
let scale = buf.get_u8();
|
||||
|
||||
Self {
|
||||
size,
|
||||
precision,
|
||||
scale,
|
||||
ty,
|
||||
collation: None,
|
||||
}
|
||||
}
|
||||
|
||||
DataType::BigVarBinary | DataType::BigBinary => Self::new(ty, buf.get_u16_le() as u32),
|
||||
|
||||
DataType::BigVarChar | DataType::BigChar | DataType::NVarChar | DataType::NChar => {
|
||||
let size = buf.get_u16_le() as u32;
|
||||
let collation = Collation::get(buf);
|
||||
|
||||
Self {
|
||||
ty,
|
||||
size,
|
||||
collation: Some(collation),
|
||||
scale: 0,
|
||||
precision: 0,
|
||||
}
|
||||
}
|
||||
|
||||
_ => {
|
||||
return Err(err_protocol!("unsupported data type {:?}", ty));
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// writes a TYPE_INFO to the buffer
|
||||
pub(crate) fn put(&self, buf: &mut Vec<u8>) {
|
||||
buf.push(self.ty as u8);
|
||||
|
||||
match self.ty {
|
||||
DataType::Null
|
||||
| DataType::TinyInt
|
||||
| DataType::Bit
|
||||
| DataType::SmallInt
|
||||
| DataType::Int
|
||||
| DataType::SmallDateTime
|
||||
| DataType::Real
|
||||
| DataType::SmallMoney
|
||||
| DataType::BigInt
|
||||
| DataType::Money
|
||||
| DataType::DateTime
|
||||
| DataType::Float
|
||||
| DataType::DateN => {
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
DataType::TimeN | DataType::DateTime2N | DataType::DateTimeOffsetN => {
|
||||
buf.push(self.scale);
|
||||
}
|
||||
|
||||
DataType::Guid
|
||||
| DataType::IntN
|
||||
| DataType::BitN
|
||||
| DataType::FloatN
|
||||
| DataType::MoneyN
|
||||
| DataType::DateTimeN
|
||||
| DataType::Char
|
||||
| DataType::VarChar
|
||||
| DataType::Binary
|
||||
| DataType::VarBinary => {
|
||||
buf.push(self.size as u8);
|
||||
}
|
||||
|
||||
DataType::Decimal | DataType::Numeric | DataType::DecimalN | DataType::NumericN => {
|
||||
buf.push(self.size as u8);
|
||||
buf.push(self.precision);
|
||||
buf.push(self.scale);
|
||||
}
|
||||
|
||||
DataType::BigVarBinary | DataType::BigBinary => {
|
||||
buf.extend(&(self.size as u16).to_le_bytes());
|
||||
}
|
||||
|
||||
DataType::BigVarChar | DataType::BigChar | DataType::NVarChar | DataType::NChar => {
|
||||
buf.extend(&(self.size as u16).to_le_bytes());
|
||||
|
||||
if let Some(collation) = &self.collation {
|
||||
collation.put(buf);
|
||||
} else {
|
||||
buf.extend(&0_u32.to_le_bytes());
|
||||
buf.push(0);
|
||||
}
|
||||
}
|
||||
|
||||
_ => {
|
||||
unimplemented!("unsupported data type {:?}", self.ty);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_null(&self) -> bool {
|
||||
matches!(self.ty, DataType::Null)
|
||||
}
|
||||
|
||||
pub(crate) fn size(&self) -> usize {
|
||||
pub(crate) fn get_value(&self, buf: &mut Bytes) -> Bytes {
|
||||
let size = match self.ty {
|
||||
DataType::Null
|
||||
| DataType::TinyInt
|
||||
| DataType::Bit
|
||||
| DataType::SmallInt
|
||||
| DataType::Int
|
||||
| DataType::SmallDateTime
|
||||
| DataType::Real
|
||||
| DataType::Money
|
||||
| DataType::DateTime
|
||||
| DataType::Float
|
||||
| DataType::SmallMoney
|
||||
| DataType::BigInt => self.size as usize,
|
||||
|
||||
DataType::Guid
|
||||
| DataType::IntN
|
||||
| DataType::Decimal
|
||||
| DataType::Numeric
|
||||
| DataType::BitN
|
||||
| DataType::DecimalN
|
||||
| DataType::NumericN
|
||||
| DataType::FloatN
|
||||
| DataType::MoneyN
|
||||
| DataType::DateTimeN
|
||||
| DataType::DateN
|
||||
| DataType::TimeN
|
||||
| DataType::DateTime2N
|
||||
| DataType::DateTimeOffsetN
|
||||
| DataType::Char
|
||||
| DataType::VarChar
|
||||
| DataType::Binary
|
||||
| DataType::VarBinary => buf.get_u8() as usize,
|
||||
|
||||
DataType::BigVarBinary
|
||||
| DataType::BigVarChar
|
||||
| DataType::BigBinary
|
||||
| DataType::BigChar
|
||||
| DataType::NVarChar
|
||||
| DataType::NChar
|
||||
| DataType::Xml
|
||||
| DataType::UserDefined => buf.get_u16_le() as usize,
|
||||
|
||||
DataType::Text | DataType::Image | DataType::NText | DataType::Variant => {
|
||||
buf.get_u32_le() as usize
|
||||
}
|
||||
};
|
||||
|
||||
buf.split_to(size)
|
||||
}
|
||||
|
||||
pub(crate) fn put_value<'q, T: Encode<'q, MsSql>>(&self, buf: &mut Vec<u8>, value: T) {
|
||||
match self.ty {
|
||||
DataType::Null => 0,
|
||||
DataType::TinyInt => 1,
|
||||
DataType::Bit => 1,
|
||||
DataType::SmallInt => 2,
|
||||
DataType::Int => 4,
|
||||
DataType::SmallDateTime => 4,
|
||||
DataType::Real => 4,
|
||||
DataType::Money => 4,
|
||||
DataType::DateTime => 8,
|
||||
DataType::Float => 8,
|
||||
DataType::SmallMoney => 4,
|
||||
DataType::BigInt => 8,
|
||||
DataType::Null
|
||||
| DataType::TinyInt
|
||||
| DataType::Bit
|
||||
| DataType::SmallInt
|
||||
| DataType::Int
|
||||
| DataType::SmallDateTime
|
||||
| DataType::Real
|
||||
| DataType::Money
|
||||
| DataType::DateTime
|
||||
| DataType::Float
|
||||
| DataType::SmallMoney
|
||||
| DataType::BigInt => {
|
||||
self.put_fixed_value(buf, value);
|
||||
}
|
||||
|
||||
DataType::Guid
|
||||
| DataType::IntN
|
||||
| DataType::Decimal
|
||||
| DataType::Numeric
|
||||
| DataType::BitN
|
||||
| DataType::DecimalN
|
||||
| DataType::NumericN
|
||||
| DataType::FloatN
|
||||
| DataType::MoneyN
|
||||
| DataType::DateTimeN
|
||||
| DataType::DateN
|
||||
| DataType::TimeN
|
||||
| DataType::DateTime2N
|
||||
| DataType::DateTimeOffsetN
|
||||
| DataType::Char
|
||||
| DataType::VarChar
|
||||
| DataType::Binary
|
||||
| DataType::VarBinary => {
|
||||
self.put_byte_len_value(buf, value);
|
||||
}
|
||||
|
||||
DataType::BigVarBinary
|
||||
| DataType::BigVarChar
|
||||
| DataType::BigBinary
|
||||
| DataType::BigChar
|
||||
| DataType::NVarChar
|
||||
| DataType::NChar
|
||||
| DataType::Xml
|
||||
| DataType::UserDefined => {
|
||||
self.put_short_len_value(buf, value);
|
||||
}
|
||||
|
||||
DataType::Text | DataType::Image | DataType::NText | DataType::Variant => {
|
||||
self.put_long_len_value(buf, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn put_fixed_value<'q, T: Encode<'q, MsSql>>(&self, buf: &mut Vec<u8>, value: T) {
|
||||
let _ = value.encode(buf);
|
||||
}
|
||||
|
||||
pub(crate) fn put_byte_len_value<'q, T: Encode<'q, MsSql>>(&self, buf: &mut Vec<u8>, value: T) {
|
||||
let offset = buf.len();
|
||||
buf.push(0);
|
||||
|
||||
let _ = value.encode(buf);
|
||||
|
||||
buf[offset] = (buf.len() - offset - 1) as u8;
|
||||
}
|
||||
|
||||
pub(crate) fn put_short_len_value<'q, T: Encode<'q, MsSql>>(
|
||||
&self,
|
||||
buf: &mut Vec<u8>,
|
||||
value: T,
|
||||
) {
|
||||
let offset = buf.len();
|
||||
buf.extend(&0_u16.to_le_bytes());
|
||||
|
||||
let _ = value.encode(buf);
|
||||
|
||||
let size = (buf.len() - offset - 2) as u16;
|
||||
buf[offset..(offset + 2)].copy_from_slice(&size.to_le_bytes());
|
||||
}
|
||||
|
||||
pub(crate) fn put_long_len_value<'q, T: Encode<'q, MsSql>>(&self, buf: &mut Vec<u8>, value: T) {
|
||||
let offset = buf.len();
|
||||
buf.extend(&0_u32.to_le_bytes());
|
||||
|
||||
let _ = value.encode(buf);
|
||||
|
||||
let size = (buf.len() - offset - 4) as u32;
|
||||
buf[offset..(offset + 4)].copy_from_slice(&size.to_le_bytes());
|
||||
}
|
||||
|
||||
pub(crate) fn fmt(&self, s: &mut String) {
|
||||
match self.ty {
|
||||
DataType::Null => s.push_str("nvarchar(1)"),
|
||||
DataType::TinyInt => s.push_str("tinyint"),
|
||||
DataType::SmallInt => s.push_str("smallint"),
|
||||
DataType::Int => s.push_str("int"),
|
||||
DataType::BigInt => s.push_str("bigint"),
|
||||
DataType::Real => s.push_str("real"),
|
||||
DataType::Float => s.push_str("float"),
|
||||
|
||||
DataType::IntN => s.push_str(match self.size {
|
||||
1 => "tinyint",
|
||||
2 => "smallint",
|
||||
4 => "int",
|
||||
8 => "bigint",
|
||||
|
||||
_ => unreachable!("invalid size {} for int"),
|
||||
}),
|
||||
|
||||
DataType::FloatN => s.push_str(match self.size {
|
||||
4 => "real",
|
||||
8 => "float",
|
||||
|
||||
_ => unreachable!("invalid size {} for float"),
|
||||
}),
|
||||
|
||||
DataType::NVarChar => {
|
||||
s.push_str("nvarchar(");
|
||||
let _ = itoa::fmt(&mut *s, self.size / 2);
|
||||
s.push_str(")");
|
||||
}
|
||||
|
||||
_ => unimplemented!("unsupported data type {:?}", self.ty),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -69,6 +471,36 @@ impl DataType {
|
||||
0x3e => DataType::Float,
|
||||
0x7a => DataType::SmallMoney,
|
||||
0x7f => DataType::BigInt,
|
||||
0x24 => DataType::Guid,
|
||||
0x26 => DataType::IntN,
|
||||
0x37 => DataType::Decimal,
|
||||
0x3f => DataType::Numeric,
|
||||
0x68 => DataType::BitN,
|
||||
0x6a => DataType::DecimalN,
|
||||
0x6c => DataType::NumericN,
|
||||
0x6d => DataType::FloatN,
|
||||
0x6e => DataType::MoneyN,
|
||||
0x6f => DataType::DateTimeN,
|
||||
0x28 => DataType::DateN,
|
||||
0x29 => DataType::TimeN,
|
||||
0x2a => DataType::DateTime2N,
|
||||
0x2b => DataType::DateTimeOffsetN,
|
||||
0x2f => DataType::Char,
|
||||
0x27 => DataType::VarChar,
|
||||
0x2d => DataType::Binary,
|
||||
0x25 => DataType::VarBinary,
|
||||
0xa5 => DataType::BigVarBinary,
|
||||
0xa7 => DataType::BigVarChar,
|
||||
0xad => DataType::BigBinary,
|
||||
0xaf => DataType::BigChar,
|
||||
0xe7 => DataType::NVarChar,
|
||||
0xef => DataType::NChar,
|
||||
0xf1 => DataType::Xml,
|
||||
0xf0 => DataType::UserDefined,
|
||||
0x23 => DataType::Text,
|
||||
0x22 => DataType::Image,
|
||||
0x63 => DataType::NText,
|
||||
0x62 => DataType::Variant,
|
||||
|
||||
ty => {
|
||||
return Err(err_protocol!("unknown data type 0x{:02x}", ty));
|
||||
@ -76,3 +508,28 @@ impl DataType {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Collation {
|
||||
pub(crate) fn get(buf: &mut Bytes) -> Collation {
|
||||
let locale_sort_version = buf.get_u32();
|
||||
let locale = locale_sort_version & 0xF_FFFF;
|
||||
let flags = CollationFlags::from_bits_truncate(((locale_sort_version >> 20) & 0xFF) as u8);
|
||||
let version = (locale_sort_version >> 28) as u8;
|
||||
let sort = buf.get_u8();
|
||||
|
||||
Collation {
|
||||
locale,
|
||||
flags,
|
||||
sort,
|
||||
version,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn put(&self, buf: &mut Vec<u8>) {
|
||||
let locale_sort_version =
|
||||
self.locale | ((self.flags.bits() as u32) << 20) | ((self.version as u32) << 28);
|
||||
|
||||
buf.extend(&locale_sort_version.to_le_bytes());
|
||||
buf.push(self.sort);
|
||||
}
|
||||
}
|
||||
|
@ -10,6 +10,9 @@ impl TypeInfo for MsSqlTypeInfo {}
|
||||
|
||||
impl Display for MsSqlTypeInfo {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
unimplemented!()
|
||||
let mut buf = String::new();
|
||||
self.0.fmt(&mut buf);
|
||||
|
||||
f.pad(&*buf)
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,8 @@
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
|
||||
use crate::database::{Database, HasValueRef};
|
||||
use crate::database::{Database, HasArguments, HasValueRef};
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::{Encode, IsNull};
|
||||
use crate::error::BoxDynError;
|
||||
use crate::mssql::protocol::type_info::{DataType, TypeInfo};
|
||||
use crate::mssql::{MsSql, MsSqlTypeInfo, MsSqlValueRef};
|
||||
@ -9,11 +10,23 @@ use crate::types::Type;
|
||||
|
||||
impl Type<MsSql> for i32 {
|
||||
fn type_info() -> MsSqlTypeInfo {
|
||||
MsSqlTypeInfo(TypeInfo { ty: DataType::Int })
|
||||
MsSqlTypeInfo(TypeInfo::new(DataType::IntN, 4))
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<'_, MsSql> for i32 {
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
buf.extend(&self.to_le_bytes());
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<'_, MsSql> for i32 {
|
||||
fn accepts(ty: &MsSqlTypeInfo) -> bool {
|
||||
matches!(ty.0.ty, DataType::Int | DataType::IntN) && ty.0.size == 4
|
||||
}
|
||||
|
||||
fn decode(value: MsSqlValueRef<'_>) -> Result<Self, BoxDynError> {
|
||||
Ok(LittleEndian::read_i32(value.as_bytes()?))
|
||||
}
|
||||
|
@ -1 +1,2 @@
|
||||
mod int;
|
||||
mod str;
|
||||
|
28
sqlx-core/src/mssql/types/str.rs
Normal file
28
sqlx-core/src/mssql/types/str.rs
Normal file
@ -0,0 +1,28 @@
|
||||
use byteorder::{ByteOrder, LittleEndian};
|
||||
|
||||
use crate::database::{Database, HasArguments, HasValueRef};
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::{Encode, IsNull};
|
||||
use crate::error::BoxDynError;
|
||||
use crate::mssql::io::MsSqlBufMutExt;
|
||||
use crate::mssql::protocol::type_info::{DataType, TypeInfo};
|
||||
use crate::mssql::{MsSql, MsSqlTypeInfo, MsSqlValueRef};
|
||||
use crate::types::Type;
|
||||
|
||||
impl Type<MsSql> for str {
|
||||
fn type_info() -> MsSqlTypeInfo {
|
||||
MsSqlTypeInfo(TypeInfo::new(DataType::NVarChar, 0))
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<'_, MsSql> for &'_ str {
|
||||
fn produces(&self) -> MsSqlTypeInfo {
|
||||
MsSqlTypeInfo(TypeInfo::new(DataType::NVarChar, (self.len() * 2) as u32))
|
||||
}
|
||||
|
||||
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
|
||||
buf.put_utf16_str(self);
|
||||
|
||||
IsNull::No
|
||||
}
|
||||
}
|
@ -1 +0,0 @@
|
||||
|
@ -15,7 +15,7 @@ async fn it_connects() -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn it_can_select_1() -> anyhow::Result<()> {
|
||||
async fn it_can_select_expression() -> anyhow::Result<()> {
|
||||
let mut conn = new::<MsSql>().await?;
|
||||
|
||||
let row: MsSqlRow = conn.fetch_one("SELECT 4").await?;
|
||||
@ -25,3 +25,18 @@ async fn it_can_select_1() -> anyhow::Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn it_maths() -> anyhow::Result<()> {
|
||||
let mut conn = new::<MsSql>().await?;
|
||||
|
||||
let value = sqlx::query("SELECT 1 + @p1")
|
||||
.bind(5_i32)
|
||||
.try_map(|row: MsSqlRow| row.try_get::<i32, _>(0))
|
||||
.fetch_one(&mut conn)
|
||||
.await?;
|
||||
|
||||
assert_eq!(6_i32, value);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user