Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Identity Manager for DNSSEC Awesomeness #513

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
34 changes: 20 additions & 14 deletions api/desecapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,10 @@ def filter_qname(self, qname: str, **kwargs) -> models.query.QuerySet:
).filter(dotted_qname__endswith=F('dotted_name'), **kwargs)

def most_specific_zone(self, fqdn: str) -> Tuple[Domain, str]:
domain = self.filter_qname(fqdn).order_by('-name_length').first()
try:
domain = self.filter_qname(fqdn).order_by('-name_length')[0]
except IndexError:
raise Domain.DoesNotExist
subname = fqdn[:-len(domain.name)].rstrip('.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.removesuffix()

return domain, subname

Expand Down Expand Up @@ -999,7 +1002,7 @@ class Identity(models.Model):
created = models.DateTimeField(auto_now_add=True)
owner = models.ForeignKey(User, on_delete=models.PROTECT, related_name='identities')
default_ttl = models.PositiveIntegerField(default=300)
rrs = models.ManyToManyField(to=RR)
rrs = models.ManyToManyField(to=RR, related_name='identities')
scheduled_removal = models.DateTimeField(null=True)

class Meta:
Expand All @@ -1010,14 +1013,14 @@ def get_rrs(self) -> List[RR]:

def save(self, *args, **kwargs):
for rr in self.get_rrs():
self.rrs.add(rr)
rr.rrset.save()
rr.save()
self.rrs.add(rr)
return super().save(*args, **kwargs)

def delete(self, using=None, keep_parents=False):
for rr in self.rrs.all(): # TODO use one query
if len(rr.identities) == 1:
if len(rr.identities.all()) == 1:
rr.delete()
return super().delete(using, keep_parents)

Expand Down Expand Up @@ -1076,7 +1079,7 @@ def __init__(self, *args, **kwargs):
if 'not_valid_after' not in kwargs:
self.scheduled_removal = self.not_valid_after

def get_record_contents(self) -> List[str]:
def get_record_content(self) -> str:
# choose hash function
if self.tlsa_matching_type == self.MatchingType.SHA256:
hash_function = hazmat.primitives.hashes.SHA256()
Expand All @@ -1100,7 +1103,7 @@ def get_record_contents(self) -> List[str]:
hash = h.finalize().hex()

# create TLSA record content
return [f"{self.tlsa_certificate_usage} {self.tlsa_selector} {self.tlsa_matching_type} {hash}"]
return f"{self.tlsa_certificate_usage} {self.tlsa_selector} {self.tlsa_matching_type} {hash}"

@property
def _cert(self) -> x509.Certificate:
Expand Down Expand Up @@ -1145,14 +1148,17 @@ def subject_names_clean(self) -> Set[str]:
return clean

def get_rrs(self) -> List[RR]:
return [
self.get_or_create_rr(
fqdn=f"_{self.port:n}._{self.protocol}.{qname}",
content=content,
)
for qname in self.subject_names_clean
for content in self.get_record_contents()
]
rrs = []
content = self.get_record_content()
for qname in self.subject_names_clean:
try:
rrs.append(self.get_or_create_rr(
fqdn=f"_{self.port:n}._{self.protocol}.{qname}",
content=content,
))
except Domain.DoesNotExist:
pass
return rrs

@property
def not_valid_before(self):
Expand Down
40 changes: 35 additions & 5 deletions api/desecapi/tests/test_identities.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ def test_generated_rrs_many_rrsets(self):

id = models.TLSIdentity(certificate=CERTIFICATE, owner=self.user, protocol=models.TLSIdentity.Protocol.SCTP)

self.assertEqual(
id.domains_subnames(),
{(domain, '_443._sctp'), (domain, '_443._sctp.desec'), (domain, '_443._sctp.dedyn')},
)
self.assertEqual(id.subject_names, SUBJECT_NAMES)

rrs = id.get_rrs()
self.assertEqual(len(rrs), 3)
Expand All @@ -69,7 +66,6 @@ def test_generated_rrs_one_rrset(self):
domain.save()

id = models.TLSIdentity(certificate=CERTIFICATE, owner=self.user, port=123)
self.assertEqual(id.domains_subnames(), {(domain, '_123._tcp')})

rrs = id.get_rrs()
self.assertEqual(len(rrs), 1)
Expand Down Expand Up @@ -115,3 +111,37 @@ def test_create_delete_rrs(self):
id.delete()
rrset = models.RRset.objects.get(domain__name='desec.example.dedyn.io', type='TLSA', subname='_123._tcp')
self.assertEqual(len(rrset.records.all()), 1)

def test_duplicate_record(self):
def count_tlsa_records():
return models.RRset.objects.get(
domain__name='desec.example.dedyn.io',
type='TLSA', subname='_443._tcp'
).records.count()

domain = models.Domain(name='desec.example.dedyn.io', owner=self.user)
domain.save()

# insert first cert, insert second, delete first, delete second
id1 = models.TLSIdentity(certificate=CERTIFICATE, owner=self.user)
id2 = models.TLSIdentity(certificate=CERTIFICATE, owner=self.user)
id1.save()
self.assertEqual(count_tlsa_records(), 1)
id2.save()
self.assertEqual(count_tlsa_records(), 1)
id1.delete()
self.assertEqual(count_tlsa_records(), 1)
id2.delete()
self.assertEqual(count_tlsa_records(), 0)

# insert first cert, insert second, delete second, delete first
id1 = models.TLSIdentity(certificate=CERTIFICATE, owner=self.user)
id2 = models.TLSIdentity(certificate=CERTIFICATE, owner=self.user)
id1.save()
self.assertEqual(count_tlsa_records(), 1)
id2.save()
self.assertEqual(count_tlsa_records(), 1)
id2.delete()
self.assertEqual(count_tlsa_records(), 1)
id1.delete()
self.assertEqual(count_tlsa_records(), 0)