feat(mssql): implement parameterized queries

This commit is contained in:
Ryan Leckey 2020-06-06 01:39:04 -07:00
parent 9a701313bc
commit c64122c03f
25 changed files with 899 additions and 266 deletions

View File

@ -1,26 +0,0 @@
use std::{
ascii::escape_default,
fmt::{self, Debug},
str::from_utf8,
};
// Wrapper type for byte slices that will debug print
// as a binary string
pub struct ByteStr<'a>(pub &'a [u8]);
impl Debug for ByteStr<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "b\"")?;
for &b in self.0 {
let part: Vec<u8> = escape_default(b).collect();
let s = from_utf8(&part).unwrap();
write!(f, "{}", s)?;
}
write!(f, "\"")?;
Ok(())
}
}

View File

@ -1,152 +0,0 @@
use std::io;
use std::net::Shutdown;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::runtime::{AsyncRead, AsyncWrite, TcpStream};
use self::Inner::*;
pub struct MaybeTlsStream {
inner: Inner,
}
enum Inner {
NotTls(TcpStream),
#[cfg(all(feature = "postgres", unix))]
UnixStream(crate::runtime::UnixStream),
#[cfg(feature = "tls")]
Tls(async_native_tls::TlsStream<TcpStream>),
#[cfg(feature = "tls")]
Upgrading,
}
impl MaybeTlsStream {
#[cfg(all(feature = "postgres", unix))]
pub async fn connect_uds<S: AsRef<std::ffi::OsStr>>(p: S) -> crate::Result<Self> {
let conn = crate::runtime::UnixStream::connect(p.as_ref()).await?;
Ok(Self {
inner: Inner::UnixStream(conn),
})
}
pub async fn connect(host: &str, port: u16) -> crate::Result<Self> {
let conn = TcpStream::connect((host, port)).await?;
Ok(Self {
inner: Inner::NotTls(conn),
})
}
#[allow(dead_code)]
pub fn is_tls(&self) -> bool {
match self.inner {
Inner::NotTls(_) => false,
#[cfg(all(feature = "postgres", unix))]
Inner::UnixStream(_) => false,
#[cfg(feature = "tls")]
Inner::Tls(_) => true,
#[cfg(feature = "tls")]
Inner::Upgrading => false,
}
}
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
pub async fn upgrade(
&mut self,
host: &str,
connector: async_native_tls::TlsConnector,
) -> crate::Result<()> {
let conn = match std::mem::replace(&mut self.inner, Upgrading) {
NotTls(conn) => conn,
#[cfg(all(feature = "postgres", unix))]
UnixStream(_) => {
return Err(tls_err!("TLS is not supported with unix domain sockets").into())
}
Tls(_) => return Err(tls_err!("connection already upgraded").into()),
Upgrading => return Err(tls_err!("connection already failed to upgrade").into()),
};
self.inner = Tls(connector.connect(host, conn).await?);
Ok(())
}
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
match self.inner {
NotTls(ref conn) => conn.shutdown(how),
#[cfg(all(feature = "postgres", unix))]
UnixStream(ref conn) => conn.shutdown(how),
#[cfg(feature = "tls")]
Tls(ref conn) => conn.get_ref().shutdown(how),
#[cfg(feature = "tls")]
// connection already closed
Upgrading => Ok(()),
}
}
}
macro_rules! forward_pin (
($self:ident.$method:ident($($arg:ident),*)) => (
match &mut $self.inner {
NotTls(ref mut conn) => Pin::new(conn).$method($($arg),*),
#[cfg(all(feature = "postgres", unix))]
UnixStream(ref mut conn) => Pin::new(conn).$method($($arg),*),
#[cfg(feature = "tls")]
Tls(ref mut conn) => Pin::new(conn).$method($($arg),*),
#[cfg(feature = "tls")]
Upgrading => Err(io::Error::new(io::ErrorKind::Other, "connection broken; TLS upgrade failed")).into(),
}
)
);
impl AsyncRead for MaybeTlsStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
forward_pin!(self.poll_read(cx, buf))
}
#[cfg(feature = "runtime-async-std")]
fn poll_read_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context,
bufs: &mut [std::io::IoSliceMut],
) -> Poll<io::Result<usize>> {
forward_pin!(self.poll_read_vectored(cx, bufs))
}
}
impl AsyncWrite for MaybeTlsStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
forward_pin!(self.poll_write(cx, buf))
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
forward_pin!(self.poll_flush(cx))
}
#[cfg(feature = "runtime-async-std")]
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
forward_pin!(self.poll_close(cx))
}
#[cfg(feature = "runtime-tokio")]
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
forward_pin!(self.poll_shutdown(cx))
}
#[cfg(feature = "runtime-async-std")]
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context,
bufs: &[std::io::IoSlice],
) -> Poll<io::Result<usize>> {
forward_pin!(self.poll_write_vectored(cx, bufs))
}
}

View File

@ -1,21 +1,88 @@
use crate::arguments::Arguments;
use crate::encode::Encode;
use crate::mssql::database::MsSql;
use crate::mssql::io::MsSqlBufMutExt;
#[derive(Default)]
pub struct MsSqlArguments {}
pub struct MsSqlArguments {
// next ordinal to be used when formatting a positional parameter name
pub(crate) ordinal: usize,
// temporary string buffer used to format parameter names
name: String,
pub(crate) data: Vec<u8>,
pub(crate) declarations: String,
}
impl MsSqlArguments {
pub(crate) fn add_named<'q, T: Encode<'q, MsSql>>(&mut self, name: &str, value: T) {
let ty = value.produces();
let mut ty_name = String::new();
ty.0.fmt(&mut ty_name);
self.data.put_b_varchar(name); // [ParamName]
self.data.push(0); // [StatusFlags]
ty.0.put(&mut self.data); // [TYPE_INFO]
ty.0.put_value(&mut self.data, value); // [ParamLenData]
}
pub(crate) fn add_unnamed<'q, T: Encode<'q, MsSql>>(&mut self, value: T) {
self.add_named("", value);
}
pub(crate) fn append(&mut self, arguments: &mut MsSqlArguments) {
self.ordinal += arguments.ordinal;
self.data.append(&mut arguments.data);
}
}
impl<'q> Arguments<'q> for MsSqlArguments {
type Database = MsSql;
fn reserve(&mut self, additional: usize, size: usize) {
unimplemented!()
fn reserve(&mut self, _additional: usize, size: usize) {
self.data.reserve(size + 10); // est. 4 chars for name, 1 for status, 1 for TYPE_INFO
}
fn add<T>(&mut self, value: T)
where
T: 'q + Encode<'q, Self::Database>,
{
unimplemented!()
let ty = value.produces();
// produce an ordinal parameter name
// @p1, @p2, ... @pN
self.name.clear();
self.name.push_str("@p");
self.ordinal += 1;
let _ = itoa::fmt(&mut self.name, self.ordinal);
let MsSqlArguments {
ref name,
ref mut declarations,
ref mut data,
..
} = self;
// add this to our variable declaration list
// @p1 int, @p2 nvarchar(10), ...
if !declarations.is_empty() {
declarations.push_str(",");
}
declarations.push_str(name);
declarations.push(' ');
ty.0.fmt(declarations);
// write out the parameter
data.put_b_varchar(name); // [ParamName]
data.push(0); // [StatusFlags]
ty.0.put(data); // [TYPE_INFO]
ty.0.put_value(data, value); // [ParamLenData]
}
}

View File

@ -31,14 +31,7 @@ impl MsSqlConnection {
stream.flush().await?;
let (_, packet) = stream.recv_packet().await?;
let pl = PreLogin::decode(packet)?;
log::trace!(
"acknowledged PRELOGIN from MSSQL v{}.{}.{}",
pl.version.major,
pl.version.minor,
pl.version.build
);
let _ = PreLogin::decode(packet)?;
// LOGIN7 defines the authentication rules for use between client and server
@ -70,12 +63,9 @@ impl MsSqlConnection {
// all messages are mostly informational (ENVCHANGE, INFO, LOGINACK)
match stream.recv_message().await? {
Message::LoginAck(ack) => {
log::trace!(
"established connection to {} {}",
ack.program_name,
ack.program_version
);
Message::LoginAck(_) => {
// indicates that the login was successful
// no action is needed, we are just going to keep waiting till we hit <Done>
}
Message::Done(_) => {

View File

@ -9,13 +9,38 @@ use crate::error::Error;
use crate::executor::{Execute, Executor};
use crate::mssql::protocol::message::Message;
use crate::mssql::protocol::packet::PacketType;
use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest};
use crate::mssql::protocol::sql_batch::SqlBatch;
use crate::mssql::{MsSql, MsSqlConnection, MsSqlRow};
use crate::mssql::{MsSql, MsSqlArguments, MsSqlConnection, MsSqlRow};
impl MsSqlConnection {
async fn run(&mut self, query: &str) -> Result<(), Error> {
async fn run(&mut self, query: &str, arguments: Option<MsSqlArguments>) -> Result<(), Error> {
if let Some(mut arguments) = arguments {
let proc = Either::Right(Procedure::ExecuteSql);
let mut proc_args = MsSqlArguments::default();
// SQL
proc_args.add_unnamed(query);
// Declarations
// NAME TYPE, NAME TYPE, ...
proc_args.add_unnamed(&*arguments.declarations);
// Add the list of SQL parameters _after_ our RPC parameters
proc_args.append(&mut arguments);
self.stream.write_packet(
PacketType::Rpc,
RpcRequest {
arguments: &proc_args,
procedure: proc,
options: OptionFlags::empty(),
},
);
} else {
self.stream
.write_packet(PacketType::SqlBatch, SqlBatch { sql: query });
}
self.stream.flush().await?;
@ -35,10 +60,10 @@ impl<'c> Executor<'c> for &'c mut MsSqlConnection {
E: Execute<'q, Self::Database>,
{
let s = query.query();
// TODO: let arguments = query.take_arguments();
let arguments = query.take_arguments();
Box::pin(try_stream! {
self.run(s).await?;
self.run(s, arguments).await?;
loop {
match self.stream.recv_message().await? {

View File

@ -1,9 +1,12 @@
use std::fmt::{self, Debug, Formatter};
use std::net::Shutdown;
use futures_core::future::BoxFuture;
use futures_util::{future::ready, FutureExt, TryFutureExt};
use crate::connection::{Connect, Connection};
use crate::error::{BoxDynError, Error};
use crate::executor::Executor;
use crate::mssql::connection::stream::MsSqlStream;
use crate::mssql::{MsSql, MsSqlConnectOptions};
@ -25,23 +28,28 @@ impl Connection for MsSqlConnection {
type Database = MsSql;
fn close(self) -> BoxFuture<'static, Result<(), Error>> {
unimplemented!()
// NOTE: there does not seem to be a clean shutdown packet to send to MSSQL
ready(self.stream.shutdown(Shutdown::Both).map_err(Into::into)).boxed()
}
fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> {
unimplemented!()
// NOTE: we do not use `SELECT 1` as that *could* interact with any ongoing transactions
self.execute("/* SQLx ping */").map_ok(|_| ()).boxed()
}
#[doc(hidden)]
fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
unimplemented!()
}
#[doc(hidden)]
fn get_ref(&self) -> &MsSqlConnection {
unimplemented!()
self
}
#[doc(hidden)]
fn get_mut(&mut self) -> &mut MsSqlConnection {
unimplemented!()
self
}
}

View File

@ -8,12 +8,13 @@ use crate::io::{BufStream, Encode};
use crate::mssql::protocol::col_meta_data::{ColMetaData, ColumnData};
use crate::mssql::protocol::done::Done;
use crate::mssql::protocol::env_change::EnvChange;
use crate::mssql::protocol::error::Error as ProtocolError;
use crate::mssql::protocol::info::Info;
use crate::mssql::protocol::login_ack::LoginAck;
use crate::mssql::protocol::message::{Message, MessageType};
use crate::mssql::protocol::packet::{PacketHeader, PacketType, Status};
use crate::mssql::protocol::row::Row;
use crate::mssql::MsSqlConnectOptions;
use crate::mssql::{MsSqlConnectOptions, MsSqlDatabaseError};
use crate::net::MaybeTlsStream;
pub(crate) struct MsSqlStream {
@ -104,13 +105,20 @@ impl MsSqlStream {
break;
};
return Ok(match MessageType::get(buf)? {
let ty = MessageType::get(buf)?;
return Ok(match ty {
MessageType::EnvChange => Message::EnvChange(EnvChange::get(buf)?),
MessageType::Info => Message::Info(Info::get(buf)?),
MessageType::Row => Message::Row(Row::get(buf, &self.columns)?),
MessageType::LoginAck => Message::LoginAck(LoginAck::get(buf)?),
MessageType::Done => Message::Done(Done::get(buf)?),
MessageType::Error => {
let err = ProtocolError::get(buf)?;
return Err(MsSqlDatabaseError(err).into());
}
MessageType::ColMetaData => {
// NOTE: there isn't anything to return as the data gets
// consumed by the stream for use in subsequent Row decoding

View File

@ -1,42 +1,52 @@
use crate::error::DatabaseError;
use std::error::Error;
use std::error::Error as StdError;
use std::fmt::{self, Debug, Display, Formatter};
use crate::error::DatabaseError;
use crate::mssql::protocol::error::Error;
/// An error returned from the MSSQL database.
pub struct MsSqlDatabaseError {}
pub struct MsSqlDatabaseError(pub(crate) Error);
impl Debug for MsSqlDatabaseError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
unimplemented!()
f.debug_struct("MsSqlDatabaseError")
.field("message", &self.0.message)
.field("number", &self.0.number)
.field("state", &self.0.state)
.field("class", &self.0.class)
.field("server", &self.0.server)
.field("procedure", &self.0.procedure)
.field("line", &self.0.line)
.finish()
}
}
impl Display for MsSqlDatabaseError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
unimplemented!()
f.pad(self.message())
}
}
impl Error for MsSqlDatabaseError {}
impl StdError for MsSqlDatabaseError {}
impl DatabaseError for MsSqlDatabaseError {
#[inline]
fn message(&self) -> &str {
unimplemented!()
&self.0.message
}
#[doc(hidden)]
fn as_error(&self) -> &(dyn Error + Send + Sync + 'static) {
fn as_error(&self) -> &(dyn StdError + Send + Sync + 'static) {
self
}
#[doc(hidden)]
fn as_error_mut(&mut self) -> &mut (dyn Error + Send + Sync + 'static) {
fn as_error_mut(&mut self) -> &mut (dyn StdError + Send + Sync + 'static) {
self
}
#[doc(hidden)]
fn into_error(self: Box<Self>) -> Box<dyn Error + Send + Sync + 'static> {
fn into_error(self: Box<Self>) -> Box<dyn StdError + Send + Sync + 'static> {
self
}
}

View File

@ -1,4 +1,5 @@
pub trait MsSqlBufMutExt {
fn put_b_varchar(&mut self, s: &str);
fn put_utf16_str(&mut self, s: &str);
}
@ -9,4 +10,9 @@ impl MsSqlBufMutExt for Vec<u8> {
self.extend_from_slice(&ch.to_le_bytes());
}
}
fn put_b_varchar(&mut self, s: &str) {
self.extend(&(s.len() as u8).to_le_bytes());
self.put_utf16_str(s);
}
}

View File

@ -5,6 +5,7 @@ use crate::io::Decode;
use crate::mssql::io::MsSqlBufExt;
#[derive(Debug)]
#[allow(dead_code)]
pub(crate) enum EnvChange {
Database(String),
Language(String),

View File

@ -0,0 +1,54 @@
use crate::mssql::io::MsSqlBufExt;
use bytes::{Buf, Bytes};
#[derive(Debug)]
pub(crate) struct Error {
// The error number
pub(crate) number: i32,
// The error state, used as a modifier to the error number.
pub(crate) state: u8,
// The class (severity) of the error. A class of less than 10 indicates
// an informational message.
pub(crate) class: u8,
// The message text length and message text using US_VARCHAR format.
pub(crate) message: String,
// The server name length and server name using B_VARCHAR format
pub(crate) server: String,
// The stored procedure name length and the stored procedure name using B_VARCHAR format
pub(crate) procedure: String,
// The line number in the SQL batch or stored procedure that caused the error. Line numbers
// begin at 1. If the line number is not applicable to the message, the
// value of LineNumber is 0.
pub(crate) line: i32,
}
impl Error {
pub(crate) fn get(buf: &mut Bytes) -> Result<Self, crate::error::Error> {
let len = buf.get_u16_le();
let mut data = buf.split_to(len as usize);
let number = data.get_i32_le();
let state = data.get_u8();
let class = data.get_u8();
let message = data.get_us_varchar()?;
let server = data.get_b_varchar()?;
let procedure = data.get_b_varchar()?;
let line = data.get_i32_le();
Ok(Self {
number,
state,
class,
message,
server,
procedure,
line,
})
}
}

View File

@ -0,0 +1,45 @@
use crate::io::Encode;
pub(crate) struct AllHeaders<'a>(pub(crate) &'a [Header]);
impl Encode<'_> for AllHeaders<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
let offset = buf.len();
buf.resize(buf.len() + 4, 0);
for header in self.0 {
header.encode_with(buf, ());
}
let len = buf.len() - offset;
buf[offset..(offset + 4)].copy_from_slice(&(len as u32).to_le_bytes());
}
}
pub(crate) enum Header {
TransactionDescriptor {
// number of requests currently active on the connection
outstanding_request_count: u32,
// for each connection, a number that uniquely identifies the transaction with which the
// request is associated; initially generated by the server when a new transaction is
// created and returned to the client as part of the ENVCHANGE token stream
transaction_descriptor: u64,
},
}
impl Encode<'_> for Header {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
match self {
Header::TransactionDescriptor {
outstanding_request_count,
transaction_descriptor,
} => {
buf.extend(&18_u32.to_le_bytes()); // [HeaderLength] 4 + 2 + 8 + 4
buf.extend(&2_u16.to_le_bytes()); // [HeaderType]
buf.extend(&transaction_descriptor.to_le_bytes());
buf.extend(&outstanding_request_count.to_le_bytes());
}
}
}
}

View File

@ -1,9 +1,9 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::mssql::protocol::col_meta_data::ColMetaData;
use crate::mssql::protocol::done::Done;
use crate::mssql::protocol::env_change::EnvChange;
use crate::mssql::protocol::error::Error;
use crate::mssql::protocol::info::Info;
use crate::mssql::protocol::login_ack::LoginAck;
use crate::mssql::protocol::row::Row;
@ -15,7 +15,6 @@ pub(crate) enum Message {
EnvChange(EnvChange),
Done(Done),
Row(Row),
ColMetaData(ColMetaData),
}
#[derive(Debug)]
@ -25,13 +24,15 @@ pub(crate) enum MessageType {
EnvChange,
Done,
Row,
Error,
ColMetaData,
}
impl MessageType {
pub(crate) fn get(buf: &mut Bytes) -> Result<Self, Error> {
pub(crate) fn get(buf: &mut Bytes) -> Result<Self, crate::error::Error> {
Ok(match buf.get_u8() {
0x81 => MessageType::ColMetaData,
0xaa => MessageType::Error,
0xab => MessageType::Info,
0xad => MessageType::LoginAck,
0xd1 => MessageType::Row,

View File

@ -1,6 +1,8 @@
pub(crate) mod col_meta_data;
pub(crate) mod done;
pub(crate) mod env_change;
pub(crate) mod error;
pub(crate) mod header;
pub(crate) mod info;
pub(crate) mod login;
pub(crate) mod login_ack;
@ -8,5 +10,6 @@ pub(crate) mod message;
pub(crate) mod packet;
pub(crate) mod pre_login;
pub(crate) mod row;
pub(crate) mod rpc;
pub(crate) mod sql_batch;
pub(crate) mod type_info;

View File

@ -21,7 +21,7 @@ pub(crate) struct PreLogin<'a> {
}
impl<'de> Decode<'de> for PreLogin<'de> {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
let mut version = None;
let mut encryption = None;

View File

@ -25,7 +25,7 @@ impl Row {
if column.type_info.is_null() {
values.push(None);
} else {
values.push(Some(buf.split_to(column.type_info.size())));
values.push(Some(column.type_info.get_value(buf)));
}
}

View File

@ -0,0 +1,91 @@
use bitflags::bitflags;
use either::Either;
use crate::io::Encode;
use crate::mssql::io::MsSqlBufMutExt;
use crate::mssql::protocol::header::{AllHeaders, Header};
use crate::mssql::MsSqlArguments;
pub(crate) struct RpcRequest<'a> {
// the procedure can be encoded as a u16 of a built-in or the name for a custom one
pub(crate) procedure: Either<&'a str, Procedure>,
pub(crate) options: OptionFlags,
pub(crate) arguments: &'a MsSqlArguments,
}
#[derive(Debug, Copy, Clone)]
#[repr(u16)]
#[allow(dead_code)]
pub(crate) enum Procedure {
Cursor = 1,
CursorOpen = 2,
CursorPrepare = 3,
CursorExecute = 4,
CursorPrepareExecute = 5,
CursorUnprepare = 6,
CursorFetch = 7,
CursorOption = 8,
CursorClose = 9,
ExecuteSql = 10,
Prepare = 11,
Execute = 12,
PrepareExecute = 13,
PrepareExecuteRpc = 14,
Unprepare = 15,
}
bitflags! {
pub(crate) struct OptionFlags: u16 {
const WITH_RECOMPILE = 1;
// The server sends NoMetaData only if fNoMetadata is set to 1 in the request
const NO_META_DATA = 2;
// 1 if the metadata has not changed from the previous call and the server SHOULD reuse
// its cached metadata (the metadata MUST still be sent).
const REUSE_META_DATA = 4;
}
}
bitflags! {
pub(crate) struct StatusFlags: u8 {
// if the parameter is passed by reference (OUTPUT parameter) or
// 0 if parameter is passed by value
const BY_REF_VALUE = 1;
// 1 if the parameter being passed is to be the default value
const DEFAULT_VALUE = 2;
// 1 if the parameter that is being passed is encrypted. This flag is valid
// only when the column encryption feature is negotiated by client and server
// and is turned on
const ENCRYPTED = 8;
}
}
impl Encode<'_> for RpcRequest<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
AllHeaders(&[Header::TransactionDescriptor {
outstanding_request_count: 1,
transaction_descriptor: 0,
}])
.encode(buf);
match &self.procedure {
Either::Left(name) => {
buf.extend(&(name.len() as u16).to_le_bytes());
buf.put_utf16_str(name);
}
Either::Right(id) => {
buf.extend(&(0xffff_u16).to_le_bytes());
buf.extend(&(*id as u16).to_le_bytes());
}
}
buf.extend(&self.options.bits.to_le_bytes());
buf.extend(&self.arguments.data);
}
}
// TODO: Test serialization of this?

View File

@ -1,7 +1,6 @@
use crate::io::Encode;
use crate::mssql::io::MsSqlBufMutExt;
const HEADER_TRANSACTION_DESCRIPTOR: u16 = 0x00_02;
use crate::mssql::protocol::header::{AllHeaders, Header};
#[derive(Debug)]
pub(crate) struct SqlBatch<'a> {
@ -10,24 +9,11 @@ pub(crate) struct SqlBatch<'a> {
impl Encode<'_> for SqlBatch<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
// ALL_HEADERS -> TotalLength
buf.extend(&(4_u32 + 18).to_le_bytes()); // 4 + 18
// [Header] Transaction Descriptor
// SQL_BATCH messages require this header
// contains information regarding number of outstanding requests for MARS
buf.extend(&18_u32.to_le_bytes()); // 4 + 2 + 8 + 4
buf.extend(&HEADER_TRANSACTION_DESCRIPTOR.to_le_bytes());
// [TransactionDescriptor] a number that uniquely identifies the current transaction
// TODO: use this once we support transactions, it will be given to us from the
// server ENVCHANGE event
buf.extend(&0_u64.to_le_bytes());
// [OutstandingRequestCount] Number of active requests to MSSQL from the
// same connection
// NOTE: Long-term when we support MARS we need to connect this value correctly
buf.extend(&(1_u32.to_le_bytes()));
AllHeaders(&[Header::TransactionDescriptor {
outstanding_request_count: 1,
transaction_descriptor: 0,
}])
.encode(buf);
// SQLText
buf.put_utf16_str(self.sql);

View File

@ -1,9 +1,36 @@
use crate::error::Error;
use std::borrow::Cow;
use bitflags::bitflags;
use bytes::{Buf, Bytes};
use crate::encode::Encode;
use crate::error::Error;
use crate::mssql::MsSql;
use url::quirks::set_search;
bitflags! {
pub(crate) struct CollationFlags: u8 {
const IGNORE_CASE = (1 << 0);
const IGNORE_ACCENT = (1 << 1);
const IGNORE_WIDTH = (1 << 2);
const IGNORE_KANA = (1 << 3);
const BINARY = (1 << 4);
const BINARY2 = (1 << 5);
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum DataType {
// Fixed-length data types
pub(crate) struct Collation {
pub(crate) locale: u32,
pub(crate) flags: CollationFlags,
pub(crate) sort: u8,
pub(crate) version: u8,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[repr(u8)]
pub(crate) enum DataType {
// fixed-length data types
// https://docs.microsoft.com/en-us/openspecs/sql_server_protocols/ms-sstds/d33ef17b-7e53-4380-ad11-2ba42c8dda8d
Null = 0x1f,
TinyInt = 0x30,
@ -17,39 +44,414 @@ pub enum DataType {
Float = 0x3e,
SmallMoney = 0x7a,
BigInt = 0x7f,
// variable-length data types
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/ce3183a6-9d89-47e8-a02f-de5a1a1303de
// byte length
Guid = 0x24,
IntN = 0x26,
Decimal = 0x37, // legacy
Numeric = 0x3f, // legacy
BitN = 0x68,
DecimalN = 0x6a,
NumericN = 0x6c,
FloatN = 0x6d,
MoneyN = 0x6e,
DateTimeN = 0x6f,
DateN = 0x28,
TimeN = 0x29,
DateTime2N = 0x2a,
DateTimeOffsetN = 0x2b,
Char = 0x2f, // legacy
VarChar = 0x27, // legacy
Binary = 0x2d, // legacy
VarBinary = 0x25, // legacy
// short length
BigVarBinary = 0xa5,
BigVarChar = 0xa7,
BigBinary = 0xad,
BigChar = 0xaf,
NVarChar = 0xe7,
NChar = 0xef,
Xml = 0xf1,
UserDefined = 0xf0,
// long length
Text = 0x23,
Image = 0x22,
NText = 0x63,
Variant = 0x62,
}
// http://msdn.microsoft.com/en-us/library/dd358284.aspx
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, Eq, PartialEq)]
pub(crate) struct TypeInfo {
pub(crate) ty: DataType,
pub(crate) size: u32,
pub(crate) scale: u8,
pub(crate) precision: u8,
pub(crate) collation: Option<Collation>,
}
impl TypeInfo {
pub(crate) const fn new(ty: DataType, size: u32) -> Self {
Self {
ty,
size,
scale: 0,
precision: 0,
collation: None,
}
}
// reads a TYPE_INFO from the buffer
pub(crate) fn get(buf: &mut Bytes) -> Result<Self, Error> {
let ty = DataType::get(buf)?;
Ok(Self { ty })
Ok(match ty {
DataType::Null => Self::new(ty, 0),
DataType::TinyInt | DataType::Bit => Self::new(ty, 1),
DataType::SmallInt => Self::new(ty, 2),
DataType::Int | DataType::SmallDateTime | DataType::Real | DataType::SmallMoney => {
Self::new(ty, 4)
}
DataType::BigInt | DataType::Money | DataType::DateTime | DataType::Float => {
Self::new(ty, 8)
}
DataType::DateN => Self::new(ty, 3),
DataType::TimeN | DataType::DateTime2N | DataType::DateTimeOffsetN => {
let scale = buf.get_u8();
let mut size = match scale {
0 | 1 | 2 => 3,
3 | 4 => 4,
5 | 6 | 7 => 5,
scale => {
return Err(err_protocol!("invalid scale {} for type {:?}", scale, ty));
}
};
match ty {
DataType::DateTime2N => {
size += 3;
}
DataType::DateTimeOffsetN => {
size += 5;
}
_ => {}
}
Self {
scale,
size,
ty,
precision: 0,
collation: None,
}
}
DataType::Guid
| DataType::IntN
| DataType::BitN
| DataType::FloatN
| DataType::MoneyN
| DataType::DateTimeN
| DataType::Char
| DataType::VarChar
| DataType::Binary
| DataType::VarBinary => Self::new(ty, buf.get_u8() as u32),
DataType::Decimal | DataType::Numeric | DataType::DecimalN | DataType::NumericN => {
let size = buf.get_u8() as u32;
let precision = buf.get_u8();
let scale = buf.get_u8();
Self {
size,
precision,
scale,
ty,
collation: None,
}
}
DataType::BigVarBinary | DataType::BigBinary => Self::new(ty, buf.get_u16_le() as u32),
DataType::BigVarChar | DataType::BigChar | DataType::NVarChar | DataType::NChar => {
let size = buf.get_u16_le() as u32;
let collation = Collation::get(buf);
Self {
ty,
size,
collation: Some(collation),
scale: 0,
precision: 0,
}
}
_ => {
return Err(err_protocol!("unsupported data type {:?}", ty));
}
})
}
// writes a TYPE_INFO to the buffer
pub(crate) fn put(&self, buf: &mut Vec<u8>) {
buf.push(self.ty as u8);
match self.ty {
DataType::Null
| DataType::TinyInt
| DataType::Bit
| DataType::SmallInt
| DataType::Int
| DataType::SmallDateTime
| DataType::Real
| DataType::SmallMoney
| DataType::BigInt
| DataType::Money
| DataType::DateTime
| DataType::Float
| DataType::DateN => {
// nothing to do
}
DataType::TimeN | DataType::DateTime2N | DataType::DateTimeOffsetN => {
buf.push(self.scale);
}
DataType::Guid
| DataType::IntN
| DataType::BitN
| DataType::FloatN
| DataType::MoneyN
| DataType::DateTimeN
| DataType::Char
| DataType::VarChar
| DataType::Binary
| DataType::VarBinary => {
buf.push(self.size as u8);
}
DataType::Decimal | DataType::Numeric | DataType::DecimalN | DataType::NumericN => {
buf.push(self.size as u8);
buf.push(self.precision);
buf.push(self.scale);
}
DataType::BigVarBinary | DataType::BigBinary => {
buf.extend(&(self.size as u16).to_le_bytes());
}
DataType::BigVarChar | DataType::BigChar | DataType::NVarChar | DataType::NChar => {
buf.extend(&(self.size as u16).to_le_bytes());
if let Some(collation) = &self.collation {
collation.put(buf);
} else {
buf.extend(&0_u32.to_le_bytes());
buf.push(0);
}
}
_ => {
unimplemented!("unsupported data type {:?}", self.ty);
}
}
}
pub(crate) fn is_null(&self) -> bool {
matches!(self.ty, DataType::Null)
}
pub(crate) fn size(&self) -> usize {
pub(crate) fn get_value(&self, buf: &mut Bytes) -> Bytes {
let size = match self.ty {
DataType::Null
| DataType::TinyInt
| DataType::Bit
| DataType::SmallInt
| DataType::Int
| DataType::SmallDateTime
| DataType::Real
| DataType::Money
| DataType::DateTime
| DataType::Float
| DataType::SmallMoney
| DataType::BigInt => self.size as usize,
DataType::Guid
| DataType::IntN
| DataType::Decimal
| DataType::Numeric
| DataType::BitN
| DataType::DecimalN
| DataType::NumericN
| DataType::FloatN
| DataType::MoneyN
| DataType::DateTimeN
| DataType::DateN
| DataType::TimeN
| DataType::DateTime2N
| DataType::DateTimeOffsetN
| DataType::Char
| DataType::VarChar
| DataType::Binary
| DataType::VarBinary => buf.get_u8() as usize,
DataType::BigVarBinary
| DataType::BigVarChar
| DataType::BigBinary
| DataType::BigChar
| DataType::NVarChar
| DataType::NChar
| DataType::Xml
| DataType::UserDefined => buf.get_u16_le() as usize,
DataType::Text | DataType::Image | DataType::NText | DataType::Variant => {
buf.get_u32_le() as usize
}
};
buf.split_to(size)
}
pub(crate) fn put_value<'q, T: Encode<'q, MsSql>>(&self, buf: &mut Vec<u8>, value: T) {
match self.ty {
DataType::Null => 0,
DataType::TinyInt => 1,
DataType::Bit => 1,
DataType::SmallInt => 2,
DataType::Int => 4,
DataType::SmallDateTime => 4,
DataType::Real => 4,
DataType::Money => 4,
DataType::DateTime => 8,
DataType::Float => 8,
DataType::SmallMoney => 4,
DataType::BigInt => 8,
DataType::Null
| DataType::TinyInt
| DataType::Bit
| DataType::SmallInt
| DataType::Int
| DataType::SmallDateTime
| DataType::Real
| DataType::Money
| DataType::DateTime
| DataType::Float
| DataType::SmallMoney
| DataType::BigInt => {
self.put_fixed_value(buf, value);
}
DataType::Guid
| DataType::IntN
| DataType::Decimal
| DataType::Numeric
| DataType::BitN
| DataType::DecimalN
| DataType::NumericN
| DataType::FloatN
| DataType::MoneyN
| DataType::DateTimeN
| DataType::DateN
| DataType::TimeN
| DataType::DateTime2N
| DataType::DateTimeOffsetN
| DataType::Char
| DataType::VarChar
| DataType::Binary
| DataType::VarBinary => {
self.put_byte_len_value(buf, value);
}
DataType::BigVarBinary
| DataType::BigVarChar
| DataType::BigBinary
| DataType::BigChar
| DataType::NVarChar
| DataType::NChar
| DataType::Xml
| DataType::UserDefined => {
self.put_short_len_value(buf, value);
}
DataType::Text | DataType::Image | DataType::NText | DataType::Variant => {
self.put_long_len_value(buf, value);
}
}
}
pub(crate) fn put_fixed_value<'q, T: Encode<'q, MsSql>>(&self, buf: &mut Vec<u8>, value: T) {
let _ = value.encode(buf);
}
pub(crate) fn put_byte_len_value<'q, T: Encode<'q, MsSql>>(&self, buf: &mut Vec<u8>, value: T) {
let offset = buf.len();
buf.push(0);
let _ = value.encode(buf);
buf[offset] = (buf.len() - offset - 1) as u8;
}
pub(crate) fn put_short_len_value<'q, T: Encode<'q, MsSql>>(
&self,
buf: &mut Vec<u8>,
value: T,
) {
let offset = buf.len();
buf.extend(&0_u16.to_le_bytes());
let _ = value.encode(buf);
let size = (buf.len() - offset - 2) as u16;
buf[offset..(offset + 2)].copy_from_slice(&size.to_le_bytes());
}
pub(crate) fn put_long_len_value<'q, T: Encode<'q, MsSql>>(&self, buf: &mut Vec<u8>, value: T) {
let offset = buf.len();
buf.extend(&0_u32.to_le_bytes());
let _ = value.encode(buf);
let size = (buf.len() - offset - 4) as u32;
buf[offset..(offset + 4)].copy_from_slice(&size.to_le_bytes());
}
pub(crate) fn fmt(&self, s: &mut String) {
match self.ty {
DataType::Null => s.push_str("nvarchar(1)"),
DataType::TinyInt => s.push_str("tinyint"),
DataType::SmallInt => s.push_str("smallint"),
DataType::Int => s.push_str("int"),
DataType::BigInt => s.push_str("bigint"),
DataType::Real => s.push_str("real"),
DataType::Float => s.push_str("float"),
DataType::IntN => s.push_str(match self.size {
1 => "tinyint",
2 => "smallint",
4 => "int",
8 => "bigint",
_ => unreachable!("invalid size {} for int"),
}),
DataType::FloatN => s.push_str(match self.size {
4 => "real",
8 => "float",
_ => unreachable!("invalid size {} for float"),
}),
DataType::NVarChar => {
s.push_str("nvarchar(");
let _ = itoa::fmt(&mut *s, self.size / 2);
s.push_str(")");
}
_ => unimplemented!("unsupported data type {:?}", self.ty),
}
}
}
@ -69,6 +471,36 @@ impl DataType {
0x3e => DataType::Float,
0x7a => DataType::SmallMoney,
0x7f => DataType::BigInt,
0x24 => DataType::Guid,
0x26 => DataType::IntN,
0x37 => DataType::Decimal,
0x3f => DataType::Numeric,
0x68 => DataType::BitN,
0x6a => DataType::DecimalN,
0x6c => DataType::NumericN,
0x6d => DataType::FloatN,
0x6e => DataType::MoneyN,
0x6f => DataType::DateTimeN,
0x28 => DataType::DateN,
0x29 => DataType::TimeN,
0x2a => DataType::DateTime2N,
0x2b => DataType::DateTimeOffsetN,
0x2f => DataType::Char,
0x27 => DataType::VarChar,
0x2d => DataType::Binary,
0x25 => DataType::VarBinary,
0xa5 => DataType::BigVarBinary,
0xa7 => DataType::BigVarChar,
0xad => DataType::BigBinary,
0xaf => DataType::BigChar,
0xe7 => DataType::NVarChar,
0xef => DataType::NChar,
0xf1 => DataType::Xml,
0xf0 => DataType::UserDefined,
0x23 => DataType::Text,
0x22 => DataType::Image,
0x63 => DataType::NText,
0x62 => DataType::Variant,
ty => {
return Err(err_protocol!("unknown data type 0x{:02x}", ty));
@ -76,3 +508,28 @@ impl DataType {
})
}
}
impl Collation {
pub(crate) fn get(buf: &mut Bytes) -> Collation {
let locale_sort_version = buf.get_u32();
let locale = locale_sort_version & 0xF_FFFF;
let flags = CollationFlags::from_bits_truncate(((locale_sort_version >> 20) & 0xFF) as u8);
let version = (locale_sort_version >> 28) as u8;
let sort = buf.get_u8();
Collation {
locale,
flags,
sort,
version,
}
}
pub(crate) fn put(&self, buf: &mut Vec<u8>) {
let locale_sort_version =
self.locale | ((self.flags.bits() as u32) << 20) | ((self.version as u32) << 28);
buf.extend(&locale_sort_version.to_le_bytes());
buf.push(self.sort);
}
}

View File

@ -10,6 +10,9 @@ impl TypeInfo for MsSqlTypeInfo {}
impl Display for MsSqlTypeInfo {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
unimplemented!()
let mut buf = String::new();
self.0.fmt(&mut buf);
f.pad(&*buf)
}
}

View File

@ -1,7 +1,8 @@
use byteorder::{ByteOrder, LittleEndian};
use crate::database::{Database, HasValueRef};
use crate::database::{Database, HasArguments, HasValueRef};
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::mssql::protocol::type_info::{DataType, TypeInfo};
use crate::mssql::{MsSql, MsSqlTypeInfo, MsSqlValueRef};
@ -9,11 +10,23 @@ use crate::types::Type;
impl Type<MsSql> for i32 {
fn type_info() -> MsSqlTypeInfo {
MsSqlTypeInfo(TypeInfo { ty: DataType::Int })
MsSqlTypeInfo(TypeInfo::new(DataType::IntN, 4))
}
}
impl Encode<'_, MsSql> for i32 {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
buf.extend(&self.to_le_bytes());
IsNull::No
}
}
impl Decode<'_, MsSql> for i32 {
fn accepts(ty: &MsSqlTypeInfo) -> bool {
matches!(ty.0.ty, DataType::Int | DataType::IntN) && ty.0.size == 4
}
fn decode(value: MsSqlValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(LittleEndian::read_i32(value.as_bytes()?))
}

View File

@ -1 +1,2 @@
mod int;
mod str;

View File

@ -0,0 +1,28 @@
use byteorder::{ByteOrder, LittleEndian};
use crate::database::{Database, HasArguments, HasValueRef};
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::mssql::io::MsSqlBufMutExt;
use crate::mssql::protocol::type_info::{DataType, TypeInfo};
use crate::mssql::{MsSql, MsSqlTypeInfo, MsSqlValueRef};
use crate::types::Type;
impl Type<MsSql> for str {
fn type_info() -> MsSqlTypeInfo {
MsSqlTypeInfo(TypeInfo::new(DataType::NVarChar, 0))
}
}
impl Encode<'_, MsSql> for &'_ str {
fn produces(&self) -> MsSqlTypeInfo {
MsSqlTypeInfo(TypeInfo::new(DataType::NVarChar, (self.len() * 2) as u32))
}
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
buf.put_utf16_str(self);
IsNull::No
}
}

View File

@ -1 +0,0 @@

View File

@ -15,7 +15,7 @@ async fn it_connects() -> anyhow::Result<()> {
}
#[sqlx_macros::test]
async fn it_can_select_1() -> anyhow::Result<()> {
async fn it_can_select_expression() -> anyhow::Result<()> {
let mut conn = new::<MsSql>().await?;
let row: MsSqlRow = conn.fetch_one("SELECT 4").await?;
@ -25,3 +25,18 @@ async fn it_can_select_1() -> anyhow::Result<()> {
Ok(())
}
#[sqlx_macros::test]
async fn it_maths() -> anyhow::Result<()> {
let mut conn = new::<MsSql>().await?;
let value = sqlx::query("SELECT 1 + @p1")
.bind(5_i32)
.try_map(|row: MsSqlRow| row.try_get::<i32, _>(0))
.fetch_one(&mut conn)
.await?;
assert_eq!(6_i32, value);
Ok(())
}