Files
sharey/expiry_db.py
2025-09-27 17:45:52 +01:00

307 lines
10 KiB
Python

#!/usr/bin/env python3
"""
Database for tracking file expiry in Sharey
Simple SQLite database to store file paths and their expiry times
"""
import sqlite3
import os
from datetime import datetime
from pathlib import Path
class ExpiryDatabase:
"""Manages file expiry tracking database"""
def __init__(self, db_path="expiry.db"):
self.db_path = db_path
self.init_database()
def init_database(self):
"""Initialize the expiry database with required tables"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Create expiry table
cursor.execute('''
CREATE TABLE IF NOT EXISTS file_expiry (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_path TEXT UNIQUE NOT NULL,
expires_at TEXT NOT NULL,
created_at TEXT NOT NULL,
storage_backend TEXT NOT NULL
)
''')
# Create URL shortener table
cursor.execute('''
CREATE TABLE IF NOT EXISTS url_redirects (
id INTEGER PRIMARY KEY AUTOINCREMENT,
short_code TEXT UNIQUE NOT NULL,
target_url TEXT NOT NULL,
created_at TEXT NOT NULL,
expires_at TEXT,
click_count INTEGER DEFAULT 0,
is_active INTEGER DEFAULT 1,
created_by_ip TEXT,
notes TEXT
)
''')
conn.commit()
conn.close()
print(f"✅ Expiry database initialized: {self.db_path}")
def add_file(self, file_path: str, expires_at: str, storage_backend: str = "b2"):
"""Add a file with expiry time to the database"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
created_at = datetime.utcnow().isoformat() + 'Z'
cursor.execute('''
INSERT OR REPLACE INTO file_expiry
(file_path, expires_at, created_at, storage_backend)
VALUES (?, ?, ?, ?)
''', (file_path, expires_at, created_at, storage_backend))
conn.commit()
conn.close()
print(f"📝 Added expiry tracking: {file_path} expires at {expires_at}")
return True
except Exception as e:
print(f"❌ Failed to add expiry tracking for {file_path}: {e}")
return False
def get_expired_files(self) -> list:
"""Get list of files that have expired"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
current_time = datetime.utcnow().isoformat() + 'Z'
cursor.execute('''
SELECT file_path, expires_at, storage_backend
FROM file_expiry
WHERE expires_at <= ?
ORDER BY expires_at ASC
''', (current_time,))
expired_files = cursor.fetchall()
conn.close()
return [
{
'file_path': row[0],
'expires_at': row[1],
'storage_backend': row[2]
}
for row in expired_files
]
except Exception as e:
print(f"❌ Failed to get expired files: {e}")
return []
def remove_file(self, file_path: str):
"""Remove a file from expiry tracking (after deletion)"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('DELETE FROM file_expiry WHERE file_path = ?', (file_path,))
conn.commit()
conn.close()
print(f"🗑️ Removed from expiry tracking: {file_path}")
return True
except Exception as e:
print(f"❌ Failed to remove expiry tracking for {file_path}: {e}")
return False
def get_all_files(self) -> list:
"""Get all files in expiry tracking"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
SELECT file_path, expires_at, created_at, storage_backend
FROM file_expiry
ORDER BY expires_at ASC
''')
all_files = cursor.fetchall()
conn.close()
return [
{
'file_path': row[0],
'expires_at': row[1],
'created_at': row[2],
'storage_backend': row[3]
}
for row in all_files
]
except Exception as e:
print(f"❌ Failed to get all files: {e}")
return []
def get_stats(self) -> dict:
"""Get database statistics"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Total files
cursor.execute('SELECT COUNT(*) FROM file_expiry')
total_files = cursor.fetchone()[0]
# Expired files
current_time = datetime.utcnow().isoformat() + 'Z'
cursor.execute('SELECT COUNT(*) FROM file_expiry WHERE expires_at <= ?', (current_time,))
expired_files = cursor.fetchone()[0]
# Files by storage backend
cursor.execute('SELECT storage_backend, COUNT(*) FROM file_expiry GROUP BY storage_backend')
by_backend = dict(cursor.fetchall())
conn.close()
return {
'total_files': total_files,
'expired_files': expired_files,
'active_files': total_files - expired_files,
'by_backend': by_backend
}
except Exception as e:
print(f"❌ Failed to get stats: {e}")
return {}
# URL Shortener Methods
def add_redirect(self, short_code, target_url, expires_at=None, created_by_ip=None, notes=None):
"""Add a URL redirect"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
created_at = datetime.utcnow().isoformat() + 'Z'
cursor.execute('''
INSERT INTO url_redirects
(short_code, target_url, created_at, expires_at, created_by_ip, notes)
VALUES (?, ?, ?, ?, ?, ?)
''', (short_code, target_url, created_at, expires_at, created_by_ip, notes))
conn.commit()
conn.close()
print(f"✅ Added redirect: {short_code} -> {target_url}")
return True
except sqlite3.IntegrityError:
print(f"❌ Short code already exists: {short_code}")
return False
except Exception as e:
print(f"❌ Failed to add redirect: {e}")
return False
def get_redirect(self, short_code):
"""Get redirect target URL and increment click count"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Check if redirect exists and is active
cursor.execute('''
SELECT target_url, expires_at, is_active
FROM url_redirects
WHERE short_code = ? AND is_active = 1
''', (short_code,))
result = cursor.fetchone()
if not result:
conn.close()
return None
target_url, expires_at, is_active = result
# Check if expired
if expires_at:
current_time = datetime.utcnow().isoformat() + 'Z'
if expires_at <= current_time:
conn.close()
return None
# Increment click count
cursor.execute('''
UPDATE url_redirects
SET click_count = click_count + 1
WHERE short_code = ?
''', (short_code,))
conn.commit()
conn.close()
return target_url
except Exception as e:
print(f"❌ Failed to get redirect: {e}")
return None
def disable_redirect(self, short_code):
"""Disable a redirect (for abuse/takedown)"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
UPDATE url_redirects
SET is_active = 0
WHERE short_code = ?
''', (short_code,))
rows_affected = cursor.rowcount
conn.commit()
conn.close()
if rows_affected > 0:
print(f"✅ Disabled redirect: {short_code}")
return True
else:
print(f"❌ Redirect not found: {short_code}")
return False
except Exception as e:
print(f"❌ Failed to disable redirect: {e}")
return False
def list_redirects(self, limit=100):
"""List recent redirects for admin purposes"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
SELECT short_code, target_url, created_at, expires_at,
click_count, is_active, created_by_ip
FROM url_redirects
ORDER BY created_at DESC
LIMIT ?
''', (limit,))
redirects = []
for row in cursor.fetchall():
redirects.append({
'short_code': row[0],
'target_url': row[1],
'created_at': row[2],
'expires_at': row[3],
'click_count': row[4],
'is_active': bool(row[5]),
'created_by_ip': row[6]
})
conn.close()
return redirects
except Exception as e:
print(f"❌ Failed to list redirects: {e}")
return []