本帖最后由 CrLf 于 2023-7-29 00:17 编辑
命令行下调用OpenAI接口,从标准输入中读取用户输入并将其发送到GPT模型,再将响应写入标准输出。因原版默认是UTF8,所以我改成默认以GBK编码读取输入,并增加 --utf8 开关兼容utf8编码。
原版GitHub:https://github.com/pdfinn/sgpt
用法:sgpt -k <API_KEY> -i <INSTRUCTION> [-t TEMPERATURE] [-m MODEL] [-s SEPARATOR] [-u] [-d]COPY 参数说明:
短参数 | 长参数 | 环境变量 | 描述 | 默认值 | -k | --api_key | SGPT_API_KEY | 配置OpenAI的API KEY | 无 | -i | --instruction | SGPT_INSTRUCTION | 系统指令,用于补充一些背景信息或要求 | 无 | -t | --temperature | SGPT_TEMPERATURE | 温度值,范围是0~1,数值越高,给出的答案越有想象力但也更倾向于编造 | 0.5 | -m | --model | SGPT_MODEL | 所采用的模型 | gpt-3.5-turbo | -s | --separator | SGPT_SEPARATOR | 不同内容的分隔符 | \n | -u | --utf8 | SGPT_UTF8 | 以UTF8编码解读输入内容(该参数由CrLf添加,使默认编码是GBK) | false | -d | --debug | SGPT_DEBUG | 启用调试模式,将输出很多调试信息 | false |
CrLf修改后的源码: | package main | | | | import ( | | "bufio" | | "encoding/json" | | "fmt" | | "github.com/spf13/pflag" | | "github.com/spf13/viper" | | "io" | | "io/ioutil" | | "log" | | "net/http" | | "os" | | "strconv" | | "strings" | | | | // mod by CrLf 添加必要的模块 | | "bytes" | | "golang.org/x/text/encoding/simplifiedchinese" | | "golang.org/x/text/transform" | | | | ) | | | | | | | | func Utf8ToGbk(s []byte) ([]byte, error) { | | reader := transform.NewReader(bytes.NewReader(s), simplifiedchinese.GBK.NewEncoder()) | | d, e := ioutil.ReadAll(reader) | | if e != nil { | | return nil, e | | } | | return d, nil | | } | | | | func GbkToUtf8(s []byte) ([]byte, error) { | | reader := transform.NewReader(bytes.NewReader(s), simplifiedchinese.GBK.NewDecoder()) | | d, e := ioutil.ReadAll(reader) | | if e != nil { | | return nil, e | | } | | return d, nil | | } | | | | | | type OpenAIResponse struct { | | Choices []struct { | | Text string `json:"text,omitempty"` | | Message struct { | | Role string `json:"role,omitempty"` | | Content string `json:"content,omitempty"` | | } `json:"message,omitempty"` | | } `json:"choices"` | | } | | | | | | var utf8 *bool | | var debug *bool | | | | func init() { | | | | | | | | | | | | | | } | | | | func main() { | | | | defaultTemperature := 0.5 | | defaultModel := "gpt-3.5-turbo" | | | | | | envApiKey := os.Getenv("SGPT_API_KEY") | | envInstruction := os.Getenv("SGPT_INSTRUCTION") | | envTemperature, err := strconv.ParseFloat(os.Getenv("SGPT_TEMPERATURE"), 64) | | if err != nil { | | envTemperature = defaultTemperature | | } | | envModel := os.Getenv("SGPT_MODEL") | | envSeparator := os.Getenv("SGPT_SEPARATOR") | | | | | | envUTF8 := parseBoolWithDefault(os.Getenv("SGPT_UTF8"), false) | | envDebug := parseBoolWithDefault(os.Getenv("SGPT_DEBUG"), false) | | | | | | apiKey := pflag.StringP("api_key", "k", envApiKey, "OpenAI API key") | | instruction := pflag.StringP("instruction", "i", envInstruction, "Instruction for the GPT model") | | temperature := pflag.Float64P("temperature", "t", envTemperature, "Temperature for the GPT model") | | model := pflag.StringP("model", "m", envModel, "GPT model to use") | | defaulSeparator := "\n" | | separator := pflag.StringP("separator", "s", envSeparator, "Separator character for input") | | if *separator == "" { | | *separator = defaulSeparator | | } | | | | | | utf8 = pflag.BoolP("utf8", "u", envUTF8, "Enable UTF8 input") | | debug = pflag.BoolP("debug", "d", envDebug, "Enable debug output") | | pflag.Parse() | | | | | | viper.SetConfigName("sgpt") | | viper.AddConfigPath(".") | | viper.AddConfigPath("$HOME/.sgpt") | | viper.SetConfigType("yaml") | | | | err = viper.ReadInConfig() | | | | | | if _, ok := err.(viper.ConfigFileNotFoundError); ok { | | debugOutput(*debug, "Warning: Config file not found: %v", err) | | } else if err != nil { | | debugOutput(*debug, "Warning: Error reading config file: %v", err) | | } | | | | | | viper.SetDefault("model", defaultModel) | | viper.SetDefault("temperature", defaultTemperature) | | viper.BindPFlag("api_key", pflag.Lookup("k")) | | viper.BindPFlag("instruction", pflag.Lookup("i")) | | viper.BindPFlag("model", pflag.Lookup("m")) | | viper.BindPFlag("temperature", pflag.Lookup("t")) | | viper.BindPFlag("separator", pflag.Lookup("s")) | | viper.BindPFlag("debug", pflag.Lookup("d")) | | | | | | if *model == "" { | | *model = defaultModel | | } | | | | if *apiKey == "" { | | log.Fatal("API key is required") | | } | | | | | | | | | | var reader io.RuneReader | | if *utf8 { | | reader = bufio.NewReader(os.Stdin) | | } else { | | byteInput, _ := io.ReadAll(os.Stdin) | | gbkBytes, _ := GbkToUtf8(byteInput) | | reader = bytes.NewReader(gbkBytes) | | } | | | | var inputBuffer strings.Builder | | | | for { | | inputChar, _, err := reader.ReadRune() | | if err == io.EOF { | | input := inputBuffer.String() | | if input != "" { | | response, err := callOpenAI(*apiKey, *instruction, input, *temperature, *model) | | if err != nil { | | log.Fatal(err) | | } | | fmt.Println(response) | | } | | break | | } | | if err != nil { | | log.Fatal(err) | | } | | | | if string(inputChar) == *separator { | | input := inputBuffer.String() | | inputBuffer.Reset() | | | | response, err := callOpenAI(*apiKey, *instruction, input, *temperature, *model) | | if err != nil { | | log.Fatal(err) | | } | | | | fmt.Println(response) | | } else { | | inputBuffer.WriteRune(inputChar) | | } | | } | | } | | | | func debugOutput(debug bool, format string, a ...interface{}) { | | if debug { | | log.Printf(format, a...) | | } | | } | | | | func parseFloatWithDefault(value string, defaultValue float64) float64 { | | if value == "" { | | return defaultValue | | } | | parsedValue, err := strconv.ParseFloat(value, 64) | | if err != nil { | | log.Printf("Warning: Failed to parse float value: %v", err) | | return defaultValue | | } | | return parsedValue | | } | | | | func parseBoolWithDefault(value string, defaultValue bool) bool { | | if value == "" { | | return defaultValue | | } | | parsedValue, err := strconv.ParseBool(value) | | if err != nil { | | log.Printf("Warning: Failed to parse bool value: %v", err) | | return defaultValue | | } | | return parsedValue | | } | | | | func callOpenAI(apiKey, instruction, input string, temperature float64, model string) (string, error) { | | var url string | | var jsonData []byte | | var err error | | | | switch model { | | case "gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314", "gpt-3.5-turbo": | | url = "https://api.openai.com/v1/chat/completions" | | | | | | messages := []map[string]string{ | | {"role": "system", "content": instruction}, | | {"role": "user", "content": input}, | | } | | | | jsonData, err = json.Marshal(map[string]interface{}{ | | "model": model, | | "messages": messages, | | "temperature": temperature, | | "max_tokens": 100, | | "stop": []string{"\n"}, | | }) | | | | case "text-davinci-003", "text-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001": | | url = "https://api.openai.com/v1/completions" | | | | | | prompt := instruction + " " + input | | jsonData, err = json.Marshal(map[string]interface{}{ | | "model": model, | | "prompt": prompt, | | "temperature": temperature, | | "max_tokens": 100, | | "stop": []string{"\n"}, | | }) | | | | case "whisper-1": | | url = "https://api.openai.com/v1/audio/transcriptions" | | default: | | return "", fmt.Errorf("unsupported model: %s", model) | | } | | | | if err != nil { | | return "", err | | } | | | | data := strings.NewReader(string(jsonData)) | | | | req, err := http.NewRequest("POST", url, data) | | if err != nil { | | return "", err | | } | | | | req.Header.Set("Content-Type", "application/json") | | req.Header.Set("Authorization", "Bearer "+apiKey) | | | | client := &http.Client{} | | resp, err := client.Do(req) | | if err != nil { | | return "", err | | } | | defer resp.Body.Close() | | | | body, err := ioutil.ReadAll(resp.Body) | | if err != nil { | | return "", err | | } | | | | debugOutput(*debug, "API response: %s\n", string(body)) | | | | var openAIResponse OpenAIResponse | | err = json.Unmarshal(body, &openAIResponse) | | if err != nil { | | return "", err | | } | | | | if len(openAIResponse.Choices) == 0 { | | debugOutput(*debug, "API response: %s\n", string(body)) | | debugOutput(*debug, "HTTP status code: %s\n", strconv.Itoa(resp.StatusCode)) | | return "", fmt.Errorf("no choices returned from the API") | | } | | | | assistantMessage := "" | | for _, choice := range openAIResponse.Choices { | | if choice.Message.Role == "assistant" { | | assistantMessage = strings.TrimSpace(choice.Message.Content) | | break | | } | | if choice.Text != "" { | | assistantMessage = strings.TrimSpace(choice.Text) | | break | | } | | } | | | | if assistantMessage == "" { | | return "", fmt.Errorf("no assistant message found in the API response") | | } | | | | return assistantMessage, nil | | }COPY |
编译后的下载地址:http://bcn.bathome.net/s/tool/index.html?key=sgpt |