Add and prepare rust worker management system for file information processing and knowledge base framework
This commit is contained in:
parent
af82b71657
commit
da6ab3a782
12 changed files with 1402 additions and 251 deletions
|
|
@ -49,5 +49,17 @@ services:
|
|||
depends_on:
|
||||
- mysql
|
||||
|
||||
qdrant:
|
||||
image: qdrant/qdrant:latest
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "127.0.0.1:6333:6333"
|
||||
volumes:
|
||||
- qdrant-data:/qdrant/storage
|
||||
environment:
|
||||
- QDRANT__SERVICE__GRPC_PORT=6334
|
||||
# expose to rust-engine via service name 'qdrant'
|
||||
|
||||
volumes:
|
||||
mysql-data: # Renamed volume for clarity (optional but good practice)
|
||||
mysql-data: # Renamed volume for clarity (optional but good practice)
|
||||
qdrant-data:
|
||||
839
rust-engine/Cargo.lock
generated
839
rust-engine/Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -7,12 +7,19 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
tokio = { version = "1.38.0", features = ["full"] }
|
||||
warp = "0.3.7"
|
||||
warp = { version = "0.4.2", features = ["server", "multipart"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "mysql", "chrono"] }
|
||||
sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "mysql", "chrono", "uuid", "macros"] }
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.3"
|
||||
dotenvy = "0.15.7" # Switched from unmaintained 'dotenv'
|
||||
anyhow = "1.0"
|
||||
anyhow = "1.0"
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
reqwest = { version = "0.12.24", features = ["json", "rustls-tls"] }
|
||||
async-trait = "0.1"
|
||||
tokio-util = "0.7"
|
||||
futures-util = "0.3"
|
||||
lazy_static = "1.4"
|
||||
bytes = "1.4"
|
||||
|
|
|
|||
226
rust-engine/src/api.rs
Normal file
226
rust-engine/src/api.rs
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
use crate::gemini_client;
|
||||
use crate::vector_db::QdrantClient;
|
||||
use crate::storage;
|
||||
use anyhow::Result;
|
||||
use bytes::Buf;
|
||||
use futures_util::{StreamExt, TryStreamExt};
|
||||
use serde::Deserialize;
|
||||
use sqlx::{MySqlPool, Row};
|
||||
use warp::{multipart::FormData, Filter, Rejection, Reply};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeleteQuery {
|
||||
id: String,
|
||||
}
|
||||
|
||||
pub fn routes(pool: MySqlPool) -> impl Filter<Extract = impl Reply, Error = Rejection> + Clone {
|
||||
let pool_filter = warp::any().map(move || pool.clone());
|
||||
|
||||
// Upload file
|
||||
let upload = warp::path("files")
|
||||
.and(warp::post())
|
||||
.and(warp::multipart::form().max_length(50_000_000)) // 50MB per part default; storage is filesystem-backed
|
||||
.and(pool_filter.clone())
|
||||
.and_then(handle_upload);
|
||||
|
||||
// Delete file
|
||||
let delete = warp::path!("files" / "delete")
|
||||
.and(warp::get())
|
||||
.and(warp::query::<DeleteQuery>())
|
||||
.and(pool_filter.clone())
|
||||
.and_then(handle_delete);
|
||||
|
||||
// List files
|
||||
let list = warp::path!("files" / "list")
|
||||
.and(warp::get())
|
||||
.and(pool_filter.clone())
|
||||
.and_then(handle_list);
|
||||
|
||||
// Create query
|
||||
let create_q = warp::path!("query" / "create")
|
||||
.and(warp::post())
|
||||
.and(warp::body::json())
|
||||
.and(pool_filter.clone())
|
||||
.and_then(handle_create_query);
|
||||
|
||||
// Query status
|
||||
let status = warp::path!("query" / "status")
|
||||
.and(warp::get())
|
||||
.and(warp::query::<DeleteQuery>())
|
||||
.and(pool_filter.clone())
|
||||
.and_then(handle_query_status);
|
||||
|
||||
// Query result
|
||||
let result = warp::path!("query" / "result")
|
||||
.and(warp::get())
|
||||
.and(warp::query::<DeleteQuery>())
|
||||
.and(pool_filter.clone())
|
||||
.and_then(handle_query_result);
|
||||
|
||||
// Cancel
|
||||
let cancel = warp::path!("query" / "cancel")
|
||||
.and(warp::get())
|
||||
.and(warp::query::<DeleteQuery>())
|
||||
.and(pool_filter.clone())
|
||||
.and_then(handle_cancel_query);
|
||||
|
||||
upload.or(delete).or(list).or(create_q).or(status).or(result).or(cancel)
|
||||
}
|
||||
|
||||
async fn handle_upload(mut form: FormData, pool: MySqlPool) -> Result<impl Reply, Rejection> {
|
||||
// qdrant client
|
||||
let qdrant_url = std::env::var("QDRANT_URL").unwrap_or_else(|_| "http://qdrant:6333".to_string());
|
||||
let qdrant = QdrantClient::new(&qdrant_url);
|
||||
|
||||
while let Some(field) = form.try_next().await.map_err(|_| warp::reject())? {
|
||||
let name = field.name().to_string();
|
||||
let filename = field
|
||||
.filename()
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| format!("upload-{}", uuid::Uuid::new_v4()));
|
||||
|
||||
// Read stream of Buf into a Vec<u8>
|
||||
let data = field
|
||||
.stream()
|
||||
.map_ok(|mut buf| {
|
||||
let mut v = Vec::new();
|
||||
while buf.has_remaining() {
|
||||
let chunk = buf.chunk();
|
||||
v.extend_from_slice(chunk);
|
||||
let n = chunk.len();
|
||||
buf.advance(n);
|
||||
}
|
||||
v
|
||||
})
|
||||
.try_fold(Vec::new(), |mut acc, chunk_vec| async move {
|
||||
acc.extend_from_slice(&chunk_vec);
|
||||
Ok(acc)
|
||||
})
|
||||
.await
|
||||
.map_err(|_| warp::reject())?;
|
||||
|
||||
// Save file
|
||||
let path = storage::save_file(&filename, &data).map_err(|_| warp::reject())?;
|
||||
|
||||
// Generate gemini token/description (stub)
|
||||
let token = gemini_client::generate_token_for_file(path.to_str().unwrap()).await.map_err(|_| warp::reject())?;
|
||||
|
||||
// Insert file record
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let desc = Some(format!("token:{}", token));
|
||||
sqlx::query("INSERT INTO files (id, filename, path, description) VALUES (?, ?, ?, ?)")
|
||||
.bind(&id)
|
||||
.bind(&filename)
|
||||
.bind(path.to_str().unwrap())
|
||||
.bind(desc)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("DB insert error: {}", e);
|
||||
warp::reject()
|
||||
})?;
|
||||
|
||||
// generate demo embedding and upsert to Qdrant (async best-effort)
|
||||
let emb = crate::gemini_client::demo_embedding_from_path(path.to_str().unwrap());
|
||||
let qdrant_clone = qdrant.clone();
|
||||
let id_clone = id.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = qdrant_clone.upsert_point(&id_clone, emb).await {
|
||||
tracing::error!("qdrant upsert failed: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(warp::reply::json(&serde_json::json!({"success": true})))
|
||||
}
|
||||
|
||||
async fn handle_delete(q: DeleteQuery, pool: MySqlPool) -> Result<impl Reply, Rejection> {
|
||||
if let Some(row) = sqlx::query("SELECT path FROM files WHERE id = ?")
|
||||
.bind(&q.id)
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
.map_err(|_| warp::reject())?
|
||||
{
|
||||
let path: String = row.get("path");
|
||||
let _ = storage::delete_file(std::path::Path::new(&path));
|
||||
let _ = sqlx::query("DELETE FROM files WHERE id = ?").bind(&q.id).execute(&pool).await;
|
||||
return Ok(warp::reply::json(&serde_json::json!({"deleted": true})));
|
||||
}
|
||||
Ok(warp::reply::json(&serde_json::json!({"deleted": false})))
|
||||
}
|
||||
|
||||
async fn handle_list(pool: MySqlPool) -> Result<impl Reply, Rejection> {
|
||||
let rows = sqlx::query("SELECT id, filename, path, description FROM files ORDER BY created_at DESC LIMIT 500")
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("DB list error: {}", e);
|
||||
warp::reject()
|
||||
})?;
|
||||
|
||||
let files: Vec<serde_json::Value> = rows
|
||||
.into_iter()
|
||||
.map(|r| {
|
||||
let id: String = r.get("id");
|
||||
let filename: String = r.get("filename");
|
||||
let path: String = r.get("path");
|
||||
let description: Option<String> = r.get("description");
|
||||
serde_json::json!({"id": id, "filename": filename, "path": path, "description": description})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(warp::reply::json(&serde_json::json!({"files": files})))
|
||||
}
|
||||
|
||||
async fn handle_create_query(body: serde_json::Value, pool: MySqlPool) -> Result<impl Reply, Rejection> {
|
||||
// Insert query as queued, worker will pick it up
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let payload = body;
|
||||
sqlx::query("INSERT INTO queries (id, status, payload) VALUES (?, 'Queued', ?)")
|
||||
.bind(&id)
|
||||
.bind(payload)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("DB insert query error: {}", e);
|
||||
warp::reject()
|
||||
})?;
|
||||
|
||||
Ok(warp::reply::json(&serde_json::json!({"id": id})))
|
||||
}
|
||||
|
||||
async fn handle_query_status(q: DeleteQuery, pool: MySqlPool) -> Result<impl Reply, Rejection> {
|
||||
if let Some(row) = sqlx::query("SELECT status FROM queries WHERE id = ?")
|
||||
.bind(&q.id)
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
.map_err(|_| warp::reject())?
|
||||
{
|
||||
let status: String = row.get("status");
|
||||
return Ok(warp::reply::json(&serde_json::json!({"status": status})));
|
||||
}
|
||||
Ok(warp::reply::json(&serde_json::json!({"status": "not_found"})))
|
||||
}
|
||||
|
||||
async fn handle_query_result(q: DeleteQuery, pool: MySqlPool) -> Result<impl Reply, Rejection> {
|
||||
if let Some(row) = sqlx::query("SELECT result FROM queries WHERE id = ?")
|
||||
.bind(&q.id)
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
.map_err(|_| warp::reject())?
|
||||
{
|
||||
let result: Option<serde_json::Value> = row.get("result");
|
||||
return Ok(warp::reply::json(&serde_json::json!({"result": result})));
|
||||
}
|
||||
Ok(warp::reply::json(&serde_json::json!({"result": null})))
|
||||
}
|
||||
|
||||
async fn handle_cancel_query(q: DeleteQuery, pool: MySqlPool) -> Result<impl Reply, Rejection> {
|
||||
// Mark as cancelled; worker must check status before heavy steps
|
||||
sqlx::query("UPDATE queries SET status = 'Cancelled' WHERE id = ?")
|
||||
.bind(&q.id)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.map_err(|_| warp::reject())?;
|
||||
Ok(warp::reply::json(&serde_json::json!({"cancelled": true})))
|
||||
}
|
||||
33
rust-engine/src/db.rs
Normal file
33
rust-engine/src/db.rs
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
use sqlx::{MySql, MySqlPool};
|
||||
use tracing::info;
|
||||
|
||||
pub async fn init_db(database_url: &str) -> Result<MySqlPool, sqlx::Error> {
|
||||
let pool = MySqlPool::connect(database_url).await?;
|
||||
|
||||
// Create tables if they don't exist. Simple schema for demo/hackathon use.
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
filename TEXT NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
description TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS queries (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
status VARCHAR(32) NOT NULL,
|
||||
payload JSON,
|
||||
result JSON,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
);
|
||||
"#,
|
||||
)
|
||||
.execute(&pool)
|
||||
.await?;
|
||||
|
||||
info!("Database initialized");
|
||||
Ok(pool)
|
||||
}
|
||||
37
rust-engine/src/gemini_client.rs
Normal file
37
rust-engine/src/gemini_client.rs
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
use anyhow::Result;
|
||||
use serde::Deserialize;
|
||||
|
||||
// NOTE: This is a small stub to represent where you'd call the Gemini API.
|
||||
// Replace with real API call and proper auth handling for production.
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GeminiTokenResponse {
|
||||
pub token: String,
|
||||
}
|
||||
|
||||
pub async fn generate_token_for_file(_path: &str) -> Result<String> {
|
||||
Ok("gemini-token-placeholder".to_string())
|
||||
}
|
||||
|
||||
/// Demo embedding generator - deterministic pseudo-embedding from filename/path
|
||||
pub fn demo_embedding_from_path(path: &str) -> Vec<f32> {
|
||||
// Very simple: hash bytes into a small vector
|
||||
let mut v = vec![0f32; 64];
|
||||
for (i, b) in path.as_bytes().iter().enumerate() {
|
||||
let idx = i % v.len();
|
||||
v[idx] += (*b as f32) / 255.0;
|
||||
}
|
||||
v
|
||||
}
|
||||
|
||||
pub const DEMO_EMBED_DIM: usize = 64;
|
||||
|
||||
/// Demo text embedding (replace with real Gemini text embedding API)
|
||||
pub async fn demo_text_embedding(text: &str) -> Result<Vec<f32>> {
|
||||
let mut v = vec![0f32; DEMO_EMBED_DIM];
|
||||
for (i, b) in text.as_bytes().iter().enumerate() {
|
||||
let idx = i % v.len();
|
||||
v[idx] += (*b as f32) / 255.0;
|
||||
}
|
||||
Ok(v)
|
||||
}
|
||||
|
|
@ -1,21 +1,16 @@
|
|||
mod api;
|
||||
mod db;
|
||||
mod gemini_client;
|
||||
mod models;
|
||||
mod storage;
|
||||
mod vector;
|
||||
mod worker;
|
||||
mod vector_db;
|
||||
|
||||
use std::env;
|
||||
use std::error::Error;
|
||||
use tracing::info;
|
||||
use warp::Filter;
|
||||
use sqlx::mysql::MySqlPool;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{info, warn};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct HealthResponse {
|
||||
status: String,
|
||||
timestamp: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ApiResponse<T> {
|
||||
success: bool,
|
||||
data: Option<T>,
|
||||
message: Option<String>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
|
@ -29,103 +24,28 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
.unwrap_or_else(|_| "mysql://astraadmin:password@mysql:3306/astra".to_string());
|
||||
|
||||
info!("Starting Rust Engine...");
|
||||
// info!("Connecting to database: {}", database_url);
|
||||
|
||||
// Connect to database
|
||||
let pool = match MySqlPool::connect(&database_url).await {
|
||||
Ok(pool) => {
|
||||
info!("Successfully connected to database");
|
||||
pool
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to connect to database: {}. Starting without DB connection.", e);
|
||||
// In a hackathon setting, we might want to continue without DB for initial testing
|
||||
return start_server_without_db().await;
|
||||
}
|
||||
};
|
||||
// Ensure storage dir
|
||||
storage::ensure_storage_dir().expect("storage dir");
|
||||
|
||||
// CORS configuration
|
||||
let cors = warp::cors()
|
||||
.allow_any_origin()
|
||||
.allow_headers(vec!["content-type", "authorization"])
|
||||
.allow_methods(vec!["GET", "POST", "PUT", "DELETE", "OPTIONS"]);
|
||||
// Initialize DB
|
||||
let pool = db::init_db(&database_url).await.map_err(|e| -> Box<dyn Error> { Box::new(e) })?;
|
||||
|
||||
// Health check endpoint
|
||||
let health = warp::path("health")
|
||||
.and(warp::get())
|
||||
.map(|| {
|
||||
let response = HealthResponse {
|
||||
status: "healthy".to_string(),
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
};
|
||||
warp::reply::json(&ApiResponse {
|
||||
success: true,
|
||||
data: Some(response),
|
||||
message: None,
|
||||
})
|
||||
});
|
||||
// Spawn worker
|
||||
let worker = worker::Worker::new(pool.clone());
|
||||
tokio::spawn(async move { worker.run().await });
|
||||
|
||||
// API routes - you'll expand these for your hackathon needs
|
||||
let api = warp::path("api")
|
||||
.and(
|
||||
health.or(
|
||||
// Add more routes here as needed
|
||||
warp::path("version")
|
||||
.and(warp::get())
|
||||
.map(|| {
|
||||
warp::reply::json(&ApiResponse {
|
||||
success: true,
|
||||
data: Some("1.0.0"),
|
||||
message: Some("Rust Engine API".to_string()),
|
||||
})
|
||||
})
|
||||
)
|
||||
);
|
||||
|
||||
let routes = api
|
||||
.with(cors)
|
||||
// API routes
|
||||
let api_routes = api::routes(pool.clone())
|
||||
.with(warp::cors()
|
||||
.allow_any_origin()
|
||||
.allow_headers(vec!["content-type", "authorization"])
|
||||
.allow_methods(vec!["GET", "POST", "PUT", "DELETE", "OPTIONS"]))
|
||||
.with(warp::log("rust_engine"));
|
||||
|
||||
info!("Rust Engine started on http://0.0.0.0:8000");
|
||||
|
||||
warp::serve(routes)
|
||||
.run(([0, 0, 0, 0], 8000))
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_server_without_db() -> Result<(), Box<dyn std::error::Error>> {
|
||||
info!("Starting server in DB-less mode for development");
|
||||
|
||||
let cors = warp::cors()
|
||||
.allow_any_origin()
|
||||
.allow_headers(vec!["content-type", "authorization"])
|
||||
.allow_methods(vec!["GET", "POST", "PUT", "DELETE", "OPTIONS"]);
|
||||
|
||||
let health = warp::path("health")
|
||||
.and(warp::get())
|
||||
.map(|| {
|
||||
let response = HealthResponse {
|
||||
status: "healthy (no db)".to_string(),
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
};
|
||||
warp::reply::json(&ApiResponse {
|
||||
success: true,
|
||||
data: Some(response),
|
||||
message: Some("Running without database connection".to_string()),
|
||||
})
|
||||
});
|
||||
|
||||
let routes = warp::path("api")
|
||||
.and(health)
|
||||
.with(cors)
|
||||
.with(warp::log("rust_engine"));
|
||||
|
||||
info!("Rust Engine started on http://0.0.0.0:8000 (DB-less mode)");
|
||||
info!("Rust Engine prepared!");
|
||||
|
||||
warp::serve(routes)
|
||||
warp::serve(api_routes)
|
||||
.run(([0, 0, 0, 0], 8000))
|
||||
.await;
|
||||
|
||||
|
|
|
|||
56
rust-engine/src/models.rs
Normal file
56
rust-engine/src/models.rs
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct FileRecord {
|
||||
pub id: String,
|
||||
pub filename: String,
|
||||
pub path: String,
|
||||
pub description: Option<String>,
|
||||
pub created_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl FileRecord {
|
||||
pub fn new(filename: impl Into<String>, path: impl Into<String>, description: Option<String>) -> Self {
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
filename: filename.into(),
|
||||
path: path.into(),
|
||||
description,
|
||||
created_at: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub enum QueryStatus {
|
||||
Queued,
|
||||
InProgress,
|
||||
Completed,
|
||||
Cancelled,
|
||||
Failed,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct QueryRecord {
|
||||
pub id: String,
|
||||
pub status: QueryStatus,
|
||||
pub payload: serde_json::Value,
|
||||
pub result: Option<serde_json::Value>,
|
||||
pub created_at: Option<DateTime<Utc>>,
|
||||
pub updated_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl QueryRecord {
|
||||
pub fn new(payload: serde_json::Value) -> Self {
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
status: QueryStatus::Queued,
|
||||
payload,
|
||||
result: None,
|
||||
created_at: None,
|
||||
updated_at: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
34
rust-engine/src/storage.rs
Normal file
34
rust-engine/src/storage.rs
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
use anyhow::Result;
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
pub fn storage_dir() -> PathBuf {
|
||||
std::env::var("ASTRA_STORAGE")
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|_| std::env::current_dir().unwrap().join("storage"))
|
||||
}
|
||||
|
||||
pub fn ensure_storage_dir() -> Result<()> {
|
||||
let dir = storage_dir();
|
||||
if !dir.exists() {
|
||||
fs::create_dir_all(&dir)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn save_file(filename: &str, contents: &[u8]) -> Result<PathBuf> {
|
||||
ensure_storage_dir()?;
|
||||
let mut path = storage_dir();
|
||||
path.push(filename);
|
||||
let mut f = fs::File::create(&path)?;
|
||||
f.write_all(contents)?;
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
pub fn delete_file(path: &Path) -> Result<()> {
|
||||
if path.exists() {
|
||||
fs::remove_file(path)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
24
rust-engine/src/vector.rs
Normal file
24
rust-engine/src/vector.rs
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
use anyhow::Result;
|
||||
use lazy_static::lazy_static;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Mutex;
|
||||
|
||||
lazy_static! {
|
||||
static ref VECTOR_STORE: Mutex<HashMap<String, Vec<f32>>> = Mutex::new(HashMap::new());
|
||||
}
|
||||
|
||||
pub fn store_embedding(id: &str, emb: Vec<f32>) -> Result<()> {
|
||||
let mut s = VECTOR_STORE.lock().unwrap();
|
||||
s.insert(id.to_string(), emb);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn query_top_k(_query_emb: &[f32], k: usize) -> Result<Vec<String>> {
|
||||
// Very naive: return up to k ids from the store.
|
||||
let s = VECTOR_STORE.lock().unwrap();
|
||||
let mut out = Vec::new();
|
||||
for key in s.keys().take(k) {
|
||||
out.push(key.clone());
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
87
rust-engine/src/vector_db.rs
Normal file
87
rust-engine/src/vector_db.rs
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
use anyhow::Result;
|
||||
use reqwest::Client;
|
||||
use serde_json::json;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct QdrantClient {
|
||||
base: String,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl QdrantClient {
|
||||
pub fn new(base: &str) -> Self {
|
||||
Self {
|
||||
base: base.trim_end_matches('/').to_string(),
|
||||
client: Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Upsert a point into collection `files` with id and vector
|
||||
pub async fn upsert_point(&self, id: &str, vector: Vec<f32>) -> Result<()> {
|
||||
let url = format!("{}/collections/files/points", self.base);
|
||||
let body = json!({
|
||||
"points": [{
|
||||
"id": id,
|
||||
"vector": vector,
|
||||
"payload": {"type": "file"}
|
||||
}]
|
||||
});
|
||||
|
||||
let resp = self.client.post(&url).json(&body).send().await?;
|
||||
let status = resp.status();
|
||||
if status.is_success() {
|
||||
Ok(())
|
||||
} else {
|
||||
let t = resp.text().await.unwrap_or_default();
|
||||
Err(anyhow::anyhow!("qdrant upsert failed: {} - {}", status, t))
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensure the 'files' collection exists with the given dimension and distance metric
|
||||
pub async fn ensure_files_collection(&self, dim: usize) -> Result<()> {
|
||||
let url = format!("{}/collections/files", self.base);
|
||||
let body = json!({
|
||||
"vectors": {"size": dim, "distance": "Cosine"}
|
||||
});
|
||||
let resp = self.client.put(&url).json(&body).send().await?;
|
||||
// 200 OK or 201 Created means ready; 409 Conflict means already exists
|
||||
if resp.status().is_success() || resp.status().as_u16() == 409 {
|
||||
Ok(())
|
||||
} else {
|
||||
let status = resp.status();
|
||||
let t = resp.text().await.unwrap_or_default();
|
||||
Err(anyhow::anyhow!("qdrant ensure collection failed: {} - {}", status, t))
|
||||
}
|
||||
}
|
||||
|
||||
/// Search top-k nearest points from 'files'
|
||||
pub async fn search_top_k(&self, vector: Vec<f32>, k: usize) -> Result<Vec<String>> {
|
||||
let url = format!("{}/collections/files/points/search", self.base);
|
||||
let body = json!({
|
||||
"vector": vector,
|
||||
"limit": k
|
||||
});
|
||||
let resp = self.client.post(&url).json(&body).send().await?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let t = resp.text().await.unwrap_or_default();
|
||||
return Err(anyhow::anyhow!("qdrant search failed: {} - {}", status, t));
|
||||
}
|
||||
#[derive(Deserialize)]
|
||||
struct Hit { id: serde_json::Value }
|
||||
#[derive(Deserialize)]
|
||||
struct Data { result: Vec<Hit> }
|
||||
let data: Data = resp.json().await?;
|
||||
let mut ids = Vec::new();
|
||||
for h in data.result {
|
||||
// id can be string or number; handle string
|
||||
if let Some(s) = h.id.as_str() {
|
||||
ids.push(s.to_string());
|
||||
} else {
|
||||
ids.push(h.id.to_string());
|
||||
}
|
||||
}
|
||||
Ok(ids)
|
||||
}
|
||||
}
|
||||
160
rust-engine/src/worker.rs
Normal file
160
rust-engine/src/worker.rs
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
use crate::gemini_client::{demo_text_embedding, DEMO_EMBED_DIM};
|
||||
use crate::models::{QueryRecord, QueryStatus};
|
||||
use crate::vector_db::QdrantClient;
|
||||
use anyhow::Result;
|
||||
use sqlx::MySqlPool;
|
||||
use std::time::Duration;
|
||||
use tracing::{error, info};
|
||||
|
||||
pub struct Worker {
|
||||
pool: MySqlPool,
|
||||
qdrant: QdrantClient,
|
||||
}
|
||||
|
||||
impl Worker {
|
||||
pub fn new(pool: MySqlPool) -> Self {
|
||||
let qdrant_url = std::env::var("QDRANT_URL").unwrap_or_else(|_| "http://qdrant:6333".to_string());
|
||||
let qdrant = QdrantClient::new(&qdrant_url);
|
||||
Self { pool, qdrant }
|
||||
}
|
||||
|
||||
pub async fn run(&self) {
|
||||
info!("Worker starting");
|
||||
|
||||
// Ensure qdrant collection exists
|
||||
if let Err(e) = self.qdrant.ensure_files_collection(DEMO_EMBED_DIM).await {
|
||||
error!("Failed to ensure Qdrant collection: {}", e);
|
||||
}
|
||||
|
||||
// Requeue stale InProgress jobs older than cutoff (e.g., 10 minutes)
|
||||
if let Err(e) = self.requeue_stale_inprogress(10 * 60).await {
|
||||
error!("Failed to requeue stale jobs: {}", e);
|
||||
}
|
||||
|
||||
loop {
|
||||
// Claim next queued query
|
||||
match self.fetch_and_claim().await {
|
||||
Ok(Some(mut q)) => {
|
||||
info!("Processing query {}", q.id);
|
||||
if let Err(e) = self.process_query(&mut q).await {
|
||||
error!("Error processing {}: {}", q.id, e);
|
||||
let _ = self.mark_failed(&q.id, &format!("{}", e)).await;
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Worker fetch error: {}", e);
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_and_claim(&self) -> Result<Option<QueryRecord>> {
|
||||
// Note: MySQL transactional SELECT FOR UPDATE handling is more complex; for this hackathon scaffold
|
||||
// we do a simple two-step: select one queued id, then update it to InProgress if it is still queued.
|
||||
if let Some(row) = sqlx::query("SELECT id, payload FROM queries WHERE status = 'Queued' ORDER BY created_at LIMIT 1")
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
{
|
||||
use sqlx::Row;
|
||||
let id: String = row.get("id");
|
||||
let payload: serde_json::Value = row.get("payload");
|
||||
|
||||
let updated = sqlx::query("UPDATE queries SET status = 'InProgress' WHERE id = ? AND status = 'Queued'")
|
||||
.bind(&id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
if updated.rows_affected() == 1 {
|
||||
let mut q = QueryRecord::new(payload);
|
||||
q.id = id;
|
||||
q.status = QueryStatus::InProgress;
|
||||
return Ok(Some(q));
|
||||
}
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn process_query(&self, q: &mut QueryRecord) -> Result<()> {
|
||||
// Stage 1: set InProgress (idempotent)
|
||||
self.update_status(&q.id, QueryStatus::InProgress).await?;
|
||||
|
||||
// Stage 2: embed query text
|
||||
let text = q.payload.get("q").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let emb = demo_text_embedding(text).await?;
|
||||
|
||||
// Check cancellation
|
||||
if self.is_cancelled(&q.id).await? { return Ok(()); }
|
||||
|
||||
// Stage 3: search top-K in Qdrant
|
||||
let top_ids = self.qdrant.search_top_k(emb, 5).await.unwrap_or_default();
|
||||
|
||||
// Check cancellation
|
||||
if self.is_cancelled(&q.id).await? { return Ok(()); }
|
||||
|
||||
// Stage 4: persist results
|
||||
let result = serde_json::json!({
|
||||
"summary": format!("Found {} related files", top_ids.len()),
|
||||
"related_file_ids": top_ids,
|
||||
});
|
||||
sqlx::query("UPDATE queries SET status = 'Completed', result = ? WHERE id = ?")
|
||||
.bind(result)
|
||||
.bind(&q.id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_status(&self, id: &str, status: QueryStatus) -> Result<()> {
|
||||
let s = match status {
|
||||
QueryStatus::Queued => "Queued",
|
||||
QueryStatus::InProgress => "InProgress",
|
||||
QueryStatus::Completed => "Completed",
|
||||
QueryStatus::Cancelled => "Cancelled",
|
||||
QueryStatus::Failed => "Failed",
|
||||
};
|
||||
sqlx::query("UPDATE queries SET status = ? WHERE id = ?")
|
||||
.bind(s)
|
||||
.bind(id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn mark_failed(&self, id: &str, message: &str) -> Result<()> {
|
||||
let result = serde_json::json!({"error": message});
|
||||
sqlx::query("UPDATE queries SET status = 'Failed', result = ? WHERE id = ?")
|
||||
.bind(result)
|
||||
.bind(id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn requeue_stale_inprogress(&self, age_secs: i64) -> Result<()> {
|
||||
// MySQL: requeue items updated_at < now()-age and status = InProgress
|
||||
sqlx::query(
|
||||
"UPDATE queries SET status = 'Queued' WHERE status = 'InProgress' AND updated_at < (NOW() - INTERVAL ? SECOND)"
|
||||
)
|
||||
.bind(age_secs)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn is_cancelled(&self, id: &str) -> Result<bool> {
|
||||
if let Some(row) = sqlx::query("SELECT status FROM queries WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
{
|
||||
use sqlx::Row;
|
||||
let s: String = row.get("status");
|
||||
return Ok(s == "Cancelled");
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue