Fixing panics for disabled statement cache

This commit is contained in:
Julius de Bruijn 2020-07-10 11:11:43 +02:00 committed by Ryan Leckey
parent a27244b3c9
commit 0c9bea4ab2
9 changed files with 93 additions and 32 deletions

1
Cargo.lock generated
View File

@ -2429,6 +2429,7 @@ dependencies = [
"time 0.2.16",
"tokio 0.2.21",
"trybuild",
"url 2.1.1",
]
[[package]]

View File

@ -92,6 +92,7 @@ sqlx-test = { path = "./sqlx-test" }
paste = "0.1.16"
serde = { version = "1.0.111", features = [ "derive" ] }
serde_json = "1.0.53"
url = "2.1.1"
#
# Any

View File

@ -63,4 +63,9 @@ impl<T> StatementCache<T> {
pub fn capacity(&self) -> usize {
self.inner.capacity()
}
/// Returns true if the cache capacity is more than 0.
pub fn is_enabled(&self) -> bool {
self.capacity() > 0
}
}

View File

@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{borrow::Cow, sync::Arc};
use bytes::Bytes;
use either::Either;
@ -26,9 +26,10 @@ use crate::mysql::{
use crate::statement::StatementInfo;
impl MySqlConnection {
async fn prepare(&mut self, query: &str) -> Result<&mut MySqlStatement, Error> {
async fn prepare<'a>(&'a mut self, query: &str) -> Result<Cow<'a, MySqlStatement>, Error> {
if self.cache_statement.contains_key(query) {
return Ok(self.cache_statement.get_mut(query).unwrap());
let stmt = self.cache_statement.get_mut(query).unwrap();
return Ok(Cow::Borrowed(&*stmt));
}
// https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html
@ -80,16 +81,21 @@ impl MySqlConnection {
nullable,
};
// in case of the cache being full, close the least recently used statement
if let Some(statement) = self.cache_statement.insert(query, statement) {
self.stream
.send_packet(StmtClose {
statement: statement.id,
})
.await?;
}
if self.cache_statement.is_enabled() {
// in case of the cache being full, close the least recently used statement
if let Some(statement) = self.cache_statement.insert(query, statement) {
self.stream
.send_packet(StmtClose {
statement: statement.id,
})
.await?;
}
Ok(self.cache_statement.get_mut(query).unwrap())
let stmt = self.cache_statement.get_mut(query).unwrap();
Ok(Cow::Borrowed(&*stmt))
} else {
Ok(Cow::Owned(statement))
}
}
async fn recv_result_metadata(&mut self, mut packet: Packet<Bytes>) -> Result<(), Error> {

View File

@ -1,6 +1,6 @@
use super::MySqlColumn;
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct MySqlStatement {
pub(crate) id: u32,
pub(crate) columns: Vec<MySqlColumn>,

View File

@ -3,7 +3,7 @@ use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use futures_core::Stream;
use futures_util::{pin_mut, TryStreamExt};
use std::sync::Arc;
use std::{borrow::Cow, sync::Arc};
use crate::error::Error;
use crate::executor::{Execute, Executor};
@ -155,30 +155,37 @@ impl PgConnection {
self.pending_ready_for_query_count += 1;
}
async fn prepare(
&mut self,
async fn prepare<'a>(
&'a mut self,
query: &str,
arguments: &PgArguments,
) -> Result<&mut PgStatement, Error> {
) -> Result<Cow<'a, PgStatement>, Error> {
let contains = self.cache_statement.contains_key(query);
if contains {
return Ok(self.cache_statement.get_mut(query).unwrap());
return Ok(Cow::Borrowed(
&*self.cache_statement.get_mut(query).unwrap(),
));
}
let statement = prepare(self, query, arguments).await?;
if let Some(statement) = self.cache_statement.insert(query, statement) {
self.stream.write(Close::Statement(statement.id));
self.stream.write(Flush);
if self.cache_statement.is_enabled() {
if let Some(statement) = self.cache_statement.insert(query, statement) {
self.stream.write(Close::Statement(statement.id));
self.stream.write(Flush);
self.stream.flush().await?;
self.stream.flush().await?;
self.wait_for_close_complete(1).await?;
self.recv_ready_for_query().await?;
self.wait_for_close_complete(1).await?;
}
Ok(Cow::Borrowed(
&*self.cache_statement.get_mut(query).unwrap(),
))
} else {
Ok(Cow::Owned(statement))
}
Ok(self.cache_statement.get_mut(query).unwrap())
}
async fn run(

View File

@ -1,6 +1,6 @@
use super::{PgColumn, PgTypeInfo};
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct PgStatement {
pub(crate) id: u32,
pub(crate) columns: Vec<PgColumn>,

View File

@ -1,7 +1,8 @@
use futures::TryStreamExt;
use sqlx::mysql::{MySql, MySqlPool, MySqlPoolOptions, MySqlRow};
use sqlx::{Connection, Done, Executor, Row};
use sqlx_test::new;
use sqlx_test::{new, setup_if_needed};
use std::env;
#[sqlx_macros::test]
async fn it_connects() -> anyhow::Result<()> {
@ -97,6 +98,26 @@ async fn it_executes_with_pool() -> anyhow::Result<()> {
Ok(())
}
#[sqlx_macros::test]
async fn it_works_with_cache_disabled() -> anyhow::Result<()> {
setup_if_needed();
let mut url = url::Url::parse(&env::var("DATABASE_URL")?)?;
url.query_pairs_mut()
.append_pair("statement-cache-capacity", "0");
let mut conn = MySqlConnection::connect(url.as_ref()).await?;
for index in 1..=10_i32 {
let _ = sqlx::query("SELECT ?")
.bind(index)
.execute(&mut conn)
.await?;
}
Ok(())
}
#[sqlx_macros::test]
async fn it_drops_results_in_affected_rows() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;

View File

@ -2,9 +2,9 @@ use futures::TryStreamExt;
use sqlx::postgres::{
PgConnectOptions, PgConnection, PgDatabaseError, PgErrorPosition, PgSeverity,
};
use sqlx::postgres::{PgPoolOptions, PgRow};
use sqlx::{postgres::Postgres, Connection, Done, Executor, Row};
use sqlx_test::new;
use sqlx::postgres::{PgPoolOptions, PgRow, Postgres};
use sqlx::{Connection, Done, Executor, PgPool, Row};
use sqlx_test::{new, setup_if_needed};
use std::env;
use std::thread;
use std::time::Duration;
@ -124,7 +124,7 @@ async fn it_describes_and_inserts_json() -> anyhow::Result<()> {
let _ = conn
.execute(
r#"
CREATE TEMPORARY TABLE json_stuff (obj json);
CREATE TEMPORARY TABLE json_stuff (obj jsonb);
"#,
)
.await?;
@ -142,6 +142,26 @@ CREATE TEMPORARY TABLE json_stuff (obj json);
Ok(())
}
#[sqlx_macros::test]
async fn it_works_with_cache_disabled() -> anyhow::Result<()> {
setup_if_needed();
let mut url = url::Url::parse(&env::var("DATABASE_URL")?)?;
url.query_pairs_mut()
.append_pair("statement-cache-capacity", "0");
let mut conn = PgConnection::connect(url.as_ref()).await?;
for index in 1..=10_i32 {
let _ = sqlx::query("SELECT $1")
.bind(index)
.execute(&mut conn)
.await?;
}
Ok(())
}
#[sqlx_macros::test]
async fn it_executes_with_pool() -> anyhow::Result<()> {
let pool = sqlx_test::pool::<Postgres>().await?;