Files
assets-bot/src/rabbit.py

89 lines
2.4 KiB
Python

import logging
import datetime
from typing import Optional
import pika
from pika import BlockingConnection
from pika.channel import Channel
from pika.credentials import PlainCredentials
from pika.spec import BasicProperties
from src.settings import (
RABBIT_HOST,
RABBIT_PORT,
RABBIT_CREDENTIALS,
RABBIT_TASK_QUEUE,
RABBIT_REPLY_QUEUE,
MESSAGE_TTL,
)
from src.messages import AnalyzeTask, AnalyzeResponse
_logger = logging.getLogger(__name__)
_logger.addHandler(logging.NullHandler())
CONNECTION_ATTEMPTS = 3
QUEUE_MAX_PRIORITY = 4
RABBIT_MESSAGE_TTL = str(int(MESSAGE_TTL * 1000))
def get_connection() -> BlockingConnection:
_logger.info(f"connecting to RabbitMQ at {RABBIT_HOST}:{RABBIT_PORT}")
connection = pika.BlockingConnection(
pika.ConnectionParameters(
host=RABBIT_HOST,
port=RABBIT_PORT,
credentials=PlainCredentials(*RABBIT_CREDENTIALS),
connection_attempts=CONNECTION_ATTEMPTS,
)
)
return connection
def get_channel(connection: Optional[BlockingConnection] = None) -> Channel:
if connection:
_connection = connection
else:
_connection = get_connection()
channel = _connection.channel()
base_queue_params = dict(
durable=True,
)
if QUEUE_MAX_PRIORITY > 1:
base_queue_params["arguments"] = {"x-max-priority": QUEUE_MAX_PRIORITY}
channel.queue_declare(queue=RABBIT_REPLY_QUEUE, **base_queue_params)
channel.queue_declare(queue=RABBIT_TASK_QUEUE, **base_queue_params)
channel.basic_qos(prefetch_count=1)
return channel
def send_task(channel: Channel, data: bytes):
channel.basic_publish(
exchange="",
routing_key=RABBIT_TASK_QUEUE,
body=data,
properties=BasicProperties(expiration=RABBIT_MESSAGE_TTL, reply_to=RABBIT_REPLY_QUEUE),
)
def send_reply(channel, data: bytes):
channel.basic_publish(
exchange="",
routing_key=RABBIT_REPLY_QUEUE,
body=data,
properties=BasicProperties(expiration=RABBIT_MESSAGE_TTL),
)
def consume_task(channel, queue: str, timeout=None, auto_ack=True, max_count=None):
for method, properties, body in channel.consume(
queue, auto_ack=auto_ack, inactivity_timeout=timeout
):
yield body
if max_count and method.delivery_tag == max_count:
break
requeued_messages = channel.cancel()
_logger.info(f"Requeued {requeued_messages} messages")