From e26f46b5d24c61708b9f5fd70d67c6a8a0c7621c Mon Sep 17 00:00:00 2001 From: itsscb Date: Thu, 28 Sep 2023 00:17:03 +0200 Subject: [PATCH] Add login logic (token) (#42) --- api/account.go | 72 ++++- api/account_test.go | 374 ++++++++++++++--------- api/main_test.go | 6 +- api/middleware.go | 54 ++++ api/middleware_test.go | 114 +++++++ api/server.go | 35 ++- api/session.go | 93 ++++++ api/token.go | 83 +++++ app.env | 5 +- db/migration/000001_init_schema.down.sql | 1 + db/migration/000001_init_schema.up.sql | 21 +- db/mock/store.go | 46 +++ db/query/account.sql | 4 + db/query/payment.sql | 8 +- db/query/session.sql | 16 + db/sqlc/account.sql.go | 67 ++-- db/sqlc/models.go | 21 +- db/sqlc/payment.sql.go | 40 +-- db/sqlc/payment_test.go | 12 +- db/sqlc/querier.go | 5 + db/sqlc/session.sql.go | 82 +++++ db/sqlc/tx_create_account.go | 10 +- db/sqlc/tx_update_account.go | 12 +- go.mod | 9 +- go.sum | 16 + main.go | 5 +- token/maker.go | 14 + token/paseto_maker.go | 57 ++++ token/paseto_maker_test.go | 49 +++ token/payload.go | 46 +++ util/config.go | 19 +- util/password.go | 21 ++ util/password_test.go | 28 ++ 33 files changed, 1206 insertions(+), 239 deletions(-) create mode 100644 api/middleware.go create mode 100644 api/middleware_test.go create mode 100644 api/session.go create mode 100644 api/token.go create mode 100644 db/query/session.sql create mode 100644 db/sqlc/session.sql.go create mode 100644 token/maker.go create mode 100644 token/paseto_maker.go create mode 100644 token/paseto_maker_test.go create mode 100644 token/payload.go create mode 100644 util/password.go create mode 100644 util/password_test.go diff --git a/api/account.go b/api/account.go index 817bd98..eac381f 100644 --- a/api/account.go +++ b/api/account.go @@ -2,11 +2,15 @@ package api import ( "database/sql" + "errors" + "fmt" "net/http" "time" "github.com/gin-gonic/gin" db "github.com/itsscb/df/db/sqlc" + "github.com/itsscb/df/token" + "golang.org/x/exp/slog" ) type createAccountRequest struct { @@ -51,13 +55,6 @@ func (server *Server) createAccount(ctx *gin.Context) { }, } - // if req.PrivacyAccepted { - // arg.PrivacyAcceptedDate = sql.NullTime{ - // Valid: true, - // Time: time.Now(), - // } - // } - account, err := server.store.CreateAccountTx(ctx, arg) if err != nil { ctx.JSON(http.StatusInternalServerError, errorResponse(err)) @@ -90,12 +87,19 @@ func (server *Server) getAccount(ctx *gin.Context) { return } + authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) + if account.Email != authPayload.Email { + err := errors.New("account doesn't belong to the authenticated user") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + ctx.JSON(http.StatusOK, account) } type listAccountRequest struct { - PageID int32 `form:"pageid" binding:"required,min=1"` - PageSize int32 `form:"pagesize" binding:"required,min=5,max=50"` + PageID int32 `form:"page_id" binding:"required,min=1"` + PageSize int32 `form:"page_size" binding:"required,min=5,max=50"` } func (server *Server) listAccounts(ctx *gin.Context) { @@ -106,6 +110,26 @@ func (server *Server) listAccounts(ctx *gin.Context) { return } + authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) + slog.Error("auth", "payload", fmt.Sprintf("%#v", authPayload)) + + account, err := server.store.GetAccountByEmail(ctx, authPayload.Email) + if err != nil { + if err == sql.ErrNoRows { + ctx.JSON(http.StatusNotFound, errorResponse(err)) + return + } + + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + if account.PermissionLevel < 1 { + err := errors.New("only for admin users") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + arg := db.ListAccountsParams{ Limit: req.PageSize, Offset: (req.PageID - 1) * req.PageSize, @@ -133,7 +157,20 @@ func (server *Server) updateAccountPrivacy(ctx *gin.Context) { return } - account, err := server.store.UpdateAccountPrivacyTx(ctx, db.UpdateAccountPrivacyTxParams(req)) + account, err := server.store.GetAccount(ctx, req.ID) + if err != nil { + ctx.JSON(http.StatusNotFound, errorResponse(err)) + return + } + + authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) + if account.Email != authPayload.Email { + err := errors.New("account doesn't belong to the authenticated user") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + account, err = server.store.UpdateAccountPrivacyTx(ctx, db.UpdateAccountPrivacyTxParams(req)) if err != nil { ctx.JSON(http.StatusInternalServerError, errorResponse(err)) return @@ -164,6 +201,19 @@ func (server *Server) updateAccount(ctx *gin.Context) { return } + account, err := server.store.GetAccount(ctx, req.ID) + if err != nil { + ctx.JSON(http.StatusNotFound, errorResponse(err)) + return + } + + authPayload := ctx.MustGet(authorizationPayloadKey).(*token.Payload) + if account.Email != authPayload.Email { + err := errors.New("account doesn't belong to the authenticated user") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + arg := db.UpdateAccountTxParams{ ID: req.ID, Changer: req.Changer, @@ -209,7 +259,7 @@ func (server *Server) updateAccount(ctx *gin.Context) { }, } - account, err := server.store.UpdateAccountTx(ctx, arg) + account, err = server.store.UpdateAccountTx(ctx, arg) if err != nil { ctx.JSON(http.StatusInternalServerError, errorResponse(err)) return diff --git a/api/account_test.go b/api/account_test.go index 5f82317..295c57a 100644 --- a/api/account_test.go +++ b/api/account_test.go @@ -14,6 +14,7 @@ import ( "github.com/gin-gonic/gin" mockdb "github.com/itsscb/df/db/mock" db "github.com/itsscb/df/db/sqlc" + "github.com/itsscb/df/token" "github.com/itsscb/df/util" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -27,6 +28,7 @@ func TestCreateAccountAPI(t *testing.T) { testCases := []struct { name string body gin.H + setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) buildStubs func(store *mockdb.MockStore) checkResponse func(recoder *httptest.ResponseRecorder) }{ @@ -46,6 +48,9 @@ func TestCreateAccountAPI(t *testing.T) { "phone": account.Phone.String, "creator": account.Creator, }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, account.Email, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { arg := db.CreateAccountTxParams{ Passwordhash: account.Passwordhash, @@ -75,11 +80,24 @@ func TestCreateAccountAPI(t *testing.T) { // { // name: "NoAuthorization", // body: gin.H{ - // "currency": account.Currency, + // "passwordhash": account.Passwordhash, + // "privacy_accepted": account.PrivacyAccepted.Bool, + // "firstname": account.Firstname, + // "lastname": account.Lastname, + // "birthday": account.Birthday, + // "email": account.Email, + // "city": account.City, + // "zip": account.Zip, + // "street": account.Street, + // "country": account.Country, + // "phone": account.Phone.String, + // "creator": account.Creator, + // }, + // setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { // }, // buildStubs: func(store *mockdb.MockStore) { // store.EXPECT(). - // CreateAccount(gomock.Any(), gomock.Any()). + // CreateAccountTx(gomock.Any(), gomock.Any()). // Times(0) // }, // checkResponse: func(recorder *httptest.ResponseRecorder) { @@ -91,6 +109,9 @@ func TestCreateAccountAPI(t *testing.T) { body: gin.H{ "email": account.Email, }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, account.Email, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). CreateAccountTx(gomock.Any(), gomock.Any()). @@ -117,6 +138,9 @@ func TestCreateAccountAPI(t *testing.T) { "phone": account.Phone.String, "creator": account.Creator, }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, account.Email, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). CreateAccountTx(gomock.Any(), gomock.Any()). @@ -139,7 +163,9 @@ func TestCreateAccountAPI(t *testing.T) { store := mockdb.NewMockStore(ctrl) tc.buildStubs(store) - server := NewServer(config, store) + server, err := NewServer(config, store) + require.NoError(t, err) + recorder := httptest.NewRecorder() // Marshal body data to JSON @@ -150,6 +176,7 @@ func TestCreateAccountAPI(t *testing.T) { request, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data)) require.NoError(t, err) + tc.setupAuth(t, request, server.tokenMaker) server.router.ServeHTTP(recorder, request) tc.checkResponse(recorder) }) @@ -162,12 +189,16 @@ func TestGetAccountAPI(t *testing.T) { testCases := []struct { name string accountID int64 + setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) buildStubs func(store *mockdb.MockStore) checkResponse func(t *testing.T, recoder *httptest.ResponseRecorder) }{ { name: "OK", accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, account.Email, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account.ID)). @@ -179,37 +210,42 @@ func TestGetAccountAPI(t *testing.T) { requireBodyMatchAccount(t, recorder.Body, account) }, }, - // { - // name: "UnauthorizedUser", - // accountID: account.ID, - // setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { - // addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "unauthorized_user", time.Minute) - // }, - // buildStubs: func(store *mockdb.MockStore) { - // store.EXPECT(). - // GetAccount(gomock.Any(), gomock.Eq(account.ID)). - // Times(1). - // Return(account, nil) - // }, - // checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { - // require.Equal(t, http.StatusUnauthorized, recorder.Code) - // }, - // }, - // { - // name: "NoAuthorization", - // accountID: account.ID, - // buildStubs: func(store *mockdb.MockStore) { - // store.EXPECT(). - // GetAccount(gomock.Any(), gomock.Any()). - // Times(0) - // }, - // checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { - // require.Equal(t, http.StatusUnauthorized, recorder.Code) - // }, - // }, + { + name: "UnauthorizedUser", + accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "UnauthorizedUser", time.Minute) + }, + buildStubs: func(store *mockdb.MockStore) { + store.EXPECT(). + GetAccount(gomock.Any(), gomock.Eq(account.ID)). + Times(1). + Return(account, nil) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "NoAuthorization", + accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + }, + buildStubs: func(store *mockdb.MockStore) { + store.EXPECT(). + GetAccount(gomock.Any(), gomock.Any()). + Times(0) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, { name: "NotFound", accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, account.Email, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account.ID)). @@ -223,6 +259,9 @@ func TestGetAccountAPI(t *testing.T) { { name: "InternalError", accountID: account.ID, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, account.Email, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Eq(account.ID)). @@ -236,6 +275,9 @@ func TestGetAccountAPI(t *testing.T) { { name: "InvalidID", accountID: 0, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, account.Email, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). GetAccount(gomock.Any(), gomock.Any()). @@ -257,13 +299,15 @@ func TestGetAccountAPI(t *testing.T) { store := mockdb.NewMockStore(ctrl) tc.buildStubs(store) - server := NewServer(config, store) + server, err := NewServer(config, store) + require.NoError(t, err) recorder := httptest.NewRecorder() url := fmt.Sprintf("/accounts/%d", tc.accountID) request, err := http.NewRequest(http.MethodGet, url, nil) require.NoError(t, err) + tc.setupAuth(t, request, server.tokenMaker) server.router.ServeHTTP(recorder, request) tc.checkResponse(t, recorder) }) @@ -273,67 +317,81 @@ func TestGetAccountAPI(t *testing.T) { func TestUpdateAccountTxAPI(t *testing.T) { account := randomAccount() changer := util.RandomName() - newPassword := util.RandomString(30) - newEmail := util.RandomEmail() + // newPassword := util.RandomString(30) + newLastname := util.RandomName() testCases := []struct { name string body gin.H accountID string + setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) buildStubs func(store *mockdb.MockStore) checkResponse func(recoder *httptest.ResponseRecorder) }{ + // { + // name: "OK_PasswordHash", + // body: gin.H{ + // "id": account.ID, + // "passwordhash": newPassword, + // "changer": changer, + // }, + // setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + // addAuthorization(t, request, tokenMaker, authorizationTypeBearer, account.Email, time.Minute) + // }, + // buildStubs: func(store *mockdb.MockStore) { + // var err error + // accountTemp := account + // accountTemp.Passwordhash, err = util.HashPassword(newPassword) + // require.NoError(t, err) + // accountTemp.Changer = changer + // arg := db.UpdateAccountTxParams{ + // ID: account.ID, + // Passwordhash: sql.NullString{ + // Valid: true, + // String: newPassword, + // }, + // Changer: changer, + // } + + // store.EXPECT(). + // UpdateAccountTx(gomock.Any(), gomock.Eq(arg)). + // Times(1). + // Return(accountTemp, nil) + // }, + // checkResponse: func(recorder *httptest.ResponseRecorder) { + // require.Equal(t, http.StatusOK, recorder.Code) + + // accountTemp := account + // accountTemp.Passwordhash = newPassword + // accountTemp.Changer = changer + + // requireBodyMatchAccount(t, recorder.Body, accountTemp) + // }, + // }, { - name: "OK_PasswordHash", + name: "OK_Lastname", body: gin.H{ - "id": account.ID, - "passwordhash": newPassword, - "changer": changer, + "id": account.ID, + "lastname": newLastname, + "changer": changer, + }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, account.Email, time.Minute) }, buildStubs: func(store *mockdb.MockStore) { - accountTemp := account - accountTemp.Passwordhash = newPassword - accountTemp.Changer = changer arg := db.UpdateAccountTxParams{ ID: account.ID, - Passwordhash: sql.NullString{ + Lastname: sql.NullString{ Valid: true, - String: newPassword, + String: newLastname, }, Changer: changer, } store.EXPECT(). - UpdateAccountTx(gomock.Any(), gomock.Eq(arg)). + GetAccount(gomock.Any(), gomock.Eq(account.ID)). Times(1). - Return(accountTemp, nil) - }, - checkResponse: func(recorder *httptest.ResponseRecorder) { - require.Equal(t, http.StatusOK, recorder.Code) - - accountTemp := account - accountTemp.Passwordhash = newPassword - accountTemp.Changer = changer - - requireBodyMatchAccount(t, recorder.Body, accountTemp) - }, - }, - { - name: "OK_Email", - body: gin.H{ - "id": account.ID, - "email": newEmail, - "changer": changer, - }, - buildStubs: func(store *mockdb.MockStore) { - arg := db.UpdateAccountTxParams{ - ID: account.ID, - Email: sql.NullString{ - Valid: true, - String: newEmail, - }, - Changer: changer, - } + Return(account, nil) store.EXPECT(). UpdateAccountTx(gomock.Any(), gomock.Eq(arg)). @@ -345,58 +403,32 @@ func TestUpdateAccountTxAPI(t *testing.T) { requireBodyMatchAccount(t, recorder.Body, account) }, }, - // { - // name: "OK_PrivacyAccepted", - // body: gin.H{ - // "id": account.ID, - // "privacy_accepted": true, - // "changer": changer, - // }, - // buildStubs: func(store *mockdb.MockStore) { - // accountAccepted := account - // accountAccepted.PrivacyAccepted = sql.NullBool{ - // Valid: true, - // Bool: true, - // } - // accountAccepted.PrivacyAcceptedDate = sql.NullTime{ - // Valid: true, - // Time: timestamp, - // } - - // arg := db.UpdateAccountTxParams{ - // ID: account.ID, - // PrivacyAccepted: sql.NullBool{ - // Valid: true, - // Bool: true, - // }, - // Changer: changer, - // } - - // store.EXPECT(). - // UpdateAccountTx(gomock.Any(), gomock.Eq(arg)). - // Times(1). - // Return(accountAccepted, nil) - // }, - // }, - // { - // name: "NoAuthorization", - // body: gin.H{ - // "currency": account.Currency, - // }, - // buildStubs: func(store *mockdb.MockStore) { - // store.EXPECT(). - // CreateAccount(gomock.Any(), gomock.Any()). - // Times(0) - // }, - // checkResponse: func(recorder *httptest.ResponseRecorder) { - // require.Equal(t, http.StatusUnauthorized, recorder.Code) - // }, - // }, + { + name: "NoAuthorization", + body: gin.H{ + "id": account.ID, + "lastname": newLastname, + "changer": changer, + }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + }, + buildStubs: func(store *mockdb.MockStore) { + store.EXPECT(). + CreateAccount(gomock.Any(), gomock.Any()). + Times(0) + }, + checkResponse: func(recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, { name: "BadRequest", body: gin.H{ "email": account.Email, }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, account.Email, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). CreateAccount(gomock.Any(), gomock.Any()). @@ -419,7 +451,9 @@ func TestUpdateAccountTxAPI(t *testing.T) { store := mockdb.NewMockStore(ctrl) tc.buildStubs(store) - server := NewServer(config, store) + server, err := NewServer(config, store) + require.NoError(t, err) + recorder := httptest.NewRecorder() // Marshal body data to JSON @@ -430,6 +464,7 @@ func TestUpdateAccountTxAPI(t *testing.T) { request, err := http.NewRequest(http.MethodPut, url, bytes.NewReader(data)) require.NoError(t, err) + tc.setupAuth(t, request, server.tokenMaker) server.router.ServeHTTP(recorder, request) tc.checkResponse(recorder) }) @@ -443,6 +478,7 @@ func TestListAccountsAPI(t *testing.T) { for i := 0; i < n; i++ { accounts[i] = randomAccount() } + account := accounts[1] type Query struct { pageID int @@ -452,6 +488,7 @@ func TestListAccountsAPI(t *testing.T) { testCases := []struct { name string query Query + setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) buildStubs func(store *mockdb.MockStore) checkResponse func(recoder *httptest.ResponseRecorder) }{ @@ -461,12 +498,23 @@ func TestListAccountsAPI(t *testing.T) { pageID: 1, pageSize: n, }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, account.Email, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { arg := db.ListAccountsParams{ Limit: int32(n), Offset: 0, } + accountAdmin := account + accountAdmin.PermissionLevel = 1 + + store.EXPECT(). + GetAccountByEmail(gomock.Any(), gomock.Eq(account.Email)). + Times(1). + Return(accountAdmin, nil) + store.EXPECT(). ListAccounts(gomock.Any(), gomock.Eq(arg)). Times(1). @@ -477,24 +525,29 @@ func TestListAccountsAPI(t *testing.T) { requireBodyMatchAccounts(t, recorder.Body, accounts) }, }, - // { - // name: "NoAuthorization", - // query: Query{ - // pageID: 1, - // pageSize: n, - // }, - // buildStubs: func(store *mockdb.MockStore) { - // store.EXPECT(). - // ListAccounts(gomock.Any(), gomock.Any()). - // Times(0) - // }, - // checkResponse: func(recorder *httptest.ResponseRecorder) { - // require.Equal(t, http.StatusUnauthorized, recorder.Code) - // }, - // }, + { + name: "NoAuthorization", + query: Query{ + pageID: 1, + pageSize: n, + }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + }, + buildStubs: func(store *mockdb.MockStore) { + store.EXPECT(). + ListAccounts(gomock.Any(), gomock.Any()). + Times(0) + }, + checkResponse: func(recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, { name: "EmptyQuery", query: Query{}, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, account.Email, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). ListAccounts(gomock.Any(), gomock.Any()). @@ -510,6 +563,9 @@ func TestListAccountsAPI(t *testing.T) { pageID: -1, pageSize: n, }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, account.Email, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). ListAccounts(gomock.Any(), gomock.Any()). @@ -525,6 +581,9 @@ func TestListAccountsAPI(t *testing.T) { pageID: 1, pageSize: 100000, }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, account.Email, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { store.EXPECT(). ListAccounts(gomock.Any(), gomock.Any()). @@ -546,7 +605,9 @@ func TestListAccountsAPI(t *testing.T) { store := mockdb.NewMockStore(ctrl) tc.buildStubs(store) - server := NewServer(config, store) + server, err := NewServer(config, store) + require.NoError(t, err) + recorder := httptest.NewRecorder() url := "/accounts" @@ -555,10 +616,11 @@ func TestListAccountsAPI(t *testing.T) { // Add query parameters to request URL q := request.URL.Query() - q.Add("pageid", fmt.Sprintf("%d", tc.query.pageID)) - q.Add("pagesize", fmt.Sprintf("%d", tc.query.pageSize)) + q.Add("page_id", fmt.Sprintf("%d", tc.query.pageID)) + q.Add("page_size", fmt.Sprintf("%d", tc.query.pageSize)) request.URL.RawQuery = q.Encode() + tc.setupAuth(t, request, server.tokenMaker) server.router.ServeHTTP(recorder, request) tc.checkResponse(recorder) }) @@ -572,6 +634,7 @@ func TestUpdateAccountPrivacyTxAPI(t *testing.T) { testCases := []struct { name string body gin.H + setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) buildStubs func(store *mockdb.MockStore) checkResponse func(recoder *httptest.ResponseRecorder) }{ @@ -582,6 +645,9 @@ func TestUpdateAccountPrivacyTxAPI(t *testing.T) { "changer": changer, "privacy_accepted": true, }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, account.Email, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { arg := db.UpdateAccountPrivacyTxParams{ ID: account.ID, @@ -594,6 +660,11 @@ func TestUpdateAccountPrivacyTxAPI(t *testing.T) { account2.PrivacyAccepted.Bool = true account2.Changer = changer + store.EXPECT(). + GetAccount(gomock.Any(), gomock.Eq(account.ID)). + Times(1). + Return(account, nil) + store.EXPECT(). UpdateAccountPrivacyTx(gomock.Any(), gomock.Eq(arg)). Times(1). @@ -621,6 +692,9 @@ func TestUpdateAccountPrivacyTxAPI(t *testing.T) { "changer": changer, "privacy_accepted": false, }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, account.Email, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { arg := db.UpdateAccountPrivacyTxParams{ ID: account.ID, @@ -635,6 +709,11 @@ func TestUpdateAccountPrivacyTxAPI(t *testing.T) { account2.PrivacyAcceptedDate.Time = time.Time{} account2.Changer = changer + store.EXPECT(). + GetAccount(gomock.Any(), gomock.Eq(account.ID)). + Times(1). + Return(account, nil) + store.EXPECT(). UpdateAccountPrivacyTx(gomock.Any(), gomock.Eq(arg)). Times(1). @@ -656,11 +735,18 @@ func TestUpdateAccountPrivacyTxAPI(t *testing.T) { }, }, { - name: "OK", + name: "InvalidRequest", body: gin.H{ "id": account.ID, }, + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, account.Email, time.Minute) + }, buildStubs: func(store *mockdb.MockStore) { + store.EXPECT(). + GetAccount(gomock.Any(), gomock.Any()). + Times(0) + store.EXPECT(). UpdateAccountPrivacyTx(gomock.Any(), gomock.Any()). Times(0) @@ -681,7 +767,9 @@ func TestUpdateAccountPrivacyTxAPI(t *testing.T) { store := mockdb.NewMockStore(ctrl) tc.buildStubs(store) - server := NewServer(config, store) + server, err := NewServer(config, store) + require.NoError(t, err) + recorder := httptest.NewRecorder() // Marshal body data to JSON @@ -692,6 +780,7 @@ func TestUpdateAccountPrivacyTxAPI(t *testing.T) { request, err := http.NewRequest(http.MethodPut, url, bytes.NewReader(data)) require.NoError(t, err) + tc.setupAuth(t, request, server.tokenMaker) server.router.ServeHTTP(recorder, request) tc.checkResponse(recorder) }) @@ -699,9 +788,12 @@ func TestUpdateAccountPrivacyTxAPI(t *testing.T) { } func randomAccount() db.Account { + password := util.RandomString(6) + hashedPassword, _ := util.HashPassword(password) + acc := db.Account{ ID: util.RandomInt(1, 1000), - Passwordhash: util.RandomString(250), + Passwordhash: hashedPassword, Firstname: util.RandomName(), Lastname: util.RandomName(), Email: util.RandomEmail(), diff --git a/api/main_test.go b/api/main_test.go index 7732ff7..1004d18 100644 --- a/api/main_test.go +++ b/api/main_test.go @@ -3,6 +3,7 @@ package api import ( "os" "testing" + "time" "github.com/gin-gonic/gin" "github.com/itsscb/df/util" @@ -12,7 +13,10 @@ var config util.Config func TestMain(m *testing.M) { config = util.Config{ - Environment: "production", + Environment: "production", + TokenSymmetricKey: "12345678901234567890123456789012", + AccessTokenDuration: time.Minute * 1, + RefreshTokenDuration: time.Minute * 2, } gin.SetMode(gin.TestMode) diff --git a/api/middleware.go b/api/middleware.go new file mode 100644 index 0000000..cbdb506 --- /dev/null +++ b/api/middleware.go @@ -0,0 +1,54 @@ +package api + +import ( + "errors" + "fmt" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/itsscb/df/token" +) + +const ( + authorizationHeaderKey = "authorization" + authorizationTypeBearer = "bearer" + authorizationPayloadKey = "authorization_payload" +) + +// AuthMiddleware creates a gin middleware for authorization +func authMiddleware(tokenMaker token.Maker) gin.HandlerFunc { + return func(ctx *gin.Context) { + authorizationHeader := ctx.GetHeader(authorizationHeaderKey) + + if len(authorizationHeader) == 0 { + err := errors.New("authorization header is not provided") + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + fields := strings.Fields(authorizationHeader) + if len(fields) < 2 { + err := errors.New("invalid authorization header format") + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + authorizationType := strings.ToLower(fields[0]) + if authorizationType != authorizationTypeBearer { + err := fmt.Errorf("unsupported authorization type %s", authorizationType) + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + accessToken := fields[1] + payload, err := tokenMaker.VerifyToken(accessToken) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + ctx.Set(authorizationPayloadKey, payload) + ctx.Next() + } +} diff --git a/api/middleware_test.go b/api/middleware_test.go new file mode 100644 index 0000000..27adde0 --- /dev/null +++ b/api/middleware_test.go @@ -0,0 +1,114 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + mockdb "github.com/itsscb/df/db/mock" + "github.com/itsscb/df/token" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func addAuthorization( + t *testing.T, + request *http.Request, + tokenMaker token.Maker, + authorizationType string, + email string, + duration time.Duration, +) { + token, payload, err := tokenMaker.CreateToken(email, duration) + require.NoError(t, err) + require.NotEmpty(t, payload) + + authorizationHeader := fmt.Sprintf("%s %s", authorizationType, token) + request.Header.Set(authorizationHeaderKey, authorizationHeader) +} + +func TestAuthMiddleware(t *testing.T) { + testCases := []struct { + name string + setupAuth func(t *testing.T, request *http.Request, tokenMaker token.Maker) + checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder) + }{ + { + name: "OK", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "user", time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusOK, recorder.Code) + }, + }, + { + name: "NoAuthorization", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "UnsupportedAuthorization", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, "unsupported", "user", time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "InvalidAuthorizationFormat", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, "", "user", time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + { + name: "ExpiredToken", + setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) { + addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "user", -time.Minute) + }, + checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) { + require.Equal(t, http.StatusUnauthorized, recorder.Code) + }, + }, + } + + for i := range testCases { + tc := testCases[i] + + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + store := mockdb.NewMockStore(ctrl) + + server, err := NewServer(config, store) + require.NoError(t, err) + authPath := "/auth" + server.router.GET( + authPath, + authMiddleware(server.tokenMaker), + func(ctx *gin.Context) { + ctx.JSON(http.StatusOK, gin.H{}) + }, + ) + + recorder := httptest.NewRecorder() + request, err := http.NewRequest(http.MethodGet, authPath, nil) + require.NoError(t, err) + + tc.setupAuth(t, request, server.tokenMaker) + server.router.ServeHTTP(recorder, request) + tc.checkResponse(t, recorder) + }) + } +} diff --git a/api/server.go b/api/server.go index 083bbd0..8d9b43d 100644 --- a/api/server.go +++ b/api/server.go @@ -1,26 +1,35 @@ package api import ( + "fmt" "log/slog" "os" "github.com/gin-gonic/gin" db "github.com/itsscb/df/db/sqlc" + "github.com/itsscb/df/token" "github.com/itsscb/df/util" ) // Server serves HTTP requests for df service type Server struct { - store db.Store - router *gin.Engine - config util.Config + store db.Store + router *gin.Engine + config util.Config + tokenMaker token.Maker } // NewServer creates a new HTTP server and sets up routing -func NewServer(config util.Config, store db.Store) *Server { +func NewServer(config util.Config, store db.Store) (*Server, error) { + tokenMaker, err := token.NewPasetoMaker(config.TokenSymmetricKey) + if err != nil { + return nil, fmt.Errorf("cannot create token maker: %w", err) + } + server := &Server{ - store: store, - config: config, + store: store, + config: config, + tokenMaker: tokenMaker, } logLevel := slog.LevelError @@ -45,14 +54,18 @@ func NewServer(config util.Config, store db.Store) *Server { router.Use(Logger()) + router.POST("/accounts/login", server.loginAccount) + router.POST("/tokens/renew_access", server.renewAccessToken) router.POST("/accounts", server.createAccount) - router.PUT("/accounts", server.updateAccount) - router.PUT("/accounts/privacy", server.updateAccountPrivacy) - router.GET("/accounts/:id", server.getAccount) - router.GET("/accounts", server.listAccounts) + + authRoutes := router.Group("/").Use(authMiddleware(server.tokenMaker)) + authRoutes.PUT("/accounts", server.updateAccount) + authRoutes.PUT("/accounts/privacy", server.updateAccountPrivacy) + authRoutes.GET("/accounts/:id", server.getAccount) + authRoutes.GET("/accounts", server.listAccounts) server.router = router - return server + return server, nil } func (server *Server) Start(address string) error { diff --git a/api/session.go b/api/session.go new file mode 100644 index 0000000..1f08ca2 --- /dev/null +++ b/api/session.go @@ -0,0 +1,93 @@ +package api + +import ( + "database/sql" + "errors" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + db "github.com/itsscb/df/db/sqlc" + "github.com/itsscb/df/util" +) + +type loginAccountRequest struct { + Email string `json:"email" binding:"required"` + Password string `json:"password" binding:"required,min=6"` +} + +type loginAccountResponse struct { + SessionID uuid.UUID `json:"session_id"` + AccessToken string `json:"access_token"` + AccessTokenExpiresAt time.Time `json:"access_token_expires_at"` + RefreshToken string `json:"refresh_token"` + RefreshTokenExpiresAt time.Time `json:"refresh_token_expires_at"` + Email string `json:"email"` +} + +func (server *Server) loginAccount(ctx *gin.Context) { + var req loginAccountRequest + if err := ctx.ShouldBindJSON(&req); err != nil { + ctx.JSON(http.StatusBadRequest, errorResponse(err)) + return + } + + account, err := server.store.GetAccountByEmail(ctx, req.Email) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + ctx.JSON(http.StatusNotFound, errorResponse(err)) + return + } + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + err = util.CheckPassword(req.Password, account.Passwordhash) + if err != nil { + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + accessToken, accessPayload, err := server.tokenMaker.CreateToken( + account.Email, + server.config.AccessTokenDuration, + ) + if err != nil { + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + refreshToken, refreshPayload, err := server.tokenMaker.CreateToken( + account.Email, + server.config.RefreshTokenDuration, + ) + if err != nil { + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + session, err := server.store.CreateSession(ctx, db.CreateSessionParams{ + ID: refreshPayload.ID, + Email: account.Email, + RefreshToken: refreshToken, + UserAgent: ctx.Request.UserAgent(), + ClientIp: ctx.ClientIP(), + IsBlocked: false, + ExpiresAt: refreshPayload.ExpiredAt, + }) + if err != nil { + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + rsp := loginAccountResponse{ + SessionID: session.ID, + AccessToken: accessToken, + AccessTokenExpiresAt: accessPayload.ExpiredAt, + RefreshToken: refreshToken, + RefreshTokenExpiresAt: refreshPayload.ExpiredAt, + Email: account.Email, + } + ctx.JSON(http.StatusOK, rsp) +} diff --git a/api/token.go b/api/token.go new file mode 100644 index 0000000..cd3f11a --- /dev/null +++ b/api/token.go @@ -0,0 +1,83 @@ +package api + +import ( + "database/sql" + "errors" + "fmt" + "net/http" + "time" + + "github.com/gin-gonic/gin" +) + +type renewAccessTokenRequest struct { + RefreshToken string `json:"refresh_token" binding:"required"` +} + +type renewAccessTokenResponse struct { + AccessToken string `json:"access_token"` + AccessTokenExpiresAt time.Time `json:"access_token_expires_at"` +} + +func (server *Server) renewAccessToken(ctx *gin.Context) { + var req renewAccessTokenRequest + if err := ctx.ShouldBindJSON(&req); err != nil { + ctx.JSON(http.StatusBadRequest, errorResponse(err)) + return + } + + refreshPayload, err := server.tokenMaker.VerifyToken(req.RefreshToken) + if err != nil { + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + session, err := server.store.GetSession(ctx, refreshPayload.ID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + ctx.JSON(http.StatusNotFound, errorResponse(err)) + return + } + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + if session.IsBlocked { + err := fmt.Errorf("blocked session") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + if session.Email != refreshPayload.Email { + err := fmt.Errorf("incorrect session user") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + if session.RefreshToken != req.RefreshToken { + err := fmt.Errorf("mismatched session token") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + if time.Now().After(session.ExpiresAt) { + err := fmt.Errorf("expired session") + ctx.JSON(http.StatusUnauthorized, errorResponse(err)) + return + } + + accessToken, accessPayload, err := server.tokenMaker.CreateToken( + refreshPayload.Email, + server.config.AccessTokenDuration, + ) + if err != nil { + ctx.JSON(http.StatusInternalServerError, errorResponse(err)) + return + } + + rsp := renewAccessTokenResponse{ + AccessToken: accessToken, + AccessTokenExpiresAt: accessPayload.ExpiredAt, + } + ctx.JSON(http.StatusOK, rsp) +} diff --git a/app.env b/app.env index 915cf9c..396a54d 100644 --- a/app.env +++ b/app.env @@ -2,4 +2,7 @@ DB_SOURCE=postgresql://root:secret@localhost:5432/df?sslmode=disable DB_DRIVER=postgres SERVER_ADDRESS=0.0.0.0:8080 ENVIRONMENT=development -LOG_OUTPUT=text \ No newline at end of file +LOG_OUTPUT=text +ACCESS_TOKEN_DURATION=15m +REFRESH_TOKEN_DURATION=24h +TOKEN_SYMMETRIC_KEY=12345678901234567890123456789012 \ No newline at end of file diff --git a/db/migration/000001_init_schema.down.sql b/db/migration/000001_init_schema.down.sql index 882dffc..283346c 100644 --- a/db/migration/000001_init_schema.down.sql +++ b/db/migration/000001_init_schema.down.sql @@ -5,6 +5,7 @@ DROP TABLE IF EXISTS "documents"; DROP TABLE IF EXISTS "mails"; DROP TABLE IF EXISTS "persons"; DROP TABLE IF EXISTS "providers"; +DROP TABLE IF EXISTS "sessions"; DROP TABLE IF EXISTS "accounts"; diff --git a/db/migration/000001_init_schema.up.sql b/db/migration/000001_init_schema.up.sql index e6e5cc9..8a3204d 100644 --- a/db/migration/000001_init_schema.up.sql +++ b/db/migration/000001_init_schema.up.sql @@ -14,6 +14,7 @@ CREATE TABLE "mails" ( CREATE TABLE "accounts" ( "id" bigserial UNIQUE PRIMARY KEY NOT NULL, + "permission_level" int NOT NULL DEFAULT 0, "passwordhash" varchar NOT NULL, "firstname" varchar NOT NULL, "lastname" varchar NOT NULL, @@ -26,15 +27,23 @@ CREATE TABLE "accounts" ( "zip" varchar NOT NULL, "street" varchar NOT NULL, "country" varchar NOT NULL, - "token" varchar, - "token_valid" boolean DEFAULT false, - "token_expiration" timestamptz NOT NULL DEFAULT (now()), "creator" varchar NOT NULL, "created" timestamptz NOT NULL DEFAULT (now()), "changer" varchar NOT NULL, "changed" timestamptz NOT NULL DEFAULT (now()) ); +CREATE TABLE "sessions" ( + "id" uuid UNIQUE PRIMARY KEY NOT NULL, + "email" varchar NOT NULL, + "user_agent" varchar NOT NULL, + "client_ip" varchar NOT NULL, + "refresh_token" varchar NOT NULL, + "is_blocked" boolean NOT NULL DEFAULT false, + "expires_at" timestamptz NOT NULL, + "created_at" timestamptz NOT NULL DEFAULT (now()) +); + CREATE TABLE "persons" ( "id" bigserial UNIQUE PRIMARY KEY NOT NULL, "account_id" bigint NOT NULL, @@ -73,8 +82,8 @@ CREATE TABLE "payments" ( "account_id" bigint NOT NULL, "payment_category" varchar NOT NULL, "bankname" varchar, - "iban" varchar, - "bic" varchar, + "IBAN" varchar, + "BIC" varchar, "paypal_account" varchar, "paypal_id" varchar, "payment_system" varchar, @@ -123,6 +132,8 @@ CREATE TABLE "returnsLog" ( "changed" timestamptz NOT NULL DEFAULT (now()) ); +ALTER TABLE "sessions" ADD FOREIGN KEY ("email") REFERENCES "accounts" ("email"); + ALTER TABLE "persons" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id"); ALTER TABLE "documents" ADD FOREIGN KEY ("person_id") REFERENCES "persons" ("id"); diff --git a/db/mock/store.go b/db/mock/store.go index 25a06d5..ba85165 100644 --- a/db/mock/store.go +++ b/db/mock/store.go @@ -12,6 +12,7 @@ import ( context "context" reflect "reflect" + uuid "github.com/google/uuid" db "github.com/itsscb/df/db/sqlc" gomock "go.uber.org/mock/gomock" ) @@ -189,6 +190,21 @@ func (mr *MockStoreMockRecorder) CreateReturnsLog(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateReturnsLog", reflect.TypeOf((*MockStore)(nil).CreateReturnsLog), arg0, arg1) } +// CreateSession mocks base method. +func (m *MockStore) CreateSession(arg0 context.Context, arg1 db.CreateSessionParams) (db.Session, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSession", arg0, arg1) + ret0, _ := ret[0].(db.Session) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateSession indicates an expected call of CreateSession. +func (mr *MockStoreMockRecorder) CreateSession(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSession", reflect.TypeOf((*MockStore)(nil).CreateSession), arg0, arg1) +} + // DeleteAccount mocks base method. func (m *MockStore) DeleteAccount(arg0 context.Context, arg1 int64) error { m.ctrl.T.Helper() @@ -316,6 +332,21 @@ func (mr *MockStoreMockRecorder) GetAccount(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccount", reflect.TypeOf((*MockStore)(nil).GetAccount), arg0, arg1) } +// GetAccountByEmail mocks base method. +func (m *MockStore) GetAccountByEmail(arg0 context.Context, arg1 string) (db.Account, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccountByEmail", arg0, arg1) + ret0, _ := ret[0].(db.Account) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAccountByEmail indicates an expected call of GetAccountByEmail. +func (mr *MockStoreMockRecorder) GetAccountByEmail(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountByEmail", reflect.TypeOf((*MockStore)(nil).GetAccountByEmail), arg0, arg1) +} + // GetAccountForUpdate mocks base method. func (m *MockStore) GetAccountForUpdate(arg0 context.Context, arg1 int64) (db.Account, error) { m.ctrl.T.Helper() @@ -436,6 +467,21 @@ func (mr *MockStoreMockRecorder) GetReturnsLog(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetReturnsLog", reflect.TypeOf((*MockStore)(nil).GetReturnsLog), arg0, arg1) } +// GetSession mocks base method. +func (m *MockStore) GetSession(arg0 context.Context, arg1 uuid.UUID) (db.Session, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSession", arg0, arg1) + ret0, _ := ret[0].(db.Session) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSession indicates an expected call of GetSession. +func (mr *MockStoreMockRecorder) GetSession(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSession", reflect.TypeOf((*MockStore)(nil).GetSession), arg0, arg1) +} + // InvalidateDocument mocks base method. func (m *MockStore) InvalidateDocument(arg0 context.Context, arg1 db.InvalidateDocumentParams) (db.Document, error) { m.ctrl.T.Helper() diff --git a/db/query/account.sql b/db/query/account.sql index 49094e8..58780f7 100644 --- a/db/query/account.sql +++ b/db/query/account.sql @@ -2,6 +2,10 @@ SELECT * FROM accounts WHERE "id" = $1 LIMIT 1; +-- name: GetAccountByEmail :one +SELECT * FROM accounts +WHERE "email" = $1 LIMIT 1; + -- name: GetAccountForUpdate :one SELECT * FROM accounts WHERE "id" = $1 LIMIT 1 diff --git a/db/query/payment.sql b/db/query/payment.sql index d39d47b..767216a 100644 --- a/db/query/payment.sql +++ b/db/query/payment.sql @@ -7,8 +7,8 @@ INSERT INTO payments ( "account_id", "payment_category", "bankname", - "iban", - "bic", + "IBAN", + "BIC", "paypal_account", "paypal_id", "payment_system", @@ -31,8 +31,8 @@ SET "account_id" = COALESCE(sqlc.narg(account_id), "account_id"), "payment_category" = COALESCE(sqlc.narg(payment_category), "payment_category"), "bankname" = COALESCE(sqlc.narg(bankname), "bankname"), - "iban" = COALESCE(sqlc.narg(iban), "iban"), - "bic" = COALESCE(sqlc.narg(bic), "bic"), + "IBAN" = COALESCE(sqlc.narg(IBAN), "IBAN"), + "BIC" = COALESCE(sqlc.narg(BIC), "BIC"), "paypal_account" = COALESCE(sqlc.narg(paypal_account), "paypal_account"), "paypal_id" = COALESCE(sqlc.narg(paypal_id), "paypal_id"), "payment_system" = COALESCE(sqlc.narg(payment_system), "payment_system"), diff --git a/db/query/session.sql b/db/query/session.sql new file mode 100644 index 0000000..47d1fde --- /dev/null +++ b/db/query/session.sql @@ -0,0 +1,16 @@ +-- name: CreateSession :one +INSERT INTO sessions ( + id, + email, + refresh_token, + user_agent, + client_ip, + is_blocked, + expires_at +) VALUES ( + $1, $2, $3, $4, $5, $6, $7 +) RETURNING *; + +-- name: GetSession :one +SELECT * FROM sessions +WHERE id = $1 LIMIT 1; \ No newline at end of file diff --git a/db/sqlc/account.sql.go b/db/sqlc/account.sql.go index 76c9b51..0867849 100644 --- a/db/sqlc/account.sql.go +++ b/db/sqlc/account.sql.go @@ -42,7 +42,7 @@ INSERT INTO accounts ( $12, $13, $13 -) RETURNING id, passwordhash, firstname, lastname, birthday, privacy_accepted, privacy_accepted_date, email, phone, city, zip, street, country, token, token_valid, token_expiration, creator, created, changer, changed +) RETURNING id, permission_level, passwordhash, firstname, lastname, birthday, privacy_accepted, privacy_accepted_date, email, phone, city, zip, street, country, creator, created, changer, changed ` type CreateAccountParams struct { @@ -80,6 +80,7 @@ func (q *Queries) CreateAccount(ctx context.Context, arg CreateAccountParams) (A var i Account err := row.Scan( &i.ID, + &i.PermissionLevel, &i.Passwordhash, &i.Firstname, &i.Lastname, @@ -92,9 +93,6 @@ func (q *Queries) CreateAccount(ctx context.Context, arg CreateAccountParams) (A &i.Zip, &i.Street, &i.Country, - &i.Token, - &i.TokenValid, - &i.TokenExpiration, &i.Creator, &i.Created, &i.Changer, @@ -114,7 +112,7 @@ func (q *Queries) DeleteAccount(ctx context.Context, id int64) error { } const getAccount = `-- name: GetAccount :one -SELECT id, passwordhash, firstname, lastname, birthday, privacy_accepted, privacy_accepted_date, email, phone, city, zip, street, country, token, token_valid, token_expiration, creator, created, changer, changed FROM accounts +SELECT id, permission_level, passwordhash, firstname, lastname, birthday, privacy_accepted, privacy_accepted_date, email, phone, city, zip, street, country, creator, created, changer, changed FROM accounts WHERE "id" = $1 LIMIT 1 ` @@ -123,6 +121,38 @@ func (q *Queries) GetAccount(ctx context.Context, id int64) (Account, error) { var i Account err := row.Scan( &i.ID, + &i.PermissionLevel, + &i.Passwordhash, + &i.Firstname, + &i.Lastname, + &i.Birthday, + &i.PrivacyAccepted, + &i.PrivacyAcceptedDate, + &i.Email, + &i.Phone, + &i.City, + &i.Zip, + &i.Street, + &i.Country, + &i.Creator, + &i.Created, + &i.Changer, + &i.Changed, + ) + return i, err +} + +const getAccountByEmail = `-- name: GetAccountByEmail :one +SELECT id, permission_level, passwordhash, firstname, lastname, birthday, privacy_accepted, privacy_accepted_date, email, phone, city, zip, street, country, creator, created, changer, changed FROM accounts +WHERE "email" = $1 LIMIT 1 +` + +func (q *Queries) GetAccountByEmail(ctx context.Context, email string) (Account, error) { + row := q.db.QueryRowContext(ctx, getAccountByEmail, email) + var i Account + err := row.Scan( + &i.ID, + &i.PermissionLevel, &i.Passwordhash, &i.Firstname, &i.Lastname, @@ -135,9 +165,6 @@ func (q *Queries) GetAccount(ctx context.Context, id int64) (Account, error) { &i.Zip, &i.Street, &i.Country, - &i.Token, - &i.TokenValid, - &i.TokenExpiration, &i.Creator, &i.Created, &i.Changer, @@ -147,7 +174,7 @@ func (q *Queries) GetAccount(ctx context.Context, id int64) (Account, error) { } const getAccountForUpdate = `-- name: GetAccountForUpdate :one -SELECT id, passwordhash, firstname, lastname, birthday, privacy_accepted, privacy_accepted_date, email, phone, city, zip, street, country, token, token_valid, token_expiration, creator, created, changer, changed FROM accounts +SELECT id, permission_level, passwordhash, firstname, lastname, birthday, privacy_accepted, privacy_accepted_date, email, phone, city, zip, street, country, creator, created, changer, changed FROM accounts WHERE "id" = $1 LIMIT 1 FOR NO KEY UPDATE ` @@ -157,6 +184,7 @@ func (q *Queries) GetAccountForUpdate(ctx context.Context, id int64) (Account, e var i Account err := row.Scan( &i.ID, + &i.PermissionLevel, &i.Passwordhash, &i.Firstname, &i.Lastname, @@ -169,9 +197,6 @@ func (q *Queries) GetAccountForUpdate(ctx context.Context, id int64) (Account, e &i.Zip, &i.Street, &i.Country, - &i.Token, - &i.TokenValid, - &i.TokenExpiration, &i.Creator, &i.Created, &i.Changer, @@ -181,7 +206,7 @@ func (q *Queries) GetAccountForUpdate(ctx context.Context, id int64) (Account, e } const listAccounts = `-- name: ListAccounts :many -SELECT id, passwordhash, firstname, lastname, birthday, privacy_accepted, privacy_accepted_date, email, phone, city, zip, street, country, token, token_valid, token_expiration, creator, created, changer, changed FROM accounts +SELECT id, permission_level, passwordhash, firstname, lastname, birthday, privacy_accepted, privacy_accepted_date, email, phone, city, zip, street, country, creator, created, changer, changed FROM accounts ORDER BY "lastname", "firstname" LIMIT $1 OFFSET $2 @@ -203,6 +228,7 @@ func (q *Queries) ListAccounts(ctx context.Context, arg ListAccountsParams) ([]A var i Account if err := rows.Scan( &i.ID, + &i.PermissionLevel, &i.Passwordhash, &i.Firstname, &i.Lastname, @@ -215,9 +241,6 @@ func (q *Queries) ListAccounts(ctx context.Context, arg ListAccountsParams) ([]A &i.Zip, &i.Street, &i.Country, - &i.Token, - &i.TokenValid, - &i.TokenExpiration, &i.Creator, &i.Created, &i.Changer, @@ -252,7 +275,7 @@ SET "changer" = $2, "changed" = now() WHERE "id" = $1 -RETURNING id, passwordhash, firstname, lastname, birthday, privacy_accepted, privacy_accepted_date, email, phone, city, zip, street, country, token, token_valid, token_expiration, creator, created, changer, changed +RETURNING id, permission_level, passwordhash, firstname, lastname, birthday, privacy_accepted, privacy_accepted_date, email, phone, city, zip, street, country, creator, created, changer, changed ` type UpdateAccountParams struct { @@ -288,6 +311,7 @@ func (q *Queries) UpdateAccount(ctx context.Context, arg UpdateAccountParams) (A var i Account err := row.Scan( &i.ID, + &i.PermissionLevel, &i.Passwordhash, &i.Firstname, &i.Lastname, @@ -300,9 +324,6 @@ func (q *Queries) UpdateAccount(ctx context.Context, arg UpdateAccountParams) (A &i.Zip, &i.Street, &i.Country, - &i.Token, - &i.TokenValid, - &i.TokenExpiration, &i.Creator, &i.Created, &i.Changer, @@ -319,7 +340,7 @@ SET "changer" = $3, "changed" = now() WHERE "id" = $4 -RETURNING id, passwordhash, firstname, lastname, birthday, privacy_accepted, privacy_accepted_date, email, phone, city, zip, street, country, token, token_valid, token_expiration, creator, created, changer, changed +RETURNING id, permission_level, passwordhash, firstname, lastname, birthday, privacy_accepted, privacy_accepted_date, email, phone, city, zip, street, country, creator, created, changer, changed ` type UpdateAccountPrivacyParams struct { @@ -339,6 +360,7 @@ func (q *Queries) UpdateAccountPrivacy(ctx context.Context, arg UpdateAccountPri var i Account err := row.Scan( &i.ID, + &i.PermissionLevel, &i.Passwordhash, &i.Firstname, &i.Lastname, @@ -351,9 +373,6 @@ func (q *Queries) UpdateAccountPrivacy(ctx context.Context, arg UpdateAccountPri &i.Zip, &i.Street, &i.Country, - &i.Token, - &i.TokenValid, - &i.TokenExpiration, &i.Creator, &i.Created, &i.Changer, diff --git a/db/sqlc/models.go b/db/sqlc/models.go index c45f672..055c21d 100644 --- a/db/sqlc/models.go +++ b/db/sqlc/models.go @@ -7,10 +7,13 @@ package db import ( "database/sql" "time" + + "github.com/google/uuid" ) type Account struct { ID int64 `json:"id"` + PermissionLevel int32 `json:"permission_level"` Passwordhash string `json:"passwordhash"` Firstname string `json:"firstname"` Lastname string `json:"lastname"` @@ -23,9 +26,6 @@ type Account struct { Zip string `json:"zip"` Street string `json:"street"` Country string `json:"country"` - Token sql.NullString `json:"token"` - TokenValid sql.NullBool `json:"token_valid"` - TokenExpiration time.Time `json:"token_expiration"` Creator string `json:"creator"` Created time.Time `json:"created"` Changer string `json:"changer"` @@ -68,8 +68,8 @@ type Payment struct { AccountID int64 `json:"account_id"` PaymentCategory string `json:"payment_category"` Bankname sql.NullString `json:"bankname"` - Iban sql.NullString `json:"iban"` - Bic sql.NullString `json:"bic"` + IBAN sql.NullString `json:"IBAN"` + BIC sql.NullString `json:"BIC"` PaypalAccount sql.NullString `json:"paypal_account"` PaypalID sql.NullString `json:"paypal_id"` PaymentSystem sql.NullString `json:"payment_system"` @@ -133,3 +133,14 @@ type ReturnsLog struct { Changer string `json:"changer"` Changed time.Time `json:"changed"` } + +type Session struct { + ID uuid.UUID `json:"id"` + Email string `json:"email"` + UserAgent string `json:"user_agent"` + ClientIp string `json:"client_ip"` + RefreshToken string `json:"refresh_token"` + IsBlocked bool `json:"is_blocked"` + ExpiresAt time.Time `json:"expires_at"` + CreatedAt time.Time `json:"created_at"` +} diff --git a/db/sqlc/payment.sql.go b/db/sqlc/payment.sql.go index 2a62aad..41b260f 100644 --- a/db/sqlc/payment.sql.go +++ b/db/sqlc/payment.sql.go @@ -15,8 +15,8 @@ INSERT INTO payments ( "account_id", "payment_category", "bankname", - "iban", - "bic", + "IBAN", + "BIC", "paypal_account", "paypal_id", "payment_system", @@ -25,15 +25,15 @@ INSERT INTO payments ( "changer" ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11 -) RETURNING id, account_id, payment_category, bankname, iban, bic, paypal_account, paypal_id, payment_system, type, creator, created, changer, changed +) RETURNING id, account_id, payment_category, bankname, "IBAN", "BIC", paypal_account, paypal_id, payment_system, type, creator, created, changer, changed ` type CreatePaymentParams struct { AccountID int64 `json:"account_id"` PaymentCategory string `json:"payment_category"` Bankname sql.NullString `json:"bankname"` - Iban sql.NullString `json:"iban"` - Bic sql.NullString `json:"bic"` + IBAN sql.NullString `json:"IBAN"` + BIC sql.NullString `json:"BIC"` PaypalAccount sql.NullString `json:"paypal_account"` PaypalID sql.NullString `json:"paypal_id"` PaymentSystem sql.NullString `json:"payment_system"` @@ -47,8 +47,8 @@ func (q *Queries) CreatePayment(ctx context.Context, arg CreatePaymentParams) (P arg.AccountID, arg.PaymentCategory, arg.Bankname, - arg.Iban, - arg.Bic, + arg.IBAN, + arg.BIC, arg.PaypalAccount, arg.PaypalID, arg.PaymentSystem, @@ -62,8 +62,8 @@ func (q *Queries) CreatePayment(ctx context.Context, arg CreatePaymentParams) (P &i.AccountID, &i.PaymentCategory, &i.Bankname, - &i.Iban, - &i.Bic, + &i.IBAN, + &i.BIC, &i.PaypalAccount, &i.PaypalID, &i.PaymentSystem, @@ -87,7 +87,7 @@ func (q *Queries) DeletePayment(ctx context.Context, id int64) error { } const getPayment = `-- name: GetPayment :one -SELECT id, account_id, payment_category, bankname, iban, bic, paypal_account, paypal_id, payment_system, type, creator, created, changer, changed FROM payments +SELECT id, account_id, payment_category, bankname, "IBAN", "BIC", paypal_account, paypal_id, payment_system, type, creator, created, changer, changed FROM payments WHERE "id" = $1 LIMIT 1 ` @@ -99,8 +99,8 @@ func (q *Queries) GetPayment(ctx context.Context, id int64) (Payment, error) { &i.AccountID, &i.PaymentCategory, &i.Bankname, - &i.Iban, - &i.Bic, + &i.IBAN, + &i.BIC, &i.PaypalAccount, &i.PaypalID, &i.PaymentSystem, @@ -114,7 +114,7 @@ func (q *Queries) GetPayment(ctx context.Context, id int64) (Payment, error) { } const listPayments = `-- name: ListPayments :many -SELECT id, account_id, payment_category, bankname, iban, bic, paypal_account, paypal_id, payment_system, type, creator, created, changer, changed FROM payments +SELECT id, account_id, payment_category, bankname, "IBAN", "BIC", paypal_account, paypal_id, payment_system, type, creator, created, changer, changed FROM payments ORDER BY "payment_category" LIMIT $1 OFFSET $2 @@ -139,8 +139,8 @@ func (q *Queries) ListPayments(ctx context.Context, arg ListPaymentsParams) ([]P &i.AccountID, &i.PaymentCategory, &i.Bankname, - &i.Iban, - &i.Bic, + &i.IBAN, + &i.BIC, &i.PaypalAccount, &i.PaypalID, &i.PaymentSystem, @@ -169,8 +169,8 @@ SET "account_id" = COALESCE($3, "account_id"), "payment_category" = COALESCE($4, "payment_category"), "bankname" = COALESCE($5, "bankname"), - "iban" = COALESCE($6, "iban"), - "bic" = COALESCE($7, "bic"), + "IBAN" = COALESCE($6, "IBAN"), + "BIC" = COALESCE($7, "BIC"), "paypal_account" = COALESCE($8, "paypal_account"), "paypal_id" = COALESCE($9, "paypal_id"), "payment_system" = COALESCE($10, "payment_system"), @@ -178,7 +178,7 @@ SET "changer" = $2, "changed" = now() WHERE "id" = $1 -RETURNING id, account_id, payment_category, bankname, iban, bic, paypal_account, paypal_id, payment_system, type, creator, created, changer, changed +RETURNING id, account_id, payment_category, bankname, "IBAN", "BIC", paypal_account, paypal_id, payment_system, type, creator, created, changer, changed ` type UpdatePaymentParams struct { @@ -215,8 +215,8 @@ func (q *Queries) UpdatePayment(ctx context.Context, arg UpdatePaymentParams) (P &i.AccountID, &i.PaymentCategory, &i.Bankname, - &i.Iban, - &i.Bic, + &i.IBAN, + &i.BIC, &i.PaypalAccount, &i.PaypalID, &i.PaymentSystem, diff --git a/db/sqlc/payment_test.go b/db/sqlc/payment_test.go index 06f3806..ec14caa 100644 --- a/db/sqlc/payment_test.go +++ b/db/sqlc/payment_test.go @@ -23,11 +23,11 @@ func createRandomPayment(t *testing.T) Payment { Valid: true, String: util.RandomName(), }, - Iban: sql.NullString{ + IBAN: sql.NullString{ Valid: true, String: util.RandomName(), }, - Bic: sql.NullString{ + BIC: sql.NullString{ Valid: true, String: util.RandomName(), }, @@ -55,8 +55,8 @@ func createRandomPayment(t *testing.T) Payment { require.Equal(t, arg.PaymentCategory, person.PaymentCategory) require.Equal(t, arg.Bankname, person.Bankname) require.Equal(t, arg.AccountID, person.AccountID) - require.Equal(t, arg.Iban, person.Iban) - require.Equal(t, arg.Bic, person.Bic) + require.Equal(t, arg.IBAN, person.IBAN) + require.Equal(t, arg.BIC, person.BIC) require.Equal(t, arg.PaypalAccount, person.PaypalAccount) require.Equal(t, arg.PaymentSystem, person.PaymentSystem) require.Equal(t, arg.PaypalID, person.PaypalID) @@ -84,8 +84,8 @@ func TestGetPayment(t *testing.T) { require.Equal(t, newperson.PaymentCategory, person.PaymentCategory) require.Equal(t, newperson.Bankname, person.Bankname) require.Equal(t, newperson.AccountID, person.AccountID) - require.Equal(t, newperson.Iban, person.Iban) - require.Equal(t, newperson.Bic, person.Bic) + require.Equal(t, newperson.IBAN, person.IBAN) + require.Equal(t, newperson.BIC, person.BIC) require.Equal(t, newperson.PaypalAccount, person.PaypalAccount) require.Equal(t, newperson.PaymentSystem, person.PaymentSystem) require.Equal(t, newperson.PaypalID, person.PaypalID) diff --git a/db/sqlc/querier.go b/db/sqlc/querier.go index 2dee7af..55ce165 100644 --- a/db/sqlc/querier.go +++ b/db/sqlc/querier.go @@ -6,6 +6,8 @@ package db import ( "context" + + "github.com/google/uuid" ) type Querier interface { @@ -18,6 +20,7 @@ type Querier interface { CreateProvider(ctx context.Context, arg CreateProviderParams) (Provider, error) CreateReturn(ctx context.Context, arg CreateReturnParams) (Return, error) CreateReturnsLog(ctx context.Context, arg CreateReturnsLogParams) (ReturnsLog, error) + CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) DeleteAccount(ctx context.Context, id int64) error DeleteDocument(ctx context.Context, id int64) error // -- name: UpdateMail :one @@ -40,6 +43,7 @@ type Querier interface { DeleteReturn(ctx context.Context, id int64) error DeleteReturnsLog(ctx context.Context, id int64) error GetAccount(ctx context.Context, id int64) (Account, error) + GetAccountByEmail(ctx context.Context, email string) (Account, error) GetAccountForUpdate(ctx context.Context, id int64) (Account, error) GetDocument(ctx context.Context, id int64) (Document, error) GetMail(ctx context.Context, id int64) (Mail, error) @@ -48,6 +52,7 @@ type Querier interface { GetProvider(ctx context.Context, id int64) (Provider, error) GetReturn(ctx context.Context, id int64) (Return, error) GetReturnsLog(ctx context.Context, id int64) (ReturnsLog, error) + GetSession(ctx context.Context, id uuid.UUID) (Session, error) InvalidateDocument(ctx context.Context, arg InvalidateDocumentParams) (Document, error) ListAccounts(ctx context.Context, arg ListAccountsParams) ([]Account, error) ListDocuments(ctx context.Context, arg ListDocumentsParams) ([]Document, error) diff --git a/db/sqlc/session.sql.go b/db/sqlc/session.sql.go new file mode 100644 index 0000000..daace93 --- /dev/null +++ b/db/sqlc/session.sql.go @@ -0,0 +1,82 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.21.0 +// source: session.sql + +package db + +import ( + "context" + "time" + + "github.com/google/uuid" +) + +const createSession = `-- name: CreateSession :one +INSERT INTO sessions ( + id, + email, + refresh_token, + user_agent, + client_ip, + is_blocked, + expires_at +) VALUES ( + $1, $2, $3, $4, $5, $6, $7 +) RETURNING id, email, user_agent, client_ip, refresh_token, is_blocked, expires_at, created_at +` + +type CreateSessionParams struct { + ID uuid.UUID `json:"id"` + Email string `json:"email"` + RefreshToken string `json:"refresh_token"` + UserAgent string `json:"user_agent"` + ClientIp string `json:"client_ip"` + IsBlocked bool `json:"is_blocked"` + ExpiresAt time.Time `json:"expires_at"` +} + +func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) { + row := q.db.QueryRowContext(ctx, createSession, + arg.ID, + arg.Email, + arg.RefreshToken, + arg.UserAgent, + arg.ClientIp, + arg.IsBlocked, + arg.ExpiresAt, + ) + var i Session + err := row.Scan( + &i.ID, + &i.Email, + &i.UserAgent, + &i.ClientIp, + &i.RefreshToken, + &i.IsBlocked, + &i.ExpiresAt, + &i.CreatedAt, + ) + return i, err +} + +const getSession = `-- name: GetSession :one +SELECT id, email, user_agent, client_ip, refresh_token, is_blocked, expires_at, created_at FROM sessions +WHERE id = $1 LIMIT 1 +` + +func (q *Queries) GetSession(ctx context.Context, id uuid.UUID) (Session, error) { + row := q.db.QueryRowContext(ctx, getSession, id) + var i Session + err := row.Scan( + &i.ID, + &i.Email, + &i.UserAgent, + &i.ClientIp, + &i.RefreshToken, + &i.IsBlocked, + &i.ExpiresAt, + &i.CreatedAt, + ) + return i, err +} diff --git a/db/sqlc/tx_create_account.go b/db/sqlc/tx_create_account.go index 1b2c270..e3f6791 100644 --- a/db/sqlc/tx_create_account.go +++ b/db/sqlc/tx_create_account.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" "time" + + "github.com/itsscb/df/util" ) type CreateAccountTxParams struct { @@ -28,6 +30,7 @@ type CreateAccountTxResult struct { func (store *SQLStore) CreateAccountTx(ctx context.Context, arg CreateAccountTxParams) (Account, error) { var result CreateAccountTxResult + var err error if arg.PrivacyAccepted.Bool && arg.PrivacyAccepted.Valid && !arg.PrivacyAcceptedDate.Valid { arg.PrivacyAcceptedDate = sql.NullTime{ @@ -36,7 +39,12 @@ func (store *SQLStore) CreateAccountTx(ctx context.Context, arg CreateAccountTxP } } - err := store.execTx(ctx, func(q *Queries) error { + arg.Passwordhash, err = util.HashPassword(arg.Passwordhash) + if err != nil { + return Account{}, nil + } + + err = store.execTx(ctx, func(q *Queries) error { var err error result.Account, err = q.CreateAccount(ctx, CreateAccountParams(arg)) diff --git a/db/sqlc/tx_update_account.go b/db/sqlc/tx_update_account.go index 48dbff9..82ef979 100644 --- a/db/sqlc/tx_update_account.go +++ b/db/sqlc/tx_update_account.go @@ -3,6 +3,8 @@ package db import ( "context" "database/sql" + + "github.com/itsscb/df/util" ) type UpdateAccountTxParams struct { @@ -26,8 +28,16 @@ type UpdateAccountTxResult struct { func (store *SQLStore) UpdateAccountTx(ctx context.Context, arg UpdateAccountTxParams) (Account, error) { var result UpdateAccountTxResult + var err error - err := store.execTx(ctx, func(q *Queries) error { + if arg.Passwordhash.Valid { + arg.Passwordhash.String, err = util.HashPassword(arg.Passwordhash.String) + if err != nil { + return Account{}, nil + } + } + + err = store.execTx(ctx, func(q *Queries) error { var err error result.Account, err = q.UpdateAccount(ctx, UpdateAccountParams(arg)) return err diff --git a/go.mod b/go.mod index 3e83533..0ffa4a0 100644 --- a/go.mod +++ b/go.mod @@ -5,14 +5,21 @@ go 1.21 toolchain go1.21.1 require ( + github.com/aead/chacha20poly1305 v0.0.0-20201124145622-1a5aba2a8b29 github.com/gin-gonic/gin v1.9.1 + github.com/google/uuid v1.1.2 github.com/lib/pq v1.10.9 + github.com/o1egl/paseto v1.0.0 github.com/spf13/viper v1.16.0 github.com/stretchr/testify v1.8.4 go.uber.org/mock v0.3.0 + golang.org/x/crypto v0.13.0 + golang.org/x/exp v0.0.0-20230905200255-921286631fa9 ) require ( + github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect + github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 // indirect github.com/bytedance/sonic v1.10.1 // indirect github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d // indirect github.com/chenzhuoyu/iasm v0.9.0 // indirect @@ -34,6 +41,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.1.0 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/afero v1.9.5 // indirect github.com/spf13/cast v1.5.1 // indirect @@ -43,7 +51,6 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.5.0 // indirect - golang.org/x/crypto v0.13.0 // indirect golang.org/x/net v0.15.0 // indirect golang.org/x/sys v0.12.0 // indirect golang.org/x/text v0.13.0 // indirect diff --git a/go.sum b/go.sum index 99ce712..2633f41 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,13 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= +github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= +github.com/aead/chacha20poly1305 v0.0.0-20170617001512-233f39982aeb/go.mod h1:UzH9IX1MMqOcwhoNOIjmTQeAxrFgzs50j4golQtXXxU= +github.com/aead/chacha20poly1305 v0.0.0-20201124145622-1a5aba2a8b29 h1:1DcvRPZOdbQRg5nAHt2jrc5QbV0AGuhDdfQI6gXjiFE= +github.com/aead/chacha20poly1305 v0.0.0-20201124145622-1a5aba2a8b29/go.mod h1:UzH9IX1MMqOcwhoNOIjmTQeAxrFgzs50j4golQtXXxU= +github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw= +github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635/go.mod h1:lmLxL+FV291OopO93Bwf9fQLQeLyt33VJRUg5VJ30us= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.10.0-rc/go.mod h1:ElCzW+ufi8qKqNW0FY314xriJhyJhuoJ3gFZdAHF7NM= github.com/bytedance/sonic v1.10.1 h1:7a1wuFXL1cMy7a3f7/VFcEtriuXQnUBhtoVfOZiaysc= @@ -143,6 +150,7 @@ github.com/google/pprof v0.0.0-20201023163331-3e6fc7fc9c4c/go.mod h1:kpwsk12EmLe github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= @@ -185,8 +193,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/o1egl/paseto v1.0.0 h1:bwpvPu2au176w4IBlhbyUv/S5VPptERIA99Oap5qUd0= +github.com/o1egl/paseto v1.0.0/go.mod h1:5HxsZPmw/3RI2pAwGo1HhOOwSdvBpcuVzO7uDkm+CLU= github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4= github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -240,6 +252,7 @@ go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.5.0 h1:jpGode6huXQxcskEIpOCvrU+tzo81b6+oFLUYXWtH/Y= golang.org/x/arch v0.5.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.0.0-20181025213731-e84da0312774/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -259,6 +272,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -336,6 +351,7 @@ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/main.go b/main.go index 4a0e1b2..7f2f854 100644 --- a/main.go +++ b/main.go @@ -21,7 +21,10 @@ func main() { } store := db.NewStore(conn) - server := api.NewServer(config, store) + server, err := api.NewServer(config, store) + if err != nil { + log.Fatalf("could not start server: %s", err) + } err = server.Start(config.ServerAddress) if err != nil { diff --git a/token/maker.go b/token/maker.go new file mode 100644 index 0000000..d2e7577 --- /dev/null +++ b/token/maker.go @@ -0,0 +1,14 @@ +package token + +import ( + "time" +) + +// Maker is an interface for managing tokens +type Maker interface { + // CreateToken creates a new token for a specific username and duration + CreateToken(email string, duration time.Duration) (string, *Payload, error) + + // VerifyToken checks if the token is valid or not + VerifyToken(token string) (*Payload, error) +} diff --git a/token/paseto_maker.go b/token/paseto_maker.go new file mode 100644 index 0000000..b00c3f9 --- /dev/null +++ b/token/paseto_maker.go @@ -0,0 +1,57 @@ +package token + +import ( + "fmt" + "time" + + "github.com/aead/chacha20poly1305" + "github.com/o1egl/paseto" +) + +// PasetoMaker is a PASETO token maker +type PasetoMaker struct { + paseto *paseto.V2 + symmetricKey []byte +} + +// NewPasetoMaker creates a new PasetoMaker +func NewPasetoMaker(symmetricKey string) (Maker, error) { + if len(symmetricKey) != chacha20poly1305.KeySize { + return nil, fmt.Errorf("invalid key size: must be exactly %d characters", chacha20poly1305.KeySize) + } + + maker := &PasetoMaker{ + paseto: paseto.NewV2(), + symmetricKey: []byte(symmetricKey), + } + + return maker, nil +} + +// CreateToken creates a new token for a specific username and duration +func (maker *PasetoMaker) CreateToken(email string, duration time.Duration) (string, *Payload, error) { + payload, err := NewPayload(email, duration) + if err != nil { + return "", payload, err + } + + token, err := maker.paseto.Encrypt(maker.symmetricKey, payload, nil) + return token, payload, err +} + +// VerifyToken checks if the token is valid or not +func (maker *PasetoMaker) VerifyToken(token string) (*Payload, error) { + payload := &Payload{} + + err := maker.paseto.Decrypt(token, maker.symmetricKey, payload, nil) + if err != nil { + return nil, ErrInvalidToken + } + + err = payload.Valid() + if err != nil { + return nil, err + } + + return payload, nil +} diff --git a/token/paseto_maker_test.go b/token/paseto_maker_test.go new file mode 100644 index 0000000..8247c65 --- /dev/null +++ b/token/paseto_maker_test.go @@ -0,0 +1,49 @@ +package token + +import ( + "testing" + "time" + + "github.com/itsscb/df/util" + "github.com/stretchr/testify/require" +) + +func TestPasetoMaker(t *testing.T) { + maker, err := NewPasetoMaker(util.RandomString(32)) + require.NoError(t, err) + + email := util.RandomEmail() + duration := time.Minute + + issuedAt := time.Now() + expiredAt := issuedAt.Add(duration) + + token, payload, err := maker.CreateToken(email, duration) + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotEmpty(t, payload) + + payload, err = maker.VerifyToken(token) + require.NoError(t, err) + require.NotEmpty(t, token) + + require.NotZero(t, payload.ID) + require.Equal(t, email, payload.Email) + require.WithinDuration(t, issuedAt, payload.IssuedAt, time.Second) + require.WithinDuration(t, expiredAt, payload.ExpiredAt, time.Second) +} + +func TestExpiredPasetoToken(t *testing.T) { + maker, err := NewPasetoMaker(util.RandomString(32)) + require.NoError(t, err) + + token, payload, err := maker.CreateToken(util.RandomEmail(), -time.Minute) + require.NoError(t, err) + require.NotEmpty(t, token) + require.NotEmpty(t, payload) + + payload, err = maker.VerifyToken(token) + require.Error(t, err) + require.EqualError(t, err, ErrExpiredToken.Error()) + require.Nil(t, payload) +} diff --git a/token/payload.go b/token/payload.go new file mode 100644 index 0000000..6a34c07 --- /dev/null +++ b/token/payload.go @@ -0,0 +1,46 @@ +package token + +import ( + "errors" + "time" + + "github.com/google/uuid" +) + +// Different types of error returned by the VerifyToken function +var ( + ErrInvalidToken = errors.New("token is invalid") + ErrExpiredToken = errors.New("token has expired") +) + +// Payload contains the payload data of the token +type Payload struct { + ID uuid.UUID `json:"id"` + Email string `json:"account_id"` + IssuedAt time.Time `json:"issued_at"` + ExpiredAt time.Time `json:"expired_at"` +} + +// NewPayload creates a new token payload with a specific accountID and duration +func NewPayload(email string, duration time.Duration) (*Payload, error) { + tokenID, err := uuid.NewRandom() + if err != nil { + return nil, err + } + + payload := &Payload{ + ID: tokenID, + Email: email, + IssuedAt: time.Now(), + ExpiredAt: time.Now().Add(duration), + } + return payload, nil +} + +// Valid checks if the token payload is valid or not +func (payload *Payload) Valid() error { + if time.Now().After(payload.ExpiredAt) { + return ErrExpiredToken + } + return nil +} diff --git a/util/config.go b/util/config.go index 5dd256a..a4da802 100644 --- a/util/config.go +++ b/util/config.go @@ -1,13 +1,20 @@ package util -import "github.com/spf13/viper" +import ( + "time" + + "github.com/spf13/viper" +) type Config struct { - DBSource string `mapstructure:"DB_SOURCE"` - DBDriver string `mapstructure:"DB_DRIVER"` - ServerAddress string `mapstructure:"SERVER_ADDRESS"` - Environment string `mapstructure:"ENVIRONMENT"` - LogOutput string `mapstructure:"LOG_OUTPUT"` + DBSource string `mapstructure:"DB_SOURCE"` + DBDriver string `mapstructure:"DB_DRIVER"` + ServerAddress string `mapstructure:"SERVER_ADDRESS"` + Environment string `mapstructure:"ENVIRONMENT"` + LogOutput string `mapstructure:"LOG_OUTPUT"` + TokenSymmetricKey string `mapstructure:"TOKEN_SYMMETRIC_KEY"` + AccessTokenDuration time.Duration `mapstructure:"ACCESS_TOKEN_DURATION"` + RefreshTokenDuration time.Duration `mapstructure:"REFRESH_TOKEN_DURATION"` } func LoadConfig(path string) (config Config, err error) { diff --git a/util/password.go b/util/password.go new file mode 100644 index 0000000..8237e44 --- /dev/null +++ b/util/password.go @@ -0,0 +1,21 @@ +package util + +import ( + "fmt" + + "golang.org/x/crypto/bcrypt" +) + +// HashPassword returns the bcrypt hash of the password +func HashPassword(password string) (string, error) { + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return "", fmt.Errorf("failed to hash password: %w", err) + } + return string(hashedPassword), nil +} + +// CheckPassword checks if the provided password is correct or not +func CheckPassword(password string, hashedPassword string) error { + return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) +} diff --git a/util/password_test.go b/util/password_test.go new file mode 100644 index 0000000..1e87e9f --- /dev/null +++ b/util/password_test.go @@ -0,0 +1,28 @@ +package util + +import ( + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/crypto/bcrypt" +) + +func TestPassword(t *testing.T) { + password := RandomString(6) + + hashedPassword1, err := HashPassword(password) + require.NoError(t, err) + require.NotEmpty(t, hashedPassword1) + + err = CheckPassword(password, hashedPassword1) + require.NoError(t, err) + + wrongPassword := RandomString(6) + err = CheckPassword(wrongPassword, hashedPassword1) + require.EqualError(t, err, bcrypt.ErrMismatchedHashAndPassword.Error()) + + hashedPassword2, err := HashPassword(password) + require.NoError(t, err) + require.NotEmpty(t, hashedPassword2) + require.NotEqual(t, hashedPassword1, hashedPassword2) +}