// main.rs // TODO: Add other/better rate limit...
// main.rs // TODO: Add other/better rate limit handler? mod commands; use commands::Command; use dotenv::dotenv; use serde::{Deserialize, Serialize}; use serenity::{ async_trait, cache::Cache, http, model::{channel::Message, gateway::Ready, guild::Guild, id::GuildId}, prelude::*, Client, }; use std::{net::Shutdown, sync::Arc}; use std::{env, io::Read}; use strum::IntoEnumIterator; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio::sync::broadcast;
use crate::{database::Database, extra::ExtraDiscord}; mod databasespec;
const CHANNEL_BUFFER_SIZE: usize = 32;
#[cfg(feature = "pgdatabase")] mod database { include!("pgdatabase.rs"); }
#[cfg(not(feature = "pgdatabase"))] mod database { include!("jsondatabase.rs"); }
#[cfg(feature = "extra")] mod extra; #[cfg(not(feature = "extra"))] mod extra { use serenity::{client::Context, model::prelude::Ready};
pub struct ExtraDiscord {} impl ExtraDiscord { pub fn start() -> Self { Self {} } pub fn poststart_hook(_ctx: Context, _ready: Ready) {} pub async fn on_message( &self, _ctx: serenity::prelude::Context, _msg: serenity::model::channel::Message, ) { } }}
#[derive(Default, Debug)] enum AuthenticatedStreamState { Authenticated, Terminate, #[default] None }
struct AppState { database: Database, authenticated_stream: Arc<RwLock<AuthenticatedStreamState>>, website_tx: tokio::sync::broadcast::Sender<Vec<u8>>, website_rx: tokio::sync::broadcast::Receiver<Vec<u8>>, } impl Default for AppState { fn default() -> Self { let (website_tx, website_rx) = broadcast::channel::<Vec<u8>>(CHANNEL_BUFFER_SIZE); Self { database: Default::default(), website_tx, website_rx, authenticated_stream: Arc::new(RwLock::new(AuthenticatedStreamState::default())), } } } // impl AppState { // pub fn poststart_hook(_ctx: Context, _ready: Ready) {
// } // pub async fn on_message(&self, _ctx: serenity::prelude::Context, _msg: serenity::model::channel::Message) {} // } #[derive(Serialize, Deserialize)] struct WebsiteUserRequest { id: u64, }
#[derive(Serialize, Deserialize)] struct WebsiteUserPresentResponse { id: u64, }
struct Handler { extra: extra::ExtraDiscord, state: Arc<tokio::sync::RwLock<AppState>>, }
impl Handler { fn new() -> Self { Self { extra: extra::ExtraDiscord::start(), state: Arc::new(tokio::sync::RwLock::new(AppState::default())), } } }
#[derive(Deserialize, Serialize)] struct AuthenticateRequest { password: String }
async fn handle_stream( arc_state: Arc<RwLock<AppState>>, ctx: Context, cache: Arc<Cache>, mut rx: tokio::sync::broadcast::Receiver<Vec<u8>>, mut tx: tokio::sync::broadcast::Sender<Vec<u8>>, ) { let mut state = arc_state.write().await;
let all_guilds = ctx.cache.guilds(); let main_guild_option = all_guilds .iter() .find(|guild_id: &&GuildId| u64::from(**guild_id) == 1479002261811236984 as u64); //println!("In handle stream"); while let Ok(bytes) = rx.recv().await { //println!("Got some data"); match serde_json::from_slice::<AuthenticateRequest>(&bytes) { Ok(request) => { if request.password == "test".to_string(){ println!("Password is correct for stream"); let mut authenticated_stream = state.authenticated_stream.write().await; *authenticated_stream = AuthenticatedStreamState::Authenticated; } else { println!("Password is incorrect for stream"); } }, Err(_) => {}, } let mut authenticated_stream = state.authenticated_stream.write().await; if !matches!(*authenticated_stream, AuthenticatedStreamState::Authenticated) { println!("Terminating"); *authenticated_stream = AuthenticatedStreamState::Terminate; } drop(authenticated_stream); match serde_json::from_slice::<WebsiteUserRequest>(&bytes) { Ok(request) => { // let request = WebsiteUserRequestFinal { id: intial_request.id.parse().unwrap() }; // if cache.guilds().iter().any(|guild_id| guild_id == request){ // } if let Some(main_guild) = main_guild_option { // println!("Found the main guild"); if let Ok(members) = main_guild.members(&ctx.http, None, None).await { // println!("Get the members list"); if let Some(member) = members .iter() .find(|member| u64::from(member.user.id) == request.id.clone()) { println!("Found the member"); let response = WebsiteUserPresentResponse { id: u64::from(member.user.id), }; let bytes: Vec<u8> = serde_json::to_vec((&response).into()).unwrap(); let result = tx.send(bytes); //println!("{:#?}", result); } else { // println!("Could not find this member"); } } else { // println!("Could not get member list"); } } } /* your logic */ Err(_) => {}, } tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; }}
use tokio_util::sync::CancellationToken;
#[async_trait] impl EventHandler for Handler { async fn ready(&self, ctx: Context, ready: Ready) { println!("Bot is ready! Logged in as {}", ready.user.name); ExtraDiscord::poststart_hook(ctx.clone(), ready);
let inner_ctx = ctx.clone(); let inner_cache = inner_ctx.cache.clone(); let (mut to_website_tx, mut to_website_rx) = broadcast::channel::<Vec<u8>>(CHANNEL_BUFFER_SIZE); let (mut from_website_tx, mut from_website_rx) = broadcast::channel::<Vec<u8>>(CHANNEL_BUFFER_SIZE); let mut temp_state = self.state.write().await; temp_state.website_tx = to_website_tx.clone(); temp_state.website_rx = to_website_rx.resubscribe(); drop(temp_state); //let inner_website_rx: broadcast::Receiver<Vec<u8>> = ; //let inner_ctx = ctx.clone(); let inner_state = self.state.clone(); let inner_website_tx = to_website_tx.clone(); let inner_website_rx = from_website_rx.resubscribe(); tokio::spawn(async move { handle_stream( inner_state.clone(), inner_ctx, inner_cache, inner_website_rx, inner_website_tx, ) .await; }); let state = self.state.read().await; let outer_authenticated_stream_state = Arc::clone(&state.authenticated_stream); drop(state); tokio::spawn(async move { let listener = TcpListener::bind("0.0.0.0:8086") .await .expect("Failed to bind"); //let state = self.state.read().await; loop { let token = CancellationToken::new(); let token_read = token.clone(); let token_write = token.clone(); let inner_authenticated_stream_state = outer_authenticated_stream_state.clone(); //let inner_authenticated_stream_state = authenticated_stream_state.clone(); if let Ok((mut socket, _)) = listener.accept().await { println!("Got a new connection"); let (mut read_half, mut write_half) = socket.into_split(); let mut buf = vec![0u8; 1024]; let mut inner_website_tx = from_website_tx.clone(); let mut inner_website_rx = to_website_rx.resubscribe(); let writeable_inner_authenticated_stream_state = inner_authenticated_stream_state.clone(); tokio::spawn(async move { loop { tokio::select! { _ = token_write.cancelled() => { break; } result = inner_website_rx.recv() => { match result { Ok(data) => { let _ = write_half.write(&data).await; } Err(_) => break, } } } } }); tokio::spawn(async move { let mut buf = vec![0u8; 1024]; loop { match read_half.read(&mut buf).await { Ok(0) | Err(_) => break, Ok(n) => { let _ = inner_website_tx.send(buf[..n].to_vec()); // if let Ok(output) = str::from_utf8(&buf[..n].to_vec()) { // println!("{:#?}", output); // } let writable_authenticated_stream_state = inner_authenticated_stream_state.write().await; //println!("{:#?}", *writable_authenticated_stream_state); if matches!(*writable_authenticated_stream_state, AuthenticatedStreamState::Terminate) { token_read.cancel(); break; } } } } }); } } }); // tokio::spawn(async move { // while let Some(data) = response_rx.recv().await { // if write_half.write_all(&data).await.is_err() { // break; // } // } // }); //println!("{:#?}", ctx.cache.guilds().len()); } async fn message(&self, ctx: Context, msg: Message) { if msg.content.starts_with("!") { for cmd in Command::iter() { if cmd.name() == msg.content[1..].to_string() { cmd.run(Arc::clone(&self.state), &ctx, &msg).await; return; } } } self.extra.on_message(ctx.clone(), msg.clone()).await; }}
#[tokio::main] async fn main() { dotenv().ok(); let args: Vec<String> = std::env::args().collect();
let token = std::env::var("DISCORD_BOT_TOKEN") .expect("Expected DISCORD_BOT_TOKEN environment variable in .env file"); let intents = GatewayIntents::DIRECT_MESSAGES | GatewayIntents::GUILD_MESSAGES | GatewayIntents::MESSAGE_CONTENT; let mut client = Client::builder(&token, intents) .event_handler(Handler::new()) .await .expect("Error creating client"); if let Err(why) = client.start().await { println!("Client error: {:?}", why); } else { println!("Discord bot is up!"); }}
fn get_env_var_or_arg<T>(env_var: &str, default: Option<T>) -> Option<T> where T: std::str::FromStr + Clone, { env::var(env_var) .ok() .and_then(|s| s.parse::<T>().ok()) .or(default) }