307 lines
10 KiB
Python
307 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Database for tracking file expiry in Sharey
|
|
Simple SQLite database to store file paths and their expiry times
|
|
"""
|
|
import sqlite3
|
|
import os
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
class ExpiryDatabase:
|
|
"""Manages file expiry tracking database"""
|
|
|
|
def __init__(self, db_path="expiry.db"):
|
|
self.db_path = db_path
|
|
self.init_database()
|
|
|
|
def init_database(self):
|
|
"""Initialize the expiry database with required tables"""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Create expiry table
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS file_expiry (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
file_path TEXT UNIQUE NOT NULL,
|
|
expires_at TEXT NOT NULL,
|
|
created_at TEXT NOT NULL,
|
|
storage_backend TEXT NOT NULL
|
|
)
|
|
''')
|
|
|
|
# Create URL shortener table
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS url_redirects (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
short_code TEXT UNIQUE NOT NULL,
|
|
target_url TEXT NOT NULL,
|
|
created_at TEXT NOT NULL,
|
|
expires_at TEXT,
|
|
click_count INTEGER DEFAULT 0,
|
|
is_active INTEGER DEFAULT 1,
|
|
created_by_ip TEXT,
|
|
notes TEXT
|
|
)
|
|
''')
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
print(f"✅ Expiry database initialized: {self.db_path}")
|
|
|
|
def add_file(self, file_path: str, expires_at: str, storage_backend: str = "b2"):
|
|
"""Add a file with expiry time to the database"""
|
|
try:
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
created_at = datetime.utcnow().isoformat() + 'Z'
|
|
|
|
cursor.execute('''
|
|
INSERT OR REPLACE INTO file_expiry
|
|
(file_path, expires_at, created_at, storage_backend)
|
|
VALUES (?, ?, ?, ?)
|
|
''', (file_path, expires_at, created_at, storage_backend))
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
print(f"📝 Added expiry tracking: {file_path} expires at {expires_at}")
|
|
return True
|
|
except Exception as e:
|
|
print(f"❌ Failed to add expiry tracking for {file_path}: {e}")
|
|
return False
|
|
|
|
def get_expired_files(self) -> list:
|
|
"""Get list of files that have expired"""
|
|
try:
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
current_time = datetime.utcnow().isoformat() + 'Z'
|
|
|
|
cursor.execute('''
|
|
SELECT file_path, expires_at, storage_backend
|
|
FROM file_expiry
|
|
WHERE expires_at <= ?
|
|
ORDER BY expires_at ASC
|
|
''', (current_time,))
|
|
|
|
expired_files = cursor.fetchall()
|
|
conn.close()
|
|
|
|
return [
|
|
{
|
|
'file_path': row[0],
|
|
'expires_at': row[1],
|
|
'storage_backend': row[2]
|
|
}
|
|
for row in expired_files
|
|
]
|
|
except Exception as e:
|
|
print(f"❌ Failed to get expired files: {e}")
|
|
return []
|
|
|
|
def remove_file(self, file_path: str):
|
|
"""Remove a file from expiry tracking (after deletion)"""
|
|
try:
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute('DELETE FROM file_expiry WHERE file_path = ?', (file_path,))
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
print(f"🗑️ Removed from expiry tracking: {file_path}")
|
|
return True
|
|
except Exception as e:
|
|
print(f"❌ Failed to remove expiry tracking for {file_path}: {e}")
|
|
return False
|
|
|
|
def get_all_files(self) -> list:
|
|
"""Get all files in expiry tracking"""
|
|
try:
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute('''
|
|
SELECT file_path, expires_at, created_at, storage_backend
|
|
FROM file_expiry
|
|
ORDER BY expires_at ASC
|
|
''')
|
|
|
|
all_files = cursor.fetchall()
|
|
conn.close()
|
|
|
|
return [
|
|
{
|
|
'file_path': row[0],
|
|
'expires_at': row[1],
|
|
'created_at': row[2],
|
|
'storage_backend': row[3]
|
|
}
|
|
for row in all_files
|
|
]
|
|
except Exception as e:
|
|
print(f"❌ Failed to get all files: {e}")
|
|
return []
|
|
|
|
def get_stats(self) -> dict:
|
|
"""Get database statistics"""
|
|
try:
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Total files
|
|
cursor.execute('SELECT COUNT(*) FROM file_expiry')
|
|
total_files = cursor.fetchone()[0]
|
|
|
|
# Expired files
|
|
current_time = datetime.utcnow().isoformat() + 'Z'
|
|
cursor.execute('SELECT COUNT(*) FROM file_expiry WHERE expires_at <= ?', (current_time,))
|
|
expired_files = cursor.fetchone()[0]
|
|
|
|
# Files by storage backend
|
|
cursor.execute('SELECT storage_backend, COUNT(*) FROM file_expiry GROUP BY storage_backend')
|
|
by_backend = dict(cursor.fetchall())
|
|
|
|
conn.close()
|
|
|
|
return {
|
|
'total_files': total_files,
|
|
'expired_files': expired_files,
|
|
'active_files': total_files - expired_files,
|
|
'by_backend': by_backend
|
|
}
|
|
except Exception as e:
|
|
print(f"❌ Failed to get stats: {e}")
|
|
return {}
|
|
|
|
# URL Shortener Methods
|
|
def add_redirect(self, short_code, target_url, expires_at=None, created_by_ip=None, notes=None):
|
|
"""Add a URL redirect"""
|
|
try:
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
created_at = datetime.utcnow().isoformat() + 'Z'
|
|
|
|
cursor.execute('''
|
|
INSERT INTO url_redirects
|
|
(short_code, target_url, created_at, expires_at, created_by_ip, notes)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
''', (short_code, target_url, created_at, expires_at, created_by_ip, notes))
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
print(f"✅ Added redirect: {short_code} -> {target_url}")
|
|
return True
|
|
except sqlite3.IntegrityError:
|
|
print(f"❌ Short code already exists: {short_code}")
|
|
return False
|
|
except Exception as e:
|
|
print(f"❌ Failed to add redirect: {e}")
|
|
return False
|
|
|
|
def get_redirect(self, short_code):
|
|
"""Get redirect target URL and increment click count"""
|
|
try:
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Check if redirect exists and is active
|
|
cursor.execute('''
|
|
SELECT target_url, expires_at, is_active
|
|
FROM url_redirects
|
|
WHERE short_code = ? AND is_active = 1
|
|
''', (short_code,))
|
|
|
|
result = cursor.fetchone()
|
|
if not result:
|
|
conn.close()
|
|
return None
|
|
|
|
target_url, expires_at, is_active = result
|
|
|
|
# Check if expired
|
|
if expires_at:
|
|
current_time = datetime.utcnow().isoformat() + 'Z'
|
|
if expires_at <= current_time:
|
|
conn.close()
|
|
return None
|
|
|
|
# Increment click count
|
|
cursor.execute('''
|
|
UPDATE url_redirects
|
|
SET click_count = click_count + 1
|
|
WHERE short_code = ?
|
|
''', (short_code,))
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
return target_url
|
|
|
|
except Exception as e:
|
|
print(f"❌ Failed to get redirect: {e}")
|
|
return None
|
|
|
|
def disable_redirect(self, short_code):
|
|
"""Disable a redirect (for abuse/takedown)"""
|
|
try:
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute('''
|
|
UPDATE url_redirects
|
|
SET is_active = 0
|
|
WHERE short_code = ?
|
|
''', (short_code,))
|
|
|
|
rows_affected = cursor.rowcount
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
if rows_affected > 0:
|
|
print(f"✅ Disabled redirect: {short_code}")
|
|
return True
|
|
else:
|
|
print(f"❌ Redirect not found: {short_code}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
print(f"❌ Failed to disable redirect: {e}")
|
|
return False
|
|
|
|
def list_redirects(self, limit=100):
|
|
"""List recent redirects for admin purposes"""
|
|
try:
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute('''
|
|
SELECT short_code, target_url, created_at, expires_at,
|
|
click_count, is_active, created_by_ip
|
|
FROM url_redirects
|
|
ORDER BY created_at DESC
|
|
LIMIT ?
|
|
''', (limit,))
|
|
|
|
redirects = []
|
|
for row in cursor.fetchall():
|
|
redirects.append({
|
|
'short_code': row[0],
|
|
'target_url': row[1],
|
|
'created_at': row[2],
|
|
'expires_at': row[3],
|
|
'click_count': row[4],
|
|
'is_active': bool(row[5]),
|
|
'created_by_ip': row[6]
|
|
})
|
|
|
|
conn.close()
|
|
return redirects
|
|
|
|
except Exception as e:
|
|
print(f"❌ Failed to list redirects: {e}")
|
|
return []
|