# -*- coding: utf-8 -*- import unittest from collections import OrderedDict from nacl import signing import mock from django import forms, test from django.contrib.auth import SESSION_KEY, get_user_model from django.utils.timezone import now from ..crypto import HMAC, Ed25519, generate_randomness from ..forms import PasswordLessUserCreationForm, RequestForm from ..models import SQRLIdentity, SQRLNut from ..utils import Base64, Encoder TESTING_MODULE = 'sqrl.forms' class TestRequestForm(test.TestCase): def get_key_pair(self): signing_key = signing.SigningKey.generate() verifying_key = signing_key.verify_key signing_key = signing_key._signing_key verifying_key = verifying_key._key return signing_key, verifying_key def _setup(self): hasattr(self, 'nut') and self.nut.delete() hasattr(self, 'user') and self.user.delete() hasattr(self, 'identity') and self.identity.delete() self.nut = SQRLNut( nonce=generate_randomness(), transaction_nonce=generate_randomness(), session_key=generate_randomness(20), is_transaction_complete=False, ip_address='127.0.0.1', timestamp=now(), ) self.nut.save() self.user = get_user_model().objects.create( username='test_clean_session', ) self.identity = SQRLIdentity( user_id=self.user.pk, public_key=Base64.encode(self.public_key), server_unlock_key=Base64.encode(self.server_unlock_key), verify_unlock_key=Base64.encode(self.verify_unlock_key), is_enabled=True, is_only_sqrl=False, ) self.identity.save() self.server_data = OrderedDict([ ('ver', 1), ('nut', self.nut.nonce), ('tif', '8'), ('qry', '/sqrl/auth/?nut=nonce'), ('sfn', 'Test Server'), ]) self.server_data['mac'] = HMAC(self.nut, self.server_data).sign_data() self.server_data = Encoder.normalize(self.server_data) self.client_data = OrderedDict([ ('ver', 1), ('cmd', self.cmd), ('opt', ['sqrlonly']), ]) if self.include_idk: self.client_data['idk'] = self.public_key if self.include_pidk: self.client_data['pidk'] = self.previous_public_key if self.include_suk: self.client_data['suk'] = self.server_unlock_key if self.include_vuk: self.client_data['vuk'] = self.verify_unlock_key self.payload_client_data = Encoder.normalize(OrderedDict( (k, v if not isinstance(v, list) else '~'.join(v)) for k, v in self.client_data.items() )) self.data = { 'client': Encoder.base64_dumps(self.client_data), 'server': Encoder.base64_dumps(self.server_data), } self.signable_data = ( self.data['client'] + self.data['server'] ).encode('ascii') if self.include_ids: self.data['ids'] = Ed25519( self.public_key, self.identity_key, self.signable_data ).sign_data() if self.include_pids: self.data['pids'] = Ed25519( self.previous_public_key, self.previous_identity_key, self.signable_data ).sign_data() if self.include_urs: self.data['urs'] = Ed25519( self.verify_unlock_key, self.unlock_key, self.signable_data ).sign_data() self.cleaned_data = self.data.copy() self.cleaned_data.update({ 'client': self.client_data, 'server': self.server_data, }) self.form = RequestForm(self.nut, data=self.data) def setUp(self): super(TestRequestForm, self).setUp() self.cmd = ['query'] self.identity_key, self.public_key = self.get_key_pair() self.previous_identity_key, self.previous_public_key = self.get_key_pair() self.unlock_key, self.verify_unlock_key = self.get_key_pair() self.server_unlock_key = b'hello' self.include_idk = True self.include_pidk = True self.include_suk = True self.include_vuk = True self.include_ids = True self.include_pids = True self.include_urs = True self._setup() def tearDown(self): self.user and self.user.delete() self.identity and self.identity.delete() self.nut and self.nut.delete() super(TestRequestForm, self).tearDown() def test_init(self): form = RequestForm(mock.sentinel.nut) self.assertEqual(form.nut, mock.sentinel.nut) self.assertIsNone(form.session) self.assertIsNone(form.identity) self.assertIsNone(form.previous_identity) def test_clean_client(self): self.form.cleaned_data = {'client': self.payload_client_data} self.assertEqual(self.form.clean_client(), dict(self.client_data)) def test_clean_client_invalid(self): self.form.cleaned_data = { 'client': { 'ver': '2', } } with self.assertRaises(forms.ValidationError): self.form.clean_client() def test_clean_server_not_dict(self): self.form.cleaned_data = {'server': mock.sentinel.server_data} self.assertEqual(self.form.clean_server(), mock.sentinel.server_data) def test_clean_server(self): self.form.cleaned_data = { 'server': self.server_data } self.assertEqual( self.form.clean_server(), self.server_data ) def test_clean_server_mac_not_base64(self): self.server_data['mac'] = 'hello' self.form.cleaned_data = {'server': self.server_data} with self.assertRaises(forms.ValidationError): self.form.clean_server() def test_clean_server_mismatch_nut(self): self.server_data['nut'] = self.server_data['nut'][::-1] self.form.cleaned_data = {'server': self.server_data} with self.assertRaises(forms.ValidationError): self.form.clean_server() def test_clean_server_mismatch_missing_mac(self): del self.server_data['mac'] self.form.cleaned_data = {'server': self.server_data} with self.assertRaises(forms.ValidationError): self.form.clean_server() def test_clean_server_invalid_mac(self): self.server_data['mac'] = self.server_data['mac'][::-1] self.form.cleaned_data = {'server': self.server_data} with self.assertRaises(forms.ValidationError): self.form.clean_server() def test_clean_ids(self): self.form.cleaned_data = { 'client': self.client_data, 'ids': self.data['ids'], } self.assertEqual( self.form.clean_ids(), self.data['ids'] ) def test_clean_ids_invalid(self): self.form.cleaned_data = { 'client': self.client_data, 'ids': self.data['ids'][::-1], } with self.assertRaises(forms.ValidationError): self.form.clean_ids() def test_clean_pids_valid(self): self.form.cleaned_data = { 'client': self.client_data, 'pids': self.data['pids'], } self.assertEqual( self.form.clean_pids(), self.data['pids'] ) def test_clean_pids_invalid(self): self.form.cleaned_data = { 'client': self.client_data, 'pids': self.data['pids'][::-1], } with self.assertRaises(forms.ValidationError): self.form.clean_pids() def test_clean_pids_missing_pids(self): self.form.cleaned_data = { 'client': self.client_data, } with self.assertRaises(forms.ValidationError): self.form.clean_pids() def test_clean_pids_missing_pidk(self): self.client_data.pop('pidk') self.form.cleaned_data = { 'client': self.client_data, 'pids': self.data['pids'], } with self.assertRaises(forms.ValidationError): self.form.clean_pids() def test_clean_urs(self): self.form.cleaned_data = self.cleaned_data self.form.identity = self.identity self.assertEqual( self.form._clean_urs(), self.data['urs'] ) def test_clean_urs_invalid(self): self.form.cleaned_data = { 'client': self.client_data, 'urs': self.data['urs'][::-1], } self.form.identity = self.identity with self.assertRaises(forms.ValidationError): self.form._clean_urs() def test_clean_urs_no_suk(self): self.form.cleaned_data = { 'client': self.client_data, 'urs': self.data['urs'], } self.form.identity = self.identity self.identity.server_unlock_key = None with self.assertRaises(forms.ValidationError): self.form._clean_urs() def test_clean_urs_no_vuk(self): self.form.cleaned_data = { 'client': self.client_data, 'urs': self.data['urs'], } self.form.identity = self.identity self.identity.verify_unlock_key = None with self.assertRaises(forms.ValidationError): self.form._clean_urs() def test_clean_cmd_query(self): self.cmd = ['query'] self._setup() self.form.cleaned_data = { 'client': self.client_data, } self.assertIsNone(self.form._clean_client_cmd()) def test_clean_cmd_ident(self): self.cmd = ['ident'] self._setup() self.form.cleaned_data = { 'client': self.client_data, } self.assertIsNone(self.form._clean_client_cmd()) def test_clean_cmd_ident_no_suk_vuk_without_identity(self): self.cmd = ['ident'] self.include_suk = None self.include_vuk = None self._setup() self.form.cleaned_data = { 'client': self.client_data, } with self.assertRaises(forms.ValidationError): self.form._clean_client_cmd() def test_clean_cmd_ident_suk_vuk_with_identity(self): self.cmd = ['ident'] self._setup() self.form.identity = self.identity self.form.cleaned_data = { 'client': self.client_data, } with self.assertRaises(forms.ValidationError): self.form._clean_client_cmd() def test_clean_cmd_ident_with_disable(self): self.cmd = ['ident', 'disable'] self._setup() self.form.cleaned_data = { 'client': self.client_data, } with self.assertRaises(forms.ValidationError): self.form._clean_client_cmd() def test_clean_cmd_ident_no_urs_with_previous_identity(self): self.cmd = ['ident'] self._setup() self.form.previous_identity = self.identity self.form.cleaned_data = { 'client': self.client_data, } with self.assertRaises(forms.ValidationError): self.form._clean_client_cmd() def test_clean_cmd_disable(self): self.cmd = ['disable'] self._setup() self.form.identity = self.identity self.form.cleaned_data = { 'client': self.client_data, } self.assertIsNone(self.form._clean_client_cmd()) def test_clean_cmd_disable_no_identity(self): self.cmd = ['disable'] self._setup() self.form.cleaned_data = { 'client': self.client_data, } with self.assertRaises(forms.ValidationError): self.form._clean_client_cmd() def test_clean_cmd_disable_with_enable(self): self.cmd = ['disable', 'enable'] self._setup() self.form.identity = self.identity self.form.cleaned_data = { 'client': self.client_data, } with self.assertRaises(forms.ValidationError): self.form._clean_client_cmd() def test_clean_cmd_enable(self): self.cmd = ['enable'] self._setup() self.form.identity = self.identity self.form.cleaned_data = self.cleaned_data self.assertIsNone(self.form._clean_client_cmd()) def test_clean_cmd_enable_no_urs(self): self.cmd = ['enable'] self.include_urs = False self._setup() self.form.identity = self.identity self.form.cleaned_data = self.cleaned_data with self.assertRaises(forms.ValidationError): self.form._clean_client_cmd() def test_clean_cmd_enable_no_identity(self): self.cmd = ['enable'] self._setup() self.form.cleaned_data = self.cleaned_data with self.assertRaises(forms.ValidationError): self.form._clean_client_cmd() def test_clean_cmd_enable_with_disable(self): self.cmd = ['enable', 'disable'] self._setup() self.form.identity = self.identity self.form.cleaned_data = self.cleaned_data with self.assertRaises(forms.ValidationError): self.form._clean_client_cmd() def test_clean_cmd_remove(self): self.cmd = ['remove'] self._setup() self.form.identity = self.identity self.form.cleaned_data = self.cleaned_data self.assertIsNone(self.form._clean_client_cmd()) def test_clean_cmd_remove_no_identity(self): self.cmd = ['remove'] self._setup() self.form.cleaned_data = self.cleaned_data with self.assertRaises(forms.ValidationError): self.form._clean_client_cmd() def test_clean_cmd_remove_no_urs(self): self.cmd = ['remove'] self.include_urs = False self._setup() self.form.identity = self.identity self.form.cleaned_data = self.cleaned_data with self.assertRaises(forms.ValidationError): self.form._clean_client_cmd() def test_clean_cmd_remove_with_other_cmd(self): self.cmd = ['remove', 'ident'] self._setup() self.form.identity = self.identity self.form.cleaned_data = self.cleaned_data with self.assertRaises(forms.ValidationError): self.form._clean_client_cmd() def test_clean_session_empty(self): self.form.session = {} self.assertIsNone(self.form._clean_session()) def test_clean_session_user_not_found(self): assert not get_user_model().objects.filter(pk=1000).first() self.form.session = { SESSION_KEY: '1000', } self.assertIsNone(self.form._clean_session()) def test_clean_session_user_not_int(self): self.form.session = { SESSION_KEY: 'aaa', } self.assertIsNone(self.form._clean_session()) def test_clean_session_no_sqrl_identity(self): self.identity.delete() self.identity = None self.form.session = { SESSION_KEY: str(self.user.pk), } self.assertIsNone(self.form._clean_session()) def test_clean_session_public_key_not_matches(self): self.identity.public_key = self.identity.public_key[::-1] self.identity.save() self.form.session = { SESSION_KEY: str(self.user.pk), } self.form.cleaned_data = self.cleaned_data with self.assertRaises(forms.ValidationError): self.form._clean_session() def test_clean_session_user_code_mismatch(self): self.form.identity = self.identity self.form.session = { SESSION_KEY: str(self.user.pk + 1), } self.form.cleaned_data = self.cleaned_data with self.assertRaises(forms.ValidationError): self.form._clean_session() def test_clean(self): self.form.cleaned_data = self.cleaned_data actual = self.form.clean() self.assertEqual(actual, self.cleaned_data) @mock.patch.object(RequestForm, 'find_identities') @mock.patch.object(RequestForm, '_clean_client_cmd') @mock.patch.object(RequestForm, '_clean_urs') @mock.patch.object(RequestForm, 'find_session') @mock.patch.object(RequestForm, '_clean_session') @mock.patch.object(forms.Form, 'clean') def test_clean_mock(self, mock_super_clean, mock_clean_session, mock_find_session, mock_clean_urs, mock_clean_client_cmd, mock_find_identities): mock_super_clean.return_value = mock.sentinel.cleaned_data actual = self.form.clean() self.assertEqual(actual, mock.sentinel.cleaned_data) mock_super_clean.assert_called_once_with() mock_clean_session.assert_called_once_with() mock_find_session.assert_called_once_with() mock_clean_urs.assert_called_once_with() mock_clean_client_cmd.assert_called_once_with() mock_find_identities.assert_called_once_with() @mock.patch(TESTING_MODULE + '.SessionMiddleware') def test_find_session(self, mock_session_middleware): self.form.find_session() self.assertEqual( self.form.session, mock_session_middleware.return_value.SessionStore.return_value ) mock_session_middleware.return_value.SessionStore.assert_called_once_with( self.nut.session_key, ) def test_find_identities(self): self.form.cleaned_data = self.cleaned_data self.form.find_identities() self.assertIsNotNone(self.form.identity) self.assertIsInstance(self.form.identity, SQRLIdentity) self.assertEqual(self.form.identity.public_key, self.identity.public_key) self.assertIsNone(self.form.previous_identity) @mock.patch.object(RequestForm, '_get_identity') def test_find_identities_mock(self, mock_get_identity): self.form.cleaned_data = self.cleaned_data mock_get_identity.side_effect = mock.sentinel.identity, mock.sentinel.previous_identity self.form.find_identities() self.assertEqual(self.form.identity, mock.sentinel.identity) self.assertEqual(self.form.previous_identity, mock.sentinel.previous_identity) mock_get_identity.assert_has_calls([ mock.call(self.public_key), mock.call(self.previous_public_key), ]) @mock.patch(TESTING_MODULE + '.SQRLIdentity') def test_get_identity(self, mock_sqrl_identity): actual = self.form._get_identity(self.public_key) self.assertEqual( actual, mock_sqrl_identity.objects.filter.return_value.first.return_value ) mock_sqrl_identity.objects.filter.assert_called_once_with( public_key=Base64.encode(self.public_key) ) def test_get_identity_no_key(self): self.assertIsNone(self.form._get_identity(None)) class TestRandomPasswordUserCreationForm(unittest.TestCase): def test_init(self): self.assertIn('password1', PasswordLessUserCreationForm.base_fields) self.assertIn('password2', PasswordLessUserCreationForm.base_fields) form = PasswordLessUserCreationForm() self.assertNotIn('password1', form.fields) self.assertNotIn('password2', form.fields) def test_save(self): form = PasswordLessUserCreationForm({'username': 'test'}) self.assertTrue(form.is_valid()) user = form.save() self.assertEqual(user.username, 'test') self.assertTrue(user.password.startswith('!')) user.delete()