diff --git a/mirrors/tests/__init__.py b/mirrors/tests/__init__.py index fb6c10d..a1d3c2c 100644 --- a/mirrors/tests/__init__.py +++ b/mirrors/tests/__init__.py @@ -1,12 +1,13 @@ from mirrors.models import MirrorUrl, MirrorProtocol, Mirror -def create_mirror_url(): - mirror = Mirror.objects.create(name='mirror1', +def create_mirror_url(name='mirror1', country='US', + protocol='http', url='https://archlinux.org'): + mirror = Mirror.objects.create(name=name, admin_email='admin@archlinux.org') - mirror_protocol = MirrorProtocol.objects.create(protocol='http') - mirror_url = MirrorUrl.objects.create(url='https://archlinux.org', + mirror_protocol = MirrorProtocol.objects.create(protocol=protocol) + mirror_url = MirrorUrl.objects.create(url=url, protocol=mirror_protocol, mirror=mirror, - country='US') + country=country) return mirror_url diff --git a/mirrors/tests/test_mirrorlist.py b/mirrors/tests/test_mirrorlist.py index 5590a96..1ad3d8d 100644 --- a/mirrors/tests/test_mirrorlist.py +++ b/mirrors/tests/test_mirrorlist.py @@ -30,9 +30,21 @@ def test_mirrorlist_all_https(self): # TODO: test 200 case def test_mirrorlist_filter(self): - response = self.client.get('/mirrorlist/?country=all&protocol=http&ip_version=4') + jp_mirror_url = create_mirror_url( + name='jp_mirror', + country='JP', + protocol='https', + url='https://wikipedia.jp') + + # First test that we correctly see the above mirror. + response = self.client.get('/mirrorlist/?country=JP&protocol=https') self.assertEqual(response.status_code, 200) - self.assertIn(self.mirror_url.hostname, response.content) + self.assertIn(jp_mirror_url.hostname, response.content) + + # Now confirm that the US mirror did not show up. + self.assertNotIn(self.mirror_url.hostname, response.content) + + jp_mirror_url.delete() def test_generate(self): response = self.client.get('/mirrorlist/?country=all&protocol=http&ip_version=4') diff --git a/mirrors/views/mirrorlist.py b/mirrors/views/mirrorlist.py index 35d59e8..45c0181 100644 --- a/mirrors/views/mirrorlist.py +++ b/mirrors/views/mirrorlist.py @@ -55,7 +55,8 @@ def as_div(self): @csrf_exempt def generate_mirrorlist(request): if request.method == 'POST' or len(request.GET) > 0: - form = MirrorlistForm(data=request.POST) + data = request.POST if request.method == 'POST' else request.GET + form = MirrorlistForm(data=data) if form.is_valid(): countries = form.cleaned_data['country'] protocols = form.cleaned_data['protocol']