feat: with base oe

Signed-off-by: kjuulh <contact@kjuulh.io>
This commit is contained in:
Kasper Juul Hermansen 2024-01-13 13:30:18 +01:00
parent ff55446e8c
commit 51076be307
Signed by: kjuulh
GPG Key ID: 57B6E1465221F912
3 changed files with 906 additions and 206 deletions

860
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -10,3 +10,6 @@ tracing.workspace = true
tracing-subscriber.workspace = true
clap.workspace = true
dotenv.workspace = true
reqwest = { version = "0.11.23", features = ["json"] }
serde = { version = "1.0.195", features = ["derive"] }
serde_json = "1.0.111"

View File

@ -1,6 +1,12 @@
use std::net::SocketAddr;
use std::{
io::{stdin, stdout, Write},
process::{Child, ChildStdout, Stdio},
};
use anyhow::{anyhow, Context};
use clap::{Parser, Subcommand};
use reqwest::header::{HeaderMap, HeaderValue};
use serde::{Deserialize, Serialize};
#[derive(Parser)]
#[command(author, version, about, long_about = None, subcommand_required = true)]
@ -11,9 +17,153 @@ struct Command {
#[derive(Subcommand)]
enum Commands {
Init
Init,
}
#[derive(Clone, Deserialize, Serialize)]
struct OpenAIMessage {
pub role: String,
pub content: String,
}
#[derive(Clone, Deserialize, Serialize)]
struct OpenAIRequest {
pub model: String,
pub messages: Vec<OpenAIMessage>,
}
#[derive(Clone, Deserialize, Serialize)]
struct OpenAIPrompt {
pub prompt: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
struct OpenAICommandResponse {
command: String,
args: Option<Vec<String>>,
pipe: Option<Box<OpenAICommandResponse>>,
}
impl OpenAICommandResponse {
fn print(&self) -> String {
format!(
r#"{} {}{}"#,
self.command,
match &self.args {
Some(args) => args.join(" "),
None => "".into(),
},
match &self.pipe {
Some(pipe) => format!(" \\\n | {}", pipe.print()),
None => "".into(),
}
)
}
async fn execute_command(&self) -> anyhow::Result<()> {
let mut cmd = std::process::Command::new(&self.command);
let cmd = if let Some(_pipe) = &self.pipe {
cmd.stdout(std::process::Stdio::piped())
} else {
&mut cmd
};
let cmd = if let Some(args) = &self.args {
cmd.args(args)
} else {
cmd
};
let mut child = cmd.spawn()?;
let grandchild = if let Some(pipe) = &self.pipe {
if let Some(output) = child.stdout.take() {
Some(pipe.execute_command_with_stdin(output).await?)
} else {
None
}
} else {
None
};
let output = child.wait_with_output()?;
let output = String::from_utf8_lossy(&output.stdout);
println!("{}", output);
if let Some(grandchild) = grandchild {
let output = grandchild.wait_with_output()?;
let output = String::from_utf8_lossy(&output.stdout);
println!("{}", output);
}
Ok(())
}
async fn execute_command_with_stdin(&self, input: ChildStdout) -> anyhow::Result<Child> {
let mut cmd = std::process::Command::new(&self.command);
let cmd = cmd.stdin(Stdio::from(input));
let cmd = if let Some(_pipe) = &self.pipe {
cmd.stdout(std::process::Stdio::piped())
} else {
cmd
};
let cmd = if let Some(args) = &self.args {
cmd.args(args)
} else {
cmd
};
let cmd = cmd.stdout(std::process::Stdio::piped());
let child = cmd.spawn()?;
Ok(child)
}
}
#[derive(Serialize, Deserialize, Debug)]
struct Response {
id: String,
object: String,
created: i64,
model: String,
choices: Vec<Choice>,
usage: Usage,
system_fingerprint: Option<serde_json::Value>,
}
#[derive(Serialize, Deserialize, Debug)]
struct Choice {
index: i32,
message: Message,
logprobs: Option<serde_json::Value>,
finish_reason: String,
}
#[derive(Serialize, Deserialize, Debug)]
struct Message {
role: String,
content: String,
}
#[derive(Serialize, Deserialize, Debug)]
struct Usage {
prompt_tokens: i32,
completion_tokens: i32,
total_tokens: i32,
}
const OPENAI_PROMPT: &'static str = r#"
You are the interface for a linux cli system. You have all standard linux commands available, and will always return JSON.
You will always receive a message like this: {"prompt": "this is an example prompt for listing a directory and finding a file with the name example"}
You will always reply like this: {"command": "ls", "args": ["-la"], "pipe": {"command": "grep", "args": ["example"]}}
However be creative as the chain will continue down.
"#;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
dotenv::dotenv().ok();
@ -24,6 +174,101 @@ async fn main() -> anyhow::Result<()> {
match cli.command.unwrap() {
Commands::Init => {
tracing::info!("hello oe");
stdout().flush()?;
loop {
print!("> ");
stdout().flush()?;
let mut input = String::new();
stdin().read_line(&mut input)?;
let mut headers = HeaderMap::new();
let input_prompt = serde_json::to_string(&OpenAIPrompt {
prompt: input.clone(),
})?;
headers.insert(
"Authorization",
HeaderValue::from_str(
format!(
"Bearer {}",
std::env::var("OPENAI_API_KEY")
.context("OPENAI_API_KEY was not set")?
)
.as_str(),
)?,
);
let client = reqwest::ClientBuilder::new()
.default_headers(headers)
.user_agent(concat!(
env!("CARGO_PKG_NAME"),
"/",
env!("CARGO_PKG_VERSION"),
))
.build()?;
let request = client
.post("https://api.openai.com/v1/chat/completions")
.json(&OpenAIRequest {
model: "gpt-3.5-turbo".into(),
messages: vec![
OpenAIMessage {
role: "system".into(),
content: OPENAI_PROMPT.into(),
},
OpenAIMessage {
role: "user".into(),
content: input_prompt,
},
],
})
.build()?;
let resp = client.execute(request).await?;
if !resp.status().is_success() {
let body = resp.text().await?;
anyhow::bail!("failed to execute query: {}", body);
}
let resp = resp.text().await?;
//println!("should be json: {}", resp);
let openai_resp: Response = serde_json::from_str(&resp)?;
let output: OpenAICommandResponse = serde_json::from_str(
&openai_resp
.choices
.first()
.ok_or(anyhow!("failed to find a choice"))?
.message
.content,
)?;
println!("preview the command to be executed:\n{}", output.print());
print!("to accept please (N/y): ");
stdout().flush()?;
let mut choice = String::new();
stdin().read_line(&mut choice)?;
match choice.trim_end() {
"y" => {
println!("execute command");
output.execute_command().await?;
}
output => {
println!("'{}' was not valid input", output)
}
}
stdout().flush()?;
}
}
}