115 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			115 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package api
 | |
| 
 | |
| import (
 | |
| 	"fmt"
 | |
| 	"net/http"
 | |
| 	"net/http/httptest"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/gin-gonic/gin"
 | |
| 	mockdb "github.com/itsscb/df/db/mock"
 | |
| 	"github.com/itsscb/df/token"
 | |
| 	"github.com/stretchr/testify/require"
 | |
| 	"go.uber.org/mock/gomock"
 | |
| )
 | |
| 
 | |
| func addAuthorization(
 | |
| 	t *testing.T,
 | |
| 	request *http.Request,
 | |
| 	tokenMaker token.Maker,
 | |
| 	authorizationType string,
 | |
| 	email string,
 | |
| 	duration time.Duration,
 | |
| ) {
 | |
| 	token, payload, err := tokenMaker.CreateToken(email, duration)
 | |
| 	require.NoError(t, err)
 | |
| 	require.NotEmpty(t, payload)
 | |
| 
 | |
| 	authorizationHeader := fmt.Sprintf("%s %s", authorizationType, token)
 | |
| 	request.Header.Set(authorizationHeaderKey, authorizationHeader)
 | |
| }
 | |
| 
 | |
| func TestAuthMiddleware(t *testing.T) {
 | |
| 	testCases := []struct {
 | |
| 		name          string
 | |
| 		setupAuth     func(t *testing.T, request *http.Request, tokenMaker token.Maker)
 | |
| 		checkResponse func(t *testing.T, recorder *httptest.ResponseRecorder)
 | |
| 	}{
 | |
| 		{
 | |
| 			name: "OK",
 | |
| 			setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
 | |
| 				addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "user", time.Minute)
 | |
| 			},
 | |
| 			checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
 | |
| 				require.Equal(t, http.StatusOK, recorder.Code)
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			name: "NoAuthorization",
 | |
| 			setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
 | |
| 			},
 | |
| 			checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
 | |
| 				require.Equal(t, http.StatusUnauthorized, recorder.Code)
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			name: "UnsupportedAuthorization",
 | |
| 			setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
 | |
| 				addAuthorization(t, request, tokenMaker, "unsupported", "user", time.Minute)
 | |
| 			},
 | |
| 			checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
 | |
| 				require.Equal(t, http.StatusUnauthorized, recorder.Code)
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			name: "InvalidAuthorizationFormat",
 | |
| 			setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
 | |
| 				addAuthorization(t, request, tokenMaker, "", "user", time.Minute)
 | |
| 			},
 | |
| 			checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
 | |
| 				require.Equal(t, http.StatusUnauthorized, recorder.Code)
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			name: "ExpiredToken",
 | |
| 			setupAuth: func(t *testing.T, request *http.Request, tokenMaker token.Maker) {
 | |
| 				addAuthorization(t, request, tokenMaker, authorizationTypeBearer, "user", -time.Minute)
 | |
| 			},
 | |
| 			checkResponse: func(t *testing.T, recorder *httptest.ResponseRecorder) {
 | |
| 				require.Equal(t, http.StatusUnauthorized, recorder.Code)
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	for i := range testCases {
 | |
| 		tc := testCases[i]
 | |
| 
 | |
| 		t.Run(tc.name, func(t *testing.T) {
 | |
| 			ctrl := gomock.NewController(t)
 | |
| 			defer ctrl.Finish()
 | |
| 
 | |
| 			store := mockdb.NewMockStore(ctrl)
 | |
| 
 | |
| 			server, err := NewServer(config, store)
 | |
| 			require.NoError(t, err)
 | |
| 			authPath := "/auth"
 | |
| 			server.router.GET(
 | |
| 				authPath,
 | |
| 				authMiddleware(server.tokenMaker),
 | |
| 				func(ctx *gin.Context) {
 | |
| 					ctx.JSON(http.StatusOK, gin.H{})
 | |
| 				},
 | |
| 			)
 | |
| 
 | |
| 			recorder := httptest.NewRecorder()
 | |
| 			request, err := http.NewRequest(http.MethodGet, authPath, nil)
 | |
| 			require.NoError(t, err)
 | |
| 
 | |
| 			tc.setupAuth(t, request, server.tokenMaker)
 | |
| 			server.router.ServeHTTP(recorder, request)
 | |
| 			tc.checkResponse(t, recorder)
 | |
| 		})
 | |
| 	}
 | |
| }
 |