repo-updater/repo.py
2024-09-19 22:36:34 +08:00

747 lines
24 KiB
Python

#!/usr/bin/python3
# -*- coding=utf-8
import os
import sys
import json
import time
import logging
import argparse
import threading
import subprocess
from stat import S_ISREG
from pyalpm import Handle
from pgpy import PGPMessage, PGPKey
from qcloud_cos import CosS3Client, CosConfig, CosClientError, CosServiceError
from tencentcloud.common.credential import Credential
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
from tencentcloud.cdn.v20180606.cdn_client import CdnClient
from tencentcloud.cdn.v20180606.models import PurgeUrlsCacheRequest, PushUrlsCacheRequest
from http.server import ThreadingHTTPServer, SimpleHTTPRequestHandler
from http import HTTPStatus, HTTPMethod
class RepoContext:
config: dict = {}
gpg: dict[PGPKey] = {}
signer: PGPKey = None
cred: Credential = None
cdn: CdnClient = None
cos_client: CosS3Client = None
cos_config: CosConfig = None
pyalpm: Handle = Handle("/", "/var/lib/pacman")
lock: threading.Lock = threading.Lock()
repo_data: RepoContext = RepoContext()
class HTTPRequestHandler(SimpleHTTPRequestHandler):
def write_error(self, response, length: int, headers: dict, code: HTTPStatus = HTTPStatus.OK):
self.send_response(code)
if length > 0:
self.send_header("Content-Length", str(length))
for key, value in headers.items():
self.send_header(key, value)
self.end_headers()
if length > 0:
self.wfile.write(response)
self.wfile.flush()
def write_text(self, response: str, code: HTTPStatus = HTTPStatus.OK):
headers = {'Content-Type': "text/plain"}
self.write_error(bytes(response, "UTF-8"), len(response), headers, code)
def write_json(self, response, code: HTTPStatus = HTTPStatus.OK):
headers = {'Content-Type': "application/json"}
data = json.dumps(response)
self.write_error(bytes(data, "UTF-8"), len(data), headers, code)
def write_html(self, response: str, code: HTTPStatus = HTTPStatus.OK):
headers = {'Content-Type': "text/html"}
self.write_error(bytes(response, "UTF-8"), len(response), headers, code)
@staticmethod
def cleanup_folder() -> tuple:
global repo_data
size = 0
count = 0
try:
now = time.time()
expires = now - 600
path = repo_data.config["upload"]
with repo_data.lock:
for f in os.listdir(path):
full = os.path.join(path, f)
st = os.stat(full)
if not S_ISREG(st.st_mode):
continue
if st.st_mtime < expires:
os.remove(full)
logging.info("cleanup %s" % f)
continue
size = size + st.st_size
count = count + 1
except BaseException as e:
logging.exception(e)
return size, count
def check_folder(self) -> bool:
global repo_data
size, count = self.cleanup_folder()
if count > repo_data.config["upload_files"] or size >= repo_data.config["upload_size"]:
self.write_text("Too many files in upload\n", HTTPStatus.TOO_MANY_REQUESTS)
return False
return True
@staticmethod
def check_msg_own(message: PGPMessage) -> list[PGPKey]:
found = []
for sign in message.signatures:
logging.info("signature signer id %s" % sign.signer)
for k, v in repo_data.gpg.items():
if sign.signer in v.fingerprint.keyid:
key = v.fingerprint
for u in v.userids:
key = u.userid
logging.info("found signer %s" % key)
found.append(v)
return found
@staticmethod
def verify_package(data: bytes, message: PGPMessage, signers: list) -> bool:
success = False
for sign in message.signatures:
for signer in signers:
if signer.verify(data, sign):
logging.info("verify package signature %s successful" % sign.signer)
success = True
return success
@staticmethod
def try_download(key, dest, bucket=None):
global repo_data
success = False
if bucket is None:
bucket = repo_data.config["bucket"]
for i in range(0, 10):
try:
repo_data.cos_client.download_file(
Bucket=bucket, Key=key, DestFilePath=dest
)
success = True
break
except CosClientError or CosServiceError as e:
logging.exception(e)
if not success:
raise Exception("download %s failed" % key)
@staticmethod
def try_upload(key, src, bucket=None):
global repo_data
success = False
if bucket is None:
bucket = repo_data.config["bucket"]
for i in range(0, 10):
try:
repo_data.cos_client.upload_file(
Bucket=bucket, Key=key, LocalFilePath=src
)
success = True
break
except CosClientError or CosServiceError as e:
logging.exception(e)
if not success:
raise Exception("upload %s failed" % key)
@staticmethod
def do_copy(src, dest, bucket=None):
global repo_data
success = False
if bucket is None:
bucket = repo_data.config["bucket"]
try:
repo_data.cos_client.copy_object(
Bucket=bucket, Key=dest,
CopySource={
'Bucket': bucket, 'Key': src,
'Region': repo_data.config["region"]
}
)
success = True
except BaseException:
pass
return success
@staticmethod
def is_exists(key, bucket=None):
global repo_data
if bucket is None:
bucket = repo_data.config["bucket"]
try:
return repo_data.cos_client.object_exists(
Bucket=bucket, Key=key,
)
except BaseException:
return False
pass
@staticmethod
def do_delete(key, bucket=None):
global repo_data
success = False
if bucket is None:
bucket = repo_data.config["bucket"]
try:
repo_data.cos_client.delete_object(Bucket=bucket, Key=key)
success = True
except BaseException:
pass
return success
def try_download_db(self, arch, file):
global repo_data
self.try_download(
"arch/%s/%s" % (arch, file),
"%s/%s" % (repo_data.config["workspace"], file)
)
def try_upload_db(self, arch, file):
global repo_data
self.try_upload(
"arch/%s/%s" % (arch, file),
"%s/%s" % (repo_data.config["workspace"], file)
)
def do_copy_db(self, arch, src, dest):
global repo_data
self.do_copy(
"arch/%s/%s" % (arch, src),
"arch/%s/%s" % (arch, dest),
)
@staticmethod
def format_size(size: int) -> str:
units = ['Bytes', 'KB', 'MB', 'GB', 'TB', 'GB']
unit = units[0]
for unit in units:
if size < 1024:
break
size /= 1024
return "{:.2f} {}".format(size, unit)
def proc_update_db(self, data):
global repo_data
if not isinstance(data, dict):
self.write_text("Request not json\n", HTTPStatus.BAD_REQUEST)
return
if 'arch' not in data or data["arch"] not in repo_data.config["arch"]:
self.write_text("Bad architecture\n", HTTPStatus.BAD_REQUEST)
return
if 'target' not in data or len(data["target"]) <= 0 or '/' in data["target"] or '\\' in data["target"]:
self.write_text("Bad filename\n", HTTPStatus.BAD_REQUEST)
return
now = time.time()
repo = repo_data.config["repo"]
work = repo_data.config["workspace"] + "/"
db = "%s.db" % repo
db_sig = "%s.sig" % db
db_tar = "%s.tar.gz" % db
db_tar_sig = "%s.sig" % db_tar
files = "%s.files" % repo
files_sig = "%s.sig" % files
files_tar = "%s.tar.gz" % files
files_tar_sig = "%s.sig" % files_tar
all_files = [db, db_sig, db_tar, db_tar_sig, files, files_sig, files_tar, files_tar_sig]
path = "%s/%s" % (repo_data.config["upload"], data["target"])
sign = "%s/%s.sig" % (repo_data.config["upload"], data["target"])
try:
with repo_data.lock:
if not os.path.exists(path) or not os.path.exists(sign):
self.write_text("Target not exists\n", HTTPStatus.GONE)
return
for file in all_files:
target = work + file
if os.path.exists(target):
os.remove(target)
if self.is_exists("arch/%s/%s" % (data["arch"], data["target"])) or \
self.is_exists("arch/%s/%s.sig" % (data["arch"], data["target"])):
self.write_text("Target already exists\n", HTTPStatus.CONFLICT)
return
logging.debug("verifying package")
try:
os.utime(sign, (now, now))
with open(sign, "rb") as f:
binary = f.read()
message = PGPMessage.from_blob(binary)
assert message
except BaseException as e:
logging.exception(e)
self.write_text("Bad signature\n", HTTPStatus.NOT_ACCEPTABLE)
return
found = self.check_msg_own(message)
if len(found) <= 0:
logging.info("package signer are not in allow list")
self.write_text("Package signer not in allow list\n", HTTPStatus.FORBIDDEN)
return
try:
os.utime(sign, (now, now))
with open(path, "rb") as f:
binary = f.read()
if len(binary) <= 0:
raise Exception("read data mismatch")
if not self.verify_package(binary, message, found):
logging.info("verify package signature failed")
self.write_text("Bad package signature\n", HTTPStatus.FORBIDDEN)
return
except BaseException as e:
logging.exception(e)
self.write_text("Verify package failed\n", HTTPStatus.NOT_ACCEPTABLE)
return
pkg = repo_data.pyalpm.load_pkg(path)
logging.info("package name: %s" % pkg.name)
logging.info("package version: %s" % pkg.version)
logging.info("package architecture: %s" % pkg.arch)
logging.info("package packager: %s" % pkg.packager)
logging.info("package size: %s" % self.format_size(pkg.size))
logging.info("package installed size: %s" % self.format_size(pkg.isize))
logging.info("package url: %s" % pkg.url)
name = "%s-%s-%s" % (pkg.name, pkg.version, pkg.arch)
if not any(name + ext == data["target"] for ext in repo_data.config["pkg_exts"]):
self.write_text("Bad package name\n", HTTPStatus.NOT_ACCEPTABLE)
return
if not any(pkg.packager in uid.userid for signer in found for uid in signer.userids):
self.write_text("Packager mismatch with PGP userid\n", HTTPStatus.NOT_ACCEPTABLE)
return
if data["arch"] != "any" and data["arch"] not in repo_data.config["arch"]:
self.write_text("Target package architecture unsupported\n", HTTPStatus.NOT_ACCEPTABLE)
return
rst = repo_data.config["restrict_pkg"]
if pkg.name in rst and not any(rst[pkg.name] == fp.fingerprint for fp in found):
self.write_text("Target package name is in restricted list\n", HTTPStatus.FORBIDDEN)
return
logging.info("verify package done")
logging.debug("downloading database")
self.try_download_db(data["arch"], db_tar)
self.try_download_db(data["arch"], files_tar)
logging.info("downloaded database")
logging.debug("updating database")
subprocess.run(
["repo-add", work + db_tar, path],
timeout=30, check=True, cwd=work
)
logging.info("update database done")
logging.debug("signing database")
with repo_data.signer.unlock(repo_data.config["signer_passphrase"]):
with open(work + db_tar, "rb") as f:
cont = f.read()
msg = repo_data.signer.sign(cont)
with open(work + db_tar_sig, "wb") as w:
w.write(bytes(msg))
with open(work + files_tar, "rb") as f:
cont = f.read()
msg = repo_data.signer.sign(cont)
with open(work + files_tar_sig, "wb") as w:
w.write(bytes(msg))
logging.info("sign database done")
logging.debug("uploading package")
self.try_upload("arch/%s/%s" % (data["arch"], data["target"]), path)
os.utime(path, (now, now))
self.try_upload("arch/%s/%s.sig" % (data["arch"], data["target"]), sign)
os.utime(sign, (now, now))
logging.info("uploaded package")
logging.debug("removing old databases")
for file in all_files:
target = "arch/%s/%s" % (data["arch"], file)
target_old = "%s.old" % target
if self.is_exists(target):
if self.is_exists(target_old):
self.do_delete(target_old)
self.do_copy(target, target_old)
self.do_delete(target)
logging.info("removed old databases")
logging.debug("uploading database")
self.try_upload_db(data["arch"], db_tar)
self.try_upload_db(data["arch"], db_tar_sig)
self.try_upload_db(data["arch"], files_tar)
self.try_upload_db(data["arch"], files_tar_sig)
self.do_copy_db(data["arch"], db_tar, db)
self.do_copy_db(data["arch"], db_tar_sig, db_sig)
self.do_copy_db(data["arch"], files_tar, files)
self.do_copy_db(data["arch"], files_tar_sig, files_sig)
logging.info("uploaded database")
if "cdn" in repo_data.config:
logging.debug("purging cdn cache")
domain = repo_data.config["cdn"]
urls = [
"https://%s/arch/%s/%s" % (domain, data["arch"], data["target"]),
"https://%s/arch/%s/%s.sig" % (domain, data["arch"], data["target"]),
]
for file in all_files:
urls.append("https://%s/arch/%s/%s" % (domain, data["arch"], file))
for url in urls:
print("new url: %s" % url)
try:
req = PurgeUrlsCacheRequest()
req.Urls = urls
repo_data.cdn.PurgeUrlsCache(req)
except BaseException as e:
logging.exception(e)
try:
req = PushUrlsCacheRequest()
req.Urls = urls
repo_data.cdn.PushUrlsCache(req)
except BaseException as e:
logging.exception(e)
logging.info("purged cdn cache")
self.write_text("Database updated\n", HTTPStatus.OK)
except BaseException as e:
logging.exception(e)
self.write_text("Error while updating database\n", HTTPStatus.INTERNAL_SERVER_ERROR)
def proc_sign(self, length: int, filename: str):
global repo_data
path = "%s/%s" % (repo_data.config["upload"], filename)
m = repo_data.config["max_sign_file"]
if length > m:
self.write_text(
"Signature too large (maximum %s)\n" % self.format_size(m),
HTTPStatus.REQUEST_ENTITY_TOO_LARGE
)
return
try:
data = self.rfile.read(length)
assert data and len(data) == length
message = PGPMessage.from_blob(data)
assert message
except BaseException as e:
logging.exception(e)
self.write_text("Bad signature\n", HTTPStatus.NOT_ACCEPTABLE)
return
logging.info("process pgp %s" % filename)
if len(self.check_msg_own(message)) <= 0:
logging.info("all signer are not in allow list")
self.write_text("Target signer not in allow list\n", HTTPStatus.FORBIDDEN)
return
with repo_data.lock:
if os.path.exists(path):
os.remove(path)
with open(path, "wb") as f:
f.write(data)
logging.info("saved %s size %s" % (path, self.format_size(length)))
self.write_text("Signature saved\n", HTTPStatus.CREATED)
def proc_pkgs(self, length: int, filename: str):
global repo_data
now = time.time()
path = "%s/%s" % (repo_data.config["upload"], filename)
sign = "%s/%s.sig" % (repo_data.config["upload"], filename)
m = repo_data.config["max_pkg_file"]
if length >= m:
self.write_text(
"Package too large (maximum %s)\n" % self.format_size(m),
HTTPStatus.REQUEST_ENTITY_TOO_LARGE
)
return
with repo_data.lock:
if not os.path.exists(sign):
self.write_text("You need upload signature first\n", HTTPStatus.NOT_ACCEPTABLE)
return
try:
os.utime(sign, (now, now))
with open(sign, "rb") as f:
data = f.read()
message = PGPMessage.from_blob(data)
assert message
except BaseException as e:
logging.exception(e)
self.write_text("Bad signature\n", HTTPStatus.NOT_ACCEPTABLE)
return
found = self.check_msg_own(message)
if len(found) <= 0:
logging.info("package signer are not in allow list")
self.write_text("Package signer not in allow list\n", HTTPStatus.FORBIDDEN)
return
data = self.rfile.read(length)
if len(data) != length:
raise Exception("read data mismatch")
if not self.verify_package(data, message, found):
logging.info("verify package signature failed")
self.write_text("Bad package signature\n", HTTPStatus.FORBIDDEN)
return
if os.path.exists(path):
os.remove(path)
with open(path, "wb") as f:
f.write(data)
logging.info("saved %s size %s bytes" % (path, self.format_size(length)))
self.write_text("File saved\n", HTTPStatus.CREATED)
def proc_get_pkgs(self, filename: str):
path = "%s/%s" % (repo_data.config["upload"], filename)
try:
if not os.path.exists(path):
self.write_text("404 Not Found\n", HTTPStatus.NOT_FOUND)
return
with open(path, "rb") as f:
st = os.fstat(f.fileno())
if not S_ISREG(st.st_mode):
self.write_text("Bad file\n", HTTPStatus.FORBIDDEN)
return
self.send_response(HTTPStatus.OK)
self.send_header("Content-Length", str(st.st_size))
self.send_header("Content-Type", "application/octet-stream")
self.send_header("Content-Disposition", "attachment; filename=\"%s\"" % filename)
self.end_headers()
self.copyfile(f, self.wfile)
except BaseException as e:
logging.exception(e)
self.write_text("Error while delete package\n", HTTPStatus.INTERNAL_SERVER_ERROR)
def proc_get_page(self):
global repo_data
if not isinstance(repo_data.config["update_page"], str):
self.write_text("404 Not Found\n", HTTPStatus.NOT_FOUND)
return
if not os.path.exists(repo_data.config["update_page"]):
self.write_text("404 Not Found\n", HTTPStatus.NOT_FOUND)
return
with open(repo_data.config["update_page"], "r") as f:
self.write_html(f.read())
def proc_api(self, method):
global repo_data
data = None
if method == HTTPMethod.POST:
if 'Content-Length' not in self.headers:
self.write_text("Miss Content-Length\n", HTTPStatus.LENGTH_REQUIRED)
return
length = int(self.headers['Content-Length'])
if length >= 0x8000000:
self.write_text("File too large\n", HTTPStatus.REQUEST_ENTITY_TOO_LARGE)
return
data = self.rfile.read(length)
if "Content-Type" in self.headers:
match self.headers['Content-Type']:
case "text/plain": data = data.decode("UTF-8")
case "application/json": data = json.loads(data.decode("UTF-8"))
if self.path == "/api/info":
self.write_json({
"repo": repo_data.config["repo"],
"arch": repo_data.config["arch"],
"sign_exts": repo_data.config["sign_exts"],
"pkg_exts": repo_data.config["pkg_exts"],
"max_sign_file": repo_data.config["max_sign_file"],
"max_pkg_file": repo_data.config["max_pkg_file"],
"upload_size": repo_data.config["upload_size"],
"upload_files": repo_data.config["upload_files"],
})
return
if self.path == "/api/update":
if method != HTTPMethod.POST:
self.write_text("Need request json\n", HTTPStatus.BAD_REQUEST)
return
self.proc_update_db(data)
return
self.write_text("404 Not Found\n", HTTPStatus.NOT_FOUND)
def do_DELETE(self):
if not self.path.startswith("/"):
self.write_text("404 Not Found\n", HTTPStatus.NOT_FOUND)
return
filename = self.path[1:]
if len(filename) <= 0 or '/' in filename or '\\' in filename:
self.write_text("Invalid filename\n", HTTPStatus.FORBIDDEN)
return
if not self.check_folder():
return
if not filename.endswith(tuple(repo_data.config["sign_exts"])) and \
not filename.endswith(tuple(repo_data.config["pkg_exts"])):
self.write_text(
"Only %s, %s accepts\n" % (
repo_data.config["sign_exts"],
repo_data.config["pkg_exts"],
), HTTPStatus.FORBIDDEN
)
return
try:
path = "%s/%s" % (repo_data.config["upload"], filename)
with repo_data.lock:
if os.path.exists(path):
os.remove(path)
logging.info("removed %s" % path)
self.write_text("Target deleted\n", HTTPStatus.OK)
else:
self.write_text("Target not found\n", HTTPStatus.OK)
except BaseException as e:
logging.exception(e)
self.write_text("Error while delete package\n", HTTPStatus.INTERNAL_SERVER_ERROR)
def do_POST(self):
if self.path.startswith("/api/"):
self.proc_api(HTTPMethod.POST)
return
self.write_text("404 Not Found\n", HTTPStatus.NOT_FOUND)
def do_GET(self):
if self.path == "/" or self.path == "/index.html":
self.proc_get_page()
return
if self.path.startswith("/api/"):
self.proc_api(HTTPMethod.GET)
return
filename = self.path[1:]
if filename == "index.html":
self.proc_get_page()
return
if not self.check_folder():
return
if filename.endswith(tuple(repo_data.config["pkg_exts"])) or \
filename.endswith(tuple(repo_data.config["sign_exts"])):
if len(filename) <= 0 or '/' in filename or '\\' in filename:
self.write_text("Invalid filename\n", HTTPStatus.FORBIDDEN)
return
self.proc_get_pkgs(filename)
self.write_text("404 Not Found\n", HTTPStatus.NOT_FOUND)
def do_PUT(self):
if not self.path.startswith("/"):
self.write_text("404 Not Found\n", HTTPStatus.NOT_FOUND)
return
if 'Content-Length' not in self.headers:
self.write_text("Miss Content-Length\n", HTTPStatus.LENGTH_REQUIRED)
return
length = int(self.headers['Content-Length'])
filename = self.path[1:]
if len(filename) <= 0 or '/' in filename or '\\' in filename:
self.write_text("Invalid filename\n", HTTPStatus.FORBIDDEN)
return
logging.info("target file %s" % filename)
if not self.check_folder():
return
try:
if filename.endswith(tuple(repo_data.config["sign_exts"])):
self.proc_sign(length, filename)
elif filename.endswith(tuple(repo_data.config["pkg_exts"])):
self.proc_pkgs(length, filename)
else:
self.write_text(
"Only %s, %s accepts\n" % (
repo_data.config["sign_exts"],
repo_data.config["pkg_exts"],
), HTTPStatus.FORBIDDEN
)
except BaseException as e:
logging.exception(e)
self.write_text("Error while process content\n", HTTPStatus.INTERNAL_SERVER_ERROR)
def load_key_file(filename):
in_key = False
key_text = ""
with open(filename, "rb") as f:
while True:
line = f.readline()
if not line:
break
line = line.decode("UTF-8")
if line == "-----BEGIN PGP PUBLIC KEY BLOCK-----\n":
if in_key:
logging.warning("unexpected key begin")
in_key = True
key_text = str(line)
elif line == "-----END PGP PUBLIC KEY BLOCK-----\n":
if not in_key:
logging.warning("unexpected key end")
key_text = key_text + str(line)
keys = PGPKey.from_blob(key_text)
load_key(keys[0])
in_key = False
key_text = ""
elif line == "\n":
continue
elif len(line) > 0 and in_key:
key_text = key_text + str(line)
if in_key:
logging.warning("unexpected keys eof")
def load_key(key: PGPKey):
if not key.is_public:
raise Exception("try to load private key")
if key.is_expired:
logging.warning("key %s is expires" % key.fingerprint)
logging.info("found %s" % key.fingerprint)
for uid in key.userids:
logging.info("uid %s" % uid.userid)
repo_data.gpg[key.fingerprint] = key
def main(argv: list) -> int:
global repo_data
prs = argparse.ArgumentParser("Arch Linux Repo Over COS Updater API")
prs.add_argument("-f", "--config-file", help="Set config file", default="config.json")
args = prs.parse_args(argv[1:])
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
if args.config_file:
with open(args.config_file, "r") as f:
repo_data.config = json.load(f)
if isinstance(repo_data.config["keyring"], str):
load_key_file(repo_data.config["keyring"])
elif isinstance(repo_data.config["keyring"], list):
for file in repo_data.config["keyring"]:
load_key_file(file)
else:
raise Exception("no any keyring file found")
if 'signer' not in repo_data.config:
raise Exception("no signer key file found")
repo_data.signer = PGPKey.from_file(repo_data.config["signer"])[0]
if repo_data.signer.is_public:
raise Exception("signer not a private key")
if not repo_data.signer.is_protected:
raise Exception("private key unprotected")
logging.info("loaded %d keys" % len(repo_data.gpg))
repo_data.cred = Credential(
secret_id=repo_data.config["secret_id"],
secret_key=repo_data.config["secret_key"],
)
repo_data.cdn = CdnClient(
repo_data.cred,
repo_data.config["region"]
)
repo_data.cos_config = CosConfig(
Region=repo_data.config["region"],
SecretId=repo_data.config["secret_id"],
SecretKey=repo_data.config["secret_key"],
Scheme="https",
)
repo_data.cos_client = CosS3Client(repo_data.cos_config)
if not os.path.exists(repo_data.config["workspace"]):
os.mkdir(repo_data.config["workspace"])
if not os.path.exists(repo_data.config["upload"]):
os.mkdir(repo_data.config["upload"])
listen = (repo_data.config["bind"], repo_data.config["port"])
with ThreadingHTTPServer(listen, HTTPRequestHandler) as httpd:
print(
f"Serving HTTP on {listen[0]} port {listen[1]} "
f"(http://{listen[0]}:{listen[1]}/) ..."
)
try:
httpd.serve_forever()
except KeyboardInterrupt:
print("\nKeyboard interrupt received, exiting.")
httpd.shutdown()
return 0
if __name__ == '__main__':
sys.exit(main(sys.argv))