diff --git a/go/README.md b/go/README.md index 58207101..14a8128f 100644 --- a/go/README.md +++ b/go/README.md @@ -69,6 +69,18 @@ func main() { } ``` +## Distributing your application with an embedded GitHub Copilot CLI + +The SDK supports bundling, using Go's `embed` package, the Copilot CLI binary within your application's distribution. +This allows you to bundle a specific CLI version and avoid external dependencies on the user's system. + +Follow these steps to embed the CLI: + +1. Run `go get -tool github.com/github/copilot-sdk/go/cmd/bundler`. This is a one-time setup step per project. +2. Run `go tool bundler` in your build environment just before building your application. + +That's it! When your application calls `copilot.NewClient` without a `CLIPath` nor the `COPILOT_CLI_PATH` environment variable, the SDK will automatically install the embedded CLI to a cache directory and use it for all operations. + ## API Reference ### Client diff --git a/go/client.go b/go/client.go index 94c51f55..56c86562 100644 --- a/go/client.go +++ b/go/client.go @@ -42,6 +42,7 @@ import ( "sync" "time" + "github.com/github/copilot-sdk/go/internal/embeddedcli" "github.com/github/copilot-sdk/go/internal/jsonrpc2" ) @@ -102,7 +103,7 @@ type Client struct { // }) func NewClient(options *ClientOptions) *Client { opts := ClientOptions{ - CLIPath: "copilot", + CLIPath: "", Cwd: "", Port: 0, LogLevel: "info", @@ -1313,6 +1314,15 @@ func (c *Client) verifyProtocolVersion(ctx context.Context) error { // This spawns the CLI server as a subprocess using the configured transport // mode (stdio or TCP). func (c *Client) startCLIServer(ctx context.Context) error { + cliPath := c.options.CLIPath + if cliPath == "" { + // If no CLI path is provided, attempt to use the embedded CLI if available + cliPath = embeddedcli.Path() + } + if cliPath == "" { + // Default to "copilot" in PATH if no embedded CLI is available and no custom path is set + cliPath = "copilot" + } args := []string{"--headless", "--no-auto-update", "--log-level", c.options.LogLevel} // Choose transport mode @@ -1339,10 +1349,10 @@ func (c *Client) startCLIServer(ctx context.Context) error { // If CLIPath is a .js file, run it with node // Note we can't rely on the shebang as Windows doesn't support it - command := c.options.CLIPath - if strings.HasSuffix(c.options.CLIPath, ".js") { + command := cliPath + if strings.HasSuffix(cliPath, ".js") { command = "node" - args = append([]string{c.options.CLIPath}, args...) + args = append([]string{cliPath}, args...) } c.process = exec.CommandContext(ctx, command, args...) diff --git a/go/cmd/bundler/main.go b/go/cmd/bundler/main.go new file mode 100644 index 00000000..4c20d8e8 --- /dev/null +++ b/go/cmd/bundler/main.go @@ -0,0 +1,613 @@ +// Bundler downloads Copilot CLI binaries and packages them as a binary file, +// along with a Go source file that embeds the binary and metadata. +// +// Usage: +// +// go run github.com/github/copilot-sdk/go/cmd/bundler [--platform GOOS/GOARCH] [--output DIR] [--cli-version VERSION] +// +// --platform: Target platform using Go conventions (linux/amd64, linux/arm64, darwin/amd64, darwin/arm64, windows/amd64, windows/arm64). Defaults to current platform. +// --output: Output directory for embedded artifacts. Defaults to the current directory. +// --cli-version: CLI version to download. If not specified, automatically detects from the copilot-sdk version in go.mod. +package main + +import ( + "archive/tar" + "compress/gzip" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "flag" + "fmt" + "io" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + + "github.com/klauspost/compress/zstd" +) + +const ( + // Keep these URLs centralized so reviewers can verify all outbound calls in one place. + sdkModule = "github.com/github/copilot-sdk/go" + packageLockURLFmt = "https://raw.githubusercontent.com/github/copilot-sdk/%s/nodejs/package-lock.json" + tarballURLFmt = "https://registry.npmjs.org/@github/copilot-%s/-/copilot-%s-%s.tgz" + licenseTarballFmt = "https://registry.npmjs.org/@github/copilot/-/copilot-%s.tgz" +) + +// Platform info: npm package suffix, binary name +type platformInfo struct { + npmPlatform string + binaryName string +} + +// Map from GOOS/GOARCH to npm platform info +var platforms = map[string]platformInfo{ + "linux/amd64": {npmPlatform: "linux-x64", binaryName: "copilot"}, + "linux/arm64": {npmPlatform: "linux-arm64", binaryName: "copilot"}, + "darwin/amd64": {npmPlatform: "darwin-x64", binaryName: "copilot"}, + "darwin/arm64": {npmPlatform: "darwin-arm64", binaryName: "copilot"}, + "windows/amd64": {npmPlatform: "win32-x64", binaryName: "copilot.exe"}, + "windows/arm64": {npmPlatform: "win32-arm64", binaryName: "copilot.exe"}, +} + +// main is the CLI entry point. +func main() { + platform := flag.String("platform", runtime.GOOS+"/"+runtime.GOARCH, "Target platform as GOOS/GOARCH (e.g. linux/amd64, darwin/arm64), defaults to current platform") + output := flag.String("output", "", "Output directory for embedded artifacts. Defaults to the current directory") + cliVersion := flag.String("cli-version", "", "CLI version to download (auto-detected from go.mod if not specified)") + flag.Parse() + + // Resolve version first so the default output name can include it. + version := resolveCLIVersion(*cliVersion) + // Resolve platform once to validate input and get the npm package mapping. + goos, goarch, info, err := resolvePlatform(*platform) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + fmt.Fprintf(os.Stderr, "Valid platforms: %s\n", strings.Join(validPlatforms(), ", ")) + os.Exit(1) + } + + outputPath := filepath.Join(*output, defaultOutputFileName(version, goos, goarch, info.binaryName)) + + fmt.Printf("Building bundle for %s (CLI version %s)\n", *platform, version) + + binaryPath, sha256Hash, err := buildBundle(info, version, outputPath) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + // Generate the Go file with embed directive + if err := generateGoFile(goos, goarch, binaryPath, version, sha256Hash, "main"); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + if err := ensureZstdDependency(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } +} + +// resolvePlatform validates the platform flag and returns GOOS/GOARCH and mapping info. +func resolvePlatform(platform string) (string, string, platformInfo, error) { + goos, goarch, ok := strings.Cut(platform, "/") + if !ok || goos == "" || goarch == "" { + return "", "", platformInfo{}, fmt.Errorf("invalid platform %q", platform) + } + info, ok := platforms[platform] + if !ok { + return "", "", platformInfo{}, fmt.Errorf("invalid platform %q", platform) + } + return goos, goarch, info, nil +} + +// resolveCLIVersion determines the CLI version from the flag or repo metadata. +func resolveCLIVersion(flagValue string) string { + if flagValue != "" { + return flagValue + } + version, err := detectCLIVersion() + if err != nil { + fmt.Fprintf(os.Stderr, "Error detecting CLI version: %v\n", err) + fmt.Fprintln(os.Stderr, "Hint: specify --cli-version explicitly, or run from a Go module that depends on github.com/github/copilot-sdk/go") + os.Exit(1) + } + fmt.Printf("Auto-detected CLI version: %s\n", version) + return version +} + +// defaultOutputFileName builds the default bundle filename for a platform. +func defaultOutputFileName(version, goos, goarch, binaryName string) string { + base := strings.TrimSuffix(binaryName, filepath.Ext(binaryName)) + ext := filepath.Ext(binaryName) + return fmt.Sprintf("z%s_%s_%s_%s%s.zst", base, version, goos, goarch, ext) +} + +// validPlatforms returns valid platform keys for error messages. +func validPlatforms() []string { + result := make([]string, 0, len(platforms)) + for p := range platforms { + result = append(result, p) + } + return result +} + +// detectCLIVersion detects the CLI version by: +// 1. Running "go list -m" to get the copilot-sdk version from the user's go.mod +// 2. Fetching the package-lock.json from the SDK repo at that version +// 3. Extracting the @github/copilot CLI version from it +func detectCLIVersion() (string, error) { + // Get the SDK version from the user's go.mod + sdkVersion, err := getSDKVersion() + if err != nil { + return "", fmt.Errorf("failed to get SDK version: %w", err) + } + + fmt.Printf("Found copilot-sdk %s in go.mod\n", sdkVersion) + + // Fetch package-lock.json from the SDK repo at that version + cliVersion, err := fetchCLIVersionFromRepo(sdkVersion) + if err != nil { + return "", fmt.Errorf("failed to fetch CLI version: %w", err) + } + + return cliVersion, nil +} + +// getSDKVersion runs "go list -m" to get the copilot-sdk version from go.mod +func getSDKVersion() (string, error) { + cmd := exec.Command("go", "list", "-m", "-f", "{{.Version}}", sdkModule) + output, err := cmd.Output() + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + return "", fmt.Errorf("go list failed: %s", string(exitErr.Stderr)) + } + return "", err + } + + version := strings.TrimSpace(string(output)) + if version == "" { + return "", fmt.Errorf("module %s not found in go.mod", sdkModule) + } + + return version, nil +} + +// fetchCLIVersionFromRepo fetches package-lock.json from GitHub and extracts the CLI version. +func fetchCLIVersionFromRepo(sdkVersion string) (string, error) { + // Convert Go module version to Git ref + // v0.1.0 -> v0.1.0 + // v0.1.0-beta.1 -> v0.1.0-beta.1 + // v0.0.0-20240101120000-abcdef123456 -> abcdef123456 (pseudo-version) + gitRef := sdkVersion + + // Pseudo-versions end with a 12-character commit hash. + // Format: vX.Y.Z-yyyymmddhhmmss-abcdefabcdef + if idx := strings.LastIndex(sdkVersion, "-"); idx != -1 { + suffix := sdkVersion[idx+1:] + // Use the commit hash when present so we fetch the exact source snapshot. + if len(suffix) == 12 && isHex(suffix) { + gitRef = suffix + } + } + + url := fmt.Sprintf(packageLockURLFmt, gitRef) + fmt.Printf("Fetching %s...\n", url) + + resp, err := http.Get(url) + if err != nil { + return "", fmt.Errorf("failed to fetch: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("failed to fetch package-lock.json: %s", resp.Status) + } + + var packageLock struct { + Packages map[string]struct { + Version string `json:"version"` + } `json:"packages"` + } + + if err := json.NewDecoder(resp.Body).Decode(&packageLock); err != nil { + return "", fmt.Errorf("failed to parse package-lock.json: %w", err) + } + + pkg, ok := packageLock.Packages["node_modules/@github/copilot"] + if !ok || pkg.Version == "" { + return "", fmt.Errorf("could not find @github/copilot version in package-lock.json") + } + + return pkg.Version, nil +} + +// isHex returns true if s contains only hexadecimal characters. +func isHex(s string) bool { + for _, c := range s { + if (c < '0' || c > '9') && (c < 'a' || c > 'f') && (c < 'A' || c > 'F') { + return false + } + } + return true +} + +// buildBundle downloads the CLI binary and writes it to outputPath. +func buildBundle(info platformInfo, cliVersion, outputPath string) (string, []byte, error) { + outputDir := filepath.Dir(outputPath) + if outputDir == "" { + outputDir = "." + } + + // Check if output already exists + if _, err := os.Stat(outputPath); err == nil { + // Idempotent output avoids re-downloading in CI or local rebuilds. + fmt.Printf("Output %s already exists, skipping download\n", outputPath) + sha256Hash, err := sha256FileFromCompressed(outputPath) + if err != nil { + return "", nil, fmt.Errorf("failed to hash existing output: %w", err) + } + if err := downloadCLILicense(cliVersion, outputPath); err != nil { + return "", nil, fmt.Errorf("failed to download CLI license: %w", err) + } + return outputPath, sha256Hash, nil + } + // Create temp directory for download + tempDir, err := os.MkdirTemp("", "copilot-bundler-*") + if err != nil { + return "", nil, fmt.Errorf("failed to create temp dir: %w", err) + } + defer os.RemoveAll(tempDir) + + // Download the binary + binaryPath, err := downloadCLIBinary(info.npmPlatform, info.binaryName, cliVersion, tempDir) + if err != nil { + return "", nil, fmt.Errorf("failed to download CLI binary: %w", err) + } + + // Create output directory if needed + if outputDir != "." { + if err := os.MkdirAll(outputDir, 0755); err != nil { + return "", nil, fmt.Errorf("failed to create output directory: %w", err) + } + } + + sha256Hash, err := sha256File(binaryPath) + if err != nil { + return "", nil, fmt.Errorf("failed to hash output binary: %w", err) + } + if err := compressZstdFile(binaryPath, outputPath); err != nil { + return "", nil, fmt.Errorf("failed to write output binary: %w", err) + } + if err := downloadCLILicense(cliVersion, outputPath); err != nil { + return "", nil, fmt.Errorf("failed to download CLI license: %w", err) + } + fmt.Printf("Successfully created %s\n", outputPath) + return outputPath, sha256Hash, nil +} + +// generateGoFile creates a Go source file that embeds the binary and metadata. +func generateGoFile(goos, goarch, binaryPath, cliVersion string, sha256Hash []byte, pkgName string) error { + // Generate Go file path: zcopilot_linux_amd64.go (without version) + binaryName := filepath.Base(binaryPath) + licenseName := licenseFileName(binaryName) + goFileName := fmt.Sprintf("zcopilot_%s_%s.go", goos, goarch) + goFilePath := filepath.Join(filepath.Dir(binaryPath), goFileName) + hashBase64 := "" + if len(sha256Hash) > 0 { + hashBase64 = base64.StdEncoding.EncodeToString(sha256Hash) + } + + content := fmt.Sprintf(`// Code generated by copilot-sdk bundler; DO NOT EDIT. + +package %s + +import ( + "bytes" + "io" + "encoding/base64" + _ "embed" + + "github.com/github/copilot-sdk/go/embeddedcli" + "github.com/klauspost/compress/zstd" +) + +//go:embed %s +var localEmbeddedCopilotCLI []byte + +//go:embed %s +var localEmbeddedCopilotCLILicense []byte + + +func init() { + embeddedcli.Setup(embeddedcli.Config{ + Cli: cliReader(), + License: localEmbeddedCopilotCLILicense, + Version: %q, + CliHash: mustDecodeBase64(%q), + }) +} + +func cliReader() io.Reader { + r, err := zstd.NewReader(bytes.NewReader(localEmbeddedCopilotCLI)) + if err != nil { + panic("failed to create zstd reader: " + err.Error()) + } + return r +} + +func mustDecodeBase64(s string) []byte { + b, err := base64.StdEncoding.DecodeString(s) + if err != nil { + panic("failed to decode base64: " + err.Error()) + } + return b +} +`, pkgName, binaryName, licenseName, cliVersion, hashBase64) + + if err := os.WriteFile(goFilePath, []byte(content), 0644); err != nil { + return err + } + + fmt.Printf("Generated %s\n", goFilePath) + return nil +} + +// downloadCLIBinary downloads the npm tarball and extracts the CLI binary. +func downloadCLIBinary(npmPlatform, binaryName, cliVersion, destDir string) (string, error) { + tarballURL := fmt.Sprintf(tarballURLFmt, npmPlatform, npmPlatform, cliVersion) + + fmt.Printf("Downloading from %s...\n", tarballURL) + + resp, err := http.Get(tarballURL) + if err != nil { + return "", fmt.Errorf("failed to download: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("failed to download: %s", resp.Status) + } + + // Save tarball to temp file + tarballPath := filepath.Join(destDir, fmt.Sprintf("copilot-%s-%s.tgz", npmPlatform, cliVersion)) + tarballFile, err := os.Create(tarballPath) + if err != nil { + return "", fmt.Errorf("failed to create tarball file: %w", err) + } + + if _, err := io.Copy(tarballFile, resp.Body); err != nil { + tarballFile.Close() + return "", fmt.Errorf("failed to save tarball: %w", err) + } + if err := tarballFile.Close(); err != nil { + return "", fmt.Errorf("failed to close tarball file: %w", err) + } + + // Extract only the CLI binary to avoid unpacking the full package tree. + binaryPath := filepath.Join(destDir, binaryName) + if err := extractFileFromTarball(tarballPath, destDir, "package/"+binaryName, binaryName); err != nil { + return "", fmt.Errorf("failed to extract binary: %w", err) + } + + // Verify binary exists + if _, err := os.Stat(binaryPath); err != nil { + return "", fmt.Errorf("binary not found after extraction: %w", err) + } + + // Make executable on Unix + if !strings.HasSuffix(binaryName, ".exe") { + if err := os.Chmod(binaryPath, 0755); err != nil { + return "", fmt.Errorf("failed to chmod binary: %w", err) + } + } + + stat, err := os.Stat(binaryPath) + if err != nil { + return "", fmt.Errorf("failed to stat binary: %w", err) + } + sizeMB := float64(stat.Size()) / 1024 / 1024 + fmt.Printf("Downloaded %s (%.1f MB)\n", binaryName, sizeMB) + + return binaryPath, nil +} + +// downloadCLILicense downloads the @github/copilot package and writes its license next to outputPath. +func downloadCLILicense(cliVersion, outputPath string) error { + outputDir := filepath.Dir(outputPath) + if outputDir == "" { + outputDir = "." + } + licensePath := licensePathForOutput(outputPath) + if _, err := os.Stat(licensePath); err == nil { + return nil + } + + licenseURL := fmt.Sprintf(licenseTarballFmt, cliVersion) + resp, err := http.Get(licenseURL) + if err != nil { + return fmt.Errorf("failed to download license tarball: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to download license tarball: %s", resp.Status) + } + + gzReader, err := gzip.NewReader(resp.Body) + if err != nil { + return fmt.Errorf("failed to create gzip reader: %w", err) + } + defer gzReader.Close() + + tarReader := tar.NewReader(gzReader) + for { + header, err := tarReader.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("failed to read tar: %w", err) + } + switch header.Name { + case "package/LICENSE.md", "package/LICENSE": + licenseName := filepath.Base(licensePath) + if err := extractFileFromTarballStream(tarReader, outputDir, licenseName, os.FileMode(header.Mode)); err != nil { + return fmt.Errorf("failed to write license: %w", err) + } + return nil + } + } + + return fmt.Errorf("license file not found in tarball") +} + +func licensePathForOutput(outputPath string) string { + if strings.HasSuffix(outputPath, ".zst") { + return strings.TrimSuffix(outputPath, ".zst") + ".license" + } + return outputPath + ".license" +} + +func licenseFileName(binaryName string) string { + if strings.HasSuffix(binaryName, ".zst") { + return strings.TrimSuffix(binaryName, ".zst") + ".license" + } + return binaryName + ".license" +} + +// extractFileFromTarballStream writes the current tar entry to disk. +func extractFileFromTarballStream(r io.Reader, destDir, outputName string, mode os.FileMode) error { + outPath := filepath.Join(destDir, outputName) + outFile, err := os.OpenFile(outPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode) + if err != nil { + return fmt.Errorf("failed to create output file: %w", err) + } + if _, err := io.Copy(outFile, r); err != nil { + outFile.Close() + return fmt.Errorf("failed to extract license: %w", err) + } + return outFile.Close() +} + +// extractFileFromTarball extracts a single file from a .tgz into destDir with a new name. +func extractFileFromTarball(tarballPath, destDir, targetPath, outputName string) error { + file, err := os.Open(tarballPath) + if err != nil { + return err + } + defer file.Close() + + gzReader, err := gzip.NewReader(file) + if err != nil { + return fmt.Errorf("failed to create gzip reader: %w", err) + } + defer gzReader.Close() + + tarReader := tar.NewReader(gzReader) + + for { + header, err := tarReader.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("failed to read tar: %w", err) + } + + if header.Name == targetPath { + outPath := filepath.Join(destDir, outputName) + outFile, err := os.OpenFile(outPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode)) + if err != nil { + return fmt.Errorf("failed to create output file: %w", err) + } + + if _, err := io.Copy(outFile, tarReader); err != nil { + outFile.Close() + return fmt.Errorf("failed to extract binary: %w", err) + } + if err := outFile.Close(); err != nil { + return fmt.Errorf("failed to close output file: %w", err) + } + return nil + } + } + + return fmt.Errorf("file %q not found in tarball", targetPath) +} + +// compressZstdFile compresses src into dst using zstd. +func compressZstdFile(src, dst string) error { + srcFile, err := os.Open(src) + if err != nil { + return err + } + defer srcFile.Close() + + dstFile, err := os.Create(dst) + if err != nil { + return err + } + defer dstFile.Close() + + writer, err := zstd.NewWriter(dstFile) + if err != nil { + return err + } + defer writer.Close() + + if _, err := io.Copy(writer, srcFile); err != nil { + return err + } + return writer.Close() +} + +// sha256HexFileFromCompressed returns SHA-256 of the decompressed zstd stream. +func sha256FileFromCompressed(path string) ([]byte, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + + reader, err := zstd.NewReader(file) + if err != nil { + return nil, err + } + defer reader.Close() + + h := sha256.New() + if _, err := io.Copy(h, reader); err != nil { + return nil, err + } + return h.Sum(nil), nil +} + +// sha256File returns the SHA-256 hash of a file as raw bytes. +func sha256File(path string) ([]byte, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + + h := sha256.New() + if _, err := io.Copy(h, file); err != nil { + return nil, err + } + return h.Sum(nil), nil +} + +// ensureZstdDependency makes sure the module has the zstd dependency for generated code. +func ensureZstdDependency() error { + cmd := exec.Command("go", "mod", "tidy") + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to add zstd dependency: %w\n%s", err, strings.TrimSpace(string(output))) + } + return nil +} diff --git a/go/embeddedcli/installer.go b/go/embeddedcli/installer.go new file mode 100644 index 00000000..deb4c2ee --- /dev/null +++ b/go/embeddedcli/installer.go @@ -0,0 +1,17 @@ +package embeddedcli + +import "github.com/github/copilot-sdk/go/internal/embeddedcli" + +// Config defines the inputs used to install and locate the embedded Copilot CLI. +// +// Cli and CliHash are required. If Dir is empty, the CLI is installed into the +// system cache directory. Version is used to suffix the installed binary name to +// allow multiple versions to coexist. License, when provided, is written next +// to the installed binary. +type Config = embeddedcli.Config + +// Setup sets the embedded GitHub Copilot CLI install configuration. +// The CLI will be lazily installed when needed. +func Setup(cfg Config) { + embeddedcli.Setup(cfg) +} diff --git a/go/go.mod b/go/go.mod index 8287b047..c835cc88 100644 --- a/go/go.mod +++ b/go/go.mod @@ -2,4 +2,7 @@ module github.com/github/copilot-sdk/go go 1.24 -require github.com/google/jsonschema-go v0.4.2 +require ( + github.com/google/jsonschema-go v0.4.2 + github.com/klauspost/compress v1.18.3 +) diff --git a/go/go.sum b/go/go.sum index 6e171099..0cc670e8 100644 --- a/go/go.sum +++ b/go/go.sum @@ -2,3 +2,5 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw= +github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= diff --git a/go/internal/embeddedcli/embeddedcli.go b/go/internal/embeddedcli/embeddedcli.go new file mode 100644 index 00000000..15c981d6 --- /dev/null +++ b/go/internal/embeddedcli/embeddedcli.go @@ -0,0 +1,202 @@ +package embeddedcli + +import ( + "bytes" + "crypto/sha256" + "fmt" + "io" + "os" + "path/filepath" + "runtime" + "strings" + "sync" + "time" + + "github.com/github/copilot-sdk/go/internal/flock" +) + +// Config defines the inputs used to install and locate the embedded Copilot CLI. +// +// Cli and CliHash are required. If Dir is empty, the CLI is installed into the +// system cache directory. Version is used to suffix the installed binary name to +// allow multiple versions to coexist. License, when provided, is written next +// to the installed binary. +type Config struct { + Cli io.Reader + CliHash []byte + + License []byte + + Dir string + Version string +} + +func Setup(cfg Config) { + if cfg.Cli == nil { + panic("Cli reader is required") + } + if len(cfg.CliHash) != sha256.Size { + panic(fmt.Sprintf("CliHash must be a SHA-256 hash (%d bytes), got %d bytes", sha256.Size, len(cfg.CliHash))) + } + setupMu.Lock() + defer setupMu.Unlock() + if setupDone { + panic("Setup must only be called once") + } + if pathInitialized { + panic("Setup must be called before Path is accessed") + } + config = cfg + setupDone = true +} + +var Path = sync.OnceValue(func() string { + setupMu.Lock() + defer setupMu.Unlock() + if !setupDone { + return "" + } + pathInitialized = true + path := install() + return path +}) + +var ( + config Config + setupMu sync.Mutex + setupDone bool + pathInitialized bool +) + +func install() (path string) { + verbose := os.Getenv("COPILOT_CLI_INSTALL_VERBOSE") == "1" + logError := func(msg string, err error) { + if verbose { + fmt.Printf("embedded CLI installation error: %s: %v\n", msg, err) + } + } + if verbose { + start := time.Now() + defer func() { + duration := time.Since(start) + fmt.Printf("installing embedded CLI at %s installation took %s\n", path, duration) + }() + } + installDir := config.Dir + if installDir == "" { + var err error + if installDir, err = os.UserCacheDir(); err != nil { + // Fall back to temp dir if UserCacheDir is unavailable + installDir = os.TempDir() + } + installDir = filepath.Join(installDir, "copilot-sdk") + } + path, err := installAt(installDir) + if err != nil { + logError("installing in configured directory", err) + return "" + } + return path +} + +func installAt(installDir string) (string, error) { + if err := os.MkdirAll(installDir, 0755); err != nil { + return "", fmt.Errorf("creating install directory: %w", err) + } + version := sanitizeVersion(config.Version) + lockName := ".copilot-cli.lock" + if version != "" { + lockName = fmt.Sprintf(".copilot-cli-%s.lock", version) + } + + // Best effort to prevent concurrent installs. + if release, _ := flock.Acquire(filepath.Join(installDir, lockName)); release != nil { + defer release() + } + + binaryName := "copilot" + if runtime.GOOS == "windows" { + binaryName += ".exe" + } + finalPath := versionedBinaryPath(installDir, binaryName, version) + + if _, err := os.Stat(finalPath); err == nil { + existingHash, err := hashFile(finalPath) + if err != nil { + return "", fmt.Errorf("hashing existing binary: %w", err) + } + if !bytes.Equal(existingHash, config.CliHash) { + return "", fmt.Errorf("existing binary hash mismatch") + } + return finalPath, nil + } + + f, err := os.OpenFile(finalPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0755) + if err != nil { + return "", fmt.Errorf("creating binary file: %w", err) + } + _, err = io.Copy(f, config.Cli) + if err1 := f.Close(); err1 != nil && err == nil { + err = err1 + } + if closer, ok := config.Cli.(io.Closer); ok { + closer.Close() + } + if err != nil { + return "", fmt.Errorf("writing binary file: %w", err) + } + if len(config.License) > 0 { + licensePath := finalPath + ".license" + if err := os.WriteFile(licensePath, config.License, 0644); err != nil { + return "", fmt.Errorf("writing license file: %w", err) + } + } + return finalPath, nil +} + +// versionedBinaryPath builds the unpacked binary filename with an optional version suffix. +func versionedBinaryPath(dir, binaryName, version string) string { + if version == "" { + return filepath.Join(dir, binaryName) + } + base := strings.TrimSuffix(binaryName, filepath.Ext(binaryName)) + ext := filepath.Ext(binaryName) + return filepath.Join(dir, fmt.Sprintf("%s_%s%s", base, version, ext)) +} + +// sanitizeVersion makes a version string safe for filenames. +func sanitizeVersion(version string) string { + if version == "" { + return "" + } + var b strings.Builder + for _, r := range version { + switch { + case r >= 'a' && r <= 'z': + b.WriteRune(r) + case r >= 'A' && r <= 'Z': + b.WriteRune(r) + case r >= '0' && r <= '9': + b.WriteRune(r) + case r == '.' || r == '-' || r == '_': + b.WriteRune(r) + default: + b.WriteRune('_') + } + } + return b.String() +} + +// hashFile returns the SHA-256 hash of a file on disk. +func hashFile(path string) ([]byte, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + h := sha256.New() + if _, err := io.Copy(h, file); err != nil { + return nil, err + } + return h.Sum(nil), nil +} diff --git a/go/internal/embeddedcli/embeddedcli_test.go b/go/internal/embeddedcli/embeddedcli_test.go new file mode 100644 index 00000000..0453f729 --- /dev/null +++ b/go/internal/embeddedcli/embeddedcli_test.go @@ -0,0 +1,136 @@ +package embeddedcli + +import ( + "bytes" + "crypto/sha256" + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +func resetGlobals() { + setupMu.Lock() + defer setupMu.Unlock() + config = Config{} + setupDone = false + pathInitialized = false +} + +func mustPanic(t *testing.T, fn func()) { + t.Helper() + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected panic") + } + }() + fn() +} + +func binaryNameForOS() string { + name := "copilot" + if runtime.GOOS == "windows" { + name += ".exe" + } + return name +} + +func TestSetupPanicsOnNilCli(t *testing.T) { + resetGlobals() + mustPanic(t, func() { Setup(Config{}) }) +} + +func TestSetupPanicsOnSecondCall(t *testing.T) { + resetGlobals() + hash := sha256.Sum256([]byte("ok")) + Setup(Config{Cli: bytes.NewReader([]byte("ok")), CliHash: hash[:]}) + hash2 := sha256.Sum256([]byte("ok")) + mustPanic(t, func() { Setup(Config{Cli: bytes.NewReader([]byte("ok")), CliHash: hash2[:]}) }) + resetGlobals() +} + +func TestInstallAtWritesBinaryAndLicense(t *testing.T) { + resetGlobals() + tempDir := t.TempDir() + content := []byte("hello") + hash := sha256.Sum256(content) + Setup(Config{ + Cli: bytes.NewReader(content), + CliHash: hash[:], + License: []byte("license"), + Version: "1.2.3", + Dir: tempDir, + }) + + path := Path() + + expectedPath := versionedBinaryPath(tempDir, binaryNameForOS(), "1.2.3") + if path != expectedPath { + t.Fatalf("unexpected path: got %q want %q", path, expectedPath) + } + + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read binary: %v", err) + } + if !bytes.Equal(got, content) { + t.Fatalf("binary content mismatch") + } + + licensePath := path + ".license" + license, err := os.ReadFile(licensePath) + if err != nil { + t.Fatalf("read license: %v", err) + } + if string(license) != "license" { + t.Fatalf("license content mismatch") + } + + gotHash, err := hashFile(path) + if err != nil { + t.Fatalf("hash file: %v", err) + } + if !bytes.Equal(gotHash, hash[:]) { + t.Fatalf("hash mismatch") + } +} + +func TestInstallAtExistingBinaryHashMismatch(t *testing.T) { + resetGlobals() + tempDir := t.TempDir() + binaryPath := versionedBinaryPath(tempDir, binaryNameForOS(), "") + if err := os.MkdirAll(filepath.Dir(binaryPath), 0755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(binaryPath, []byte("bad"), 0755); err != nil { + t.Fatalf("write binary: %v", err) + } + + goodHash := sha256.Sum256([]byte("good")) + config = Config{ + Cli: bytes.NewReader([]byte("good")), + CliHash: goodHash[:], + } + + _, err := installAt(tempDir) + if err == nil || !strings.Contains(err.Error(), "hash mismatch") { + t.Fatalf("expected hash mismatch error, got %v", err) + } +} + +func TestSanitizeVersion(t *testing.T) { + got := sanitizeVersion("v1.2.3+build/abc") + want := "v1.2.3_build_abc" + if got != want { + t.Fatalf("sanitizeVersion() = %q want %q", got, want) + } +} + +func TestVersionedBinaryPath(t *testing.T) { + got := versionedBinaryPath("/tmp", "copilot.exe", "1.0.0") + want := filepath.Join("/tmp", "copilot_1.0.0.exe") + if got != want { + t.Fatalf("versionedBinaryPath() = %q want %q", got, want) + } +} diff --git a/go/internal/flock/flock.go b/go/internal/flock/flock.go new file mode 100644 index 00000000..fbf985a3 --- /dev/null +++ b/go/internal/flock/flock.go @@ -0,0 +1,29 @@ +package flock + +import "os" + +// Acquire opens (or creates) the lock file at path and blocks until the lock is acquired. +// It returns a release function to unlock and close the file. +func Acquire(path string) (func() error, error) { + f, err := os.OpenFile(path, os.O_CREATE, 0644) + if err != nil { + return nil, err + } + if err := lockFile(f); err != nil { + _ = f.Close() + return nil, err + } + released := false + release := func() error { + if released { + return nil + } + released = true + err := unlockFile(f) + if err1 := f.Close(); err == nil { + err = err1 + } + return err + } + return release, nil +} diff --git a/go/internal/flock/flock_other.go b/go/internal/flock/flock_other.go new file mode 100644 index 00000000..833b3460 --- /dev/null +++ b/go/internal/flock/flock_other.go @@ -0,0 +1,16 @@ +//go:build !windows && (!unix || aix || (solaris && !illumos)) + +package flock + +import ( + "errors" + "os" +) + +func lockFile(_ *os.File) error { + return errors.ErrUnsupported +} + +func unlockFile(_ *os.File) (err error) { + return errors.ErrUnsupported +} diff --git a/go/internal/flock/flock_test.go b/go/internal/flock/flock_test.go new file mode 100644 index 00000000..de26f661 --- /dev/null +++ b/go/internal/flock/flock_test.go @@ -0,0 +1,88 @@ +package flock + +import ( + "context" + "errors" + "os" + "path/filepath" + "testing" + "time" +) + +func TestAcquireReleaseCreatesFile(t *testing.T) { + path := filepath.Join(t.TempDir(), "lockfile") + + release, err := Acquire(path) + if errors.Is(err, errors.ErrUnsupported) { + t.Skip("file locking unsupported on this platform") + } + if err != nil { + t.Fatalf("Acquire failed: %v", err) + } + if _, err := os.Stat(path); err != nil { + release() + t.Fatalf("lock file not created: %v", err) + } + + if err := release(); err != nil { + t.Fatalf("Release failed: %v", err) + } + if err := release(); err != nil { + t.Fatalf("Release should be idempotent: %v", err) + } +} + +func TestLockBlocksUntilRelease(t *testing.T) { + path := filepath.Join(t.TempDir(), "lockfile") + + first, err := Acquire(path) + if errors.Is(err, errors.ErrUnsupported) { + t.Skip("file locking unsupported on this platform") + } + if err != nil { + t.Fatalf("Acquire failed: %v", err) + } + defer first() + + result := make(chan error, 1) + var second func() error + go func() { + lock, err := Acquire(path) + if err == nil { + second = lock + } + result <- err + }() + + blockCtx, cancelBlock := context.WithTimeout(t.Context(), 50*time.Millisecond) + defer cancelBlock() + select { + case err := <-result: + if err == nil && second != nil { + _ = second() + } + t.Fatalf("second Acquire should block, returned early: %v", err) + case <-blockCtx.Done(): + } + + if err := first(); err != nil { + t.Fatalf("Release failed: %v", err) + } + + unlockCtx, cancelUnlock := context.WithTimeout(t.Context(), 1*time.Second) + defer cancelUnlock() + select { + case err := <-result: + if err != nil { + t.Fatalf("second Acquire failed: %v", err) + } + if second == nil { + t.Fatalf("second lock was not set") + } + if err := second(); err != nil { + t.Fatalf("second Release failed: %v", err) + } + case <-unlockCtx.Done(): + t.Fatalf("second Acquire did not unblock") + } +} diff --git a/go/internal/flock/flock_unix.go b/go/internal/flock/flock_unix.go new file mode 100644 index 00000000..dbfc0a1f --- /dev/null +++ b/go/internal/flock/flock_unix.go @@ -0,0 +1,28 @@ +//go:build darwin || dragonfly || freebsd || illumos || linux || netbsd || openbsd + +package flock + +import ( + "os" + "syscall" +) + +func lockFile(f *os.File) (err error) { + for { + err = syscall.Flock(int(f.Fd()), syscall.LOCK_EX) + if err != syscall.EINTR { + break + } + } + return err +} + +func unlockFile(f *os.File) (err error) { + for { + err = syscall.Flock(int(f.Fd()), syscall.LOCK_UN) + if err != syscall.EINTR { + break + } + } + return err +} diff --git a/go/internal/flock/flock_windows.go b/go/internal/flock/flock_windows.go new file mode 100644 index 00000000..fc3322a1 --- /dev/null +++ b/go/internal/flock/flock_windows.go @@ -0,0 +1,66 @@ +//go:build windows + +package flock + +import ( + "os" + "syscall" + "unsafe" +) + +var ( + modKernel32 = syscall.NewLazyDLL("kernel32.dll") + procLockFileEx = modKernel32.NewProc("LockFileEx") + procUnlockFileEx = modKernel32.NewProc("UnlockFileEx") +) + +const LOCKFILE_EXCLUSIVE_LOCK = 0x00000002 + +func lockFile(f *os.File) error { + rc, err := f.SyscallConn() + if err != nil { + return err + } + var callErr error + if err := rc.Control(func(fd uintptr) { + var ol syscall.Overlapped + r1, _, e1 := procLockFileEx.Call( + fd, + uintptr(LOCKFILE_EXCLUSIVE_LOCK), + 0, + 1, + 0, + uintptr(unsafe.Pointer(&ol)), + ) + if r1 == 0 { + callErr = e1 + } + }); err != nil { + return err + } + return callErr +} + +func unlockFile(f *os.File) error { + rc, err := f.SyscallConn() + if err != nil { + return err + } + var callErr error + if err := rc.Control(func(fd uintptr) { + var ol syscall.Overlapped + r1, _, e1 := procUnlockFileEx.Call( + fd, + 0, + 1, + 0, + uintptr(unsafe.Pointer(&ol)), + ) + if r1 == 0 { + callErr = e1 + } + }); err != nil { + return err + } + return callErr +}