Federation improvements (#849)

* add cli option to reset state for remote domains/users

* prevent access to blocked domain or account

* add missing migrations

---------

Co-authored-by: Your Name <you@example.com>
This commit is contained in:
Henri Dickson 2025-01-16 01:09:27 -05:00 committed by GitHub
parent ce37c25abb
commit 2f35931213
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 107 additions and 4 deletions

View file

@ -28,7 +28,7 @@ def current_user_relationship(context, target_identity: "APIdentity"):
"rejecting": False, "rejecting": False,
"status": "", "status": "",
} }
if target_identity and current_identity: if target_identity and current_identity and not target_identity.restricted:
if current_identity != target_identity: if current_identity != target_identity:
if current_identity.is_blocking( if current_identity.is_blocking(
target_identity target_identity

View file

@ -34,6 +34,11 @@ class Migration(migrations.Migration):
), ),
("state", models.CharField(default="outdated", max_length=100)), ("state", models.CharField(default="outdated", max_length=100)),
("state_changed", models.DateTimeField(auto_now_add=True)), ("state_changed", models.DateTimeField(auto_now_add=True)),
("state_next_attempt", models.DateTimeField(blank=True, null=True)),
(
"state_locked_until",
models.DateTimeField(blank=True, db_index=True, null=True),
),
("nodeinfo", models.JSONField(blank=True, null=True)), ("nodeinfo", models.JSONField(blank=True, null=True)),
("local", models.BooleanField()), ("local", models.BooleanField()),
("blocked", models.BooleanField(default=False)), ("blocked", models.BooleanField(default=False)),

View file

@ -279,6 +279,8 @@ class Domain(models.Model):
# state = StateField(DomainStates) # state = StateField(DomainStates)
state = models.CharField(max_length=100, default="outdated") state = models.CharField(max_length=100, default="outdated")
state_changed = models.DateTimeField(auto_now_add=True) state_changed = models.DateTimeField(auto_now_add=True)
state_next_attempt = models.DateTimeField(blank=True, null=True)
state_locked_until = models.DateTimeField(null=True, blank=True, db_index=True)
# nodeinfo 2.0 detail about the remote server # nodeinfo 2.0 detail about the remote server
nodeinfo = models.JSONField(null=True, blank=True) nodeinfo = models.JSONField(null=True, blank=True)
@ -352,6 +354,24 @@ class Domain(models.Model):
def __str__(self): def __str__(self):
return self.domain return self.domain
def recursively_blocked(self) -> bool:
"""
Checks for blocks on all right subsets of this domain, except the very
last part of the TLD.
Yes, I know this weirdly lets you block ".co.uk" or whatever, but
people can do that if they want I guess.
"""
# Efficient short-circuit
if self.blocked:
return True
# Build domain list
domain_parts = [self.domain]
while "." in domain_parts[-1]:
domain_parts.append(domain_parts[-1].split(".", 1)[1])
# See if any of those are blocked
return Domain.objects.filter(domain__in=domain_parts, blocked=True).exists()
def upload_store(): def upload_store():
return FileSystemStorage( return FileSystemStorage(

View file

@ -141,6 +141,10 @@ class Takahe:
@staticmethod @staticmethod
def fetch_remote_identity(handler: str) -> int | None: def fetch_remote_identity(handler: str) -> int | None:
d = handler.split("@")[-1]
domain = Domain.objects.filter(domain=d).first()
if domain and domain.recursively_blocked:
return
InboxMessage.create_internal({"type": "FetchIdentity", "handle": handler}) InboxMessage.create_internal({"type": "FetchIdentity", "handle": handler})
@staticmethod @staticmethod
@ -670,7 +674,9 @@ class Takahe:
return FediverseHtmlParser(linebreaks_filter(txt)).html return FediverseHtmlParser(linebreaks_filter(txt)).html
@staticmethod @staticmethod
def update_state(obj: Post | PostInteraction | Relay | Identity, state: str): def update_state(
obj: Post | PostInteraction | Relay | Identity | Domain, state: str
):
obj.state = state obj.state = state
obj.state_changed = timezone.now() obj.state_changed = timezone.now()
obj.state_next_attempt = None obj.state_next_attempt = None
@ -696,6 +702,7 @@ class Takahe:
nodeinfo__protocols__contains="neodb", nodeinfo__protocols__contains="neodb",
nodeinfo__metadata__nodeEnvironment="production", nodeinfo__metadata__nodeEnvironment="production",
local=False, local=False,
blocked=False,
).values_list("pk", flat=True) ).values_list("pk", flat=True)
) )
cache.set(cache_key, peers, timeout=1800) cache.set(cache_key, peers, timeout=1800)

View file

@ -1,7 +1,10 @@
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from tqdm import tqdm from tqdm import tqdm
import httpx
from users.models import Preference, User from users.models import Preference, User
from takahe.models import Identity, Domain
from takahe.utils import Takahe
class Command(BaseCommand): class Command(BaseCommand):
@ -16,6 +19,11 @@ class Command(BaseCommand):
action="store_true", action="store_true",
help="check and fix integrity for missing data for user models", help="check and fix integrity for missing data for user models",
) )
parser.add_argument(
"--remote",
action="store_true",
help="reset state for remote domains/users with previous connection issues",
)
parser.add_argument( parser.add_argument(
"--super", action="store", nargs="*", help="list or toggle superuser" "--super", action="store", nargs="*", help="list or toggle superuser"
) )
@ -32,6 +40,8 @@ class Command(BaseCommand):
self.list(self.users) self.list(self.users)
if options["integrity"]: if options["integrity"]:
self.integrity() self.integrity()
if options["remote"]:
self.check_remote()
if options["super"] is not None: if options["super"] is not None:
self.superuser(options["super"]) self.superuser(options["super"])
if options["staff"] is not None: if options["staff"] is not None:
@ -50,6 +60,7 @@ class Command(BaseCommand):
def integrity(self): def integrity(self):
count = 0 count = 0
self.stdout.write("Checking local users")
for user in tqdm(User.objects.filter(is_active=True)): for user in tqdm(User.objects.filter(is_active=True)):
i = user.identity.takahe_identity i = user.identity.takahe_identity
if i.public_key is None: if i.public_key is None:
@ -64,7 +75,60 @@ class Command(BaseCommand):
if self.fix: if self.fix:
Preference.objects.create(user=user) Preference.objects.create(user=user)
count += 1 count += 1
self.stdout.write(f"{count} issues")
def check_remote(self):
headers = {
"Accept": "application/json,application/activity+json,application/ld+json"
}
with httpx.Client(timeout=0.5) as client:
count = 0
self.stdout.write("Checking remote domains")
for d in tqdm(
Domain.objects.filter(
local=False, blocked=False, state="connection_issue"
)
):
try:
response = client.get(
f"https://{d.domain}/.well-known/nodeinfo",
follow_redirects=True,
headers=headers,
)
if response.status_code == 200 and "json" in response.headers.get(
"content-type", ""
):
count += 1
if self.fix:
Takahe.update_state(d, "outdated")
except Exception:
pass
self.stdout.write(f"{count} issues")
count = 0
self.stdout.write("Checking remote identities")
for i in tqdm(
Identity.objects.filter(
public_key__isnull=True,
local=False,
restriction=0,
state="connection_issue",
)
):
try:
response = client.request(
"get",
i.actor_uri,
headers=headers,
follow_redirects=True,
)
if (
response.status_code == 200
and "json" in response.headers.get("content-type", "")
and "@context" in response.text
):
Takahe.update_state(i, "outdated")
except Exception:
pass
self.stdout.write(f"{count} issues")
def superuser(self, v): def superuser(self, v):
if v == []: if v == []:

View file

@ -562,9 +562,12 @@ class Migration(migrations.Migration):
field=models.CharField( field=models.CharField(
choices=[ choices=[
("en", "English"), ("en", "English"),
("da", "Danish"),
("de", "German"),
("fr", "French"),
("it", "Italian"),
("zh-hans", "Simplified Chinese"), ("zh-hans", "Simplified Chinese"),
("zh-hant", "Traditional Chinese"), ("zh-hant", "Traditional Chinese"),
("da", "Danish"),
], ],
default="en", default="en",
max_length=10, max_length=10,

View file

@ -109,6 +109,10 @@ class APIdentity(models.Model):
else: else:
return f"{self.username}@{self.domain_name}" return f"{self.username}@{self.domain_name}"
@property
def restricted(self):
return self.takahe_identity.restriction != 2
@property @property
def following(self): def following(self):
return Takahe.get_following_ids(self.pk) return Takahe.get_following_ids(self.pk)