Skip to content

feat: add the /aitasks/prompts endpoint #18464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions coderd/aitasks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package coderd

import (
"fmt"
"net/http"
"strings"

"github.com/google/uuid"

"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/codersdk"
)

// This endpoint is experimental and not guaranteed to be stable, so we're not
// generating public-facing documentation for it.
func (api *API) aiTasksPrompts(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()

buildIDsParam := r.URL.Query().Get("build_ids")
if buildIDsParam == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "build_ids query parameter is required",
})
return
}

// Parse build IDs
buildIDStrings := strings.Split(buildIDsParam, ",")
buildIDs := make([]uuid.UUID, 0, len(buildIDStrings))
for _, idStr := range buildIDStrings {
id, err := uuid.Parse(strings.TrimSpace(idStr))
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Invalid build ID format: %s", idStr),
Detail: err.Error(),
})
return
}
buildIDs = append(buildIDs, id)
}

parameters, err := api.Database.GetWorkspaceBuildParametersByBuildIDs(ctx, buildIDs)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching workspace build parameters.",
Detail: err.Error(),
})
return
}

promptsByBuildID := make(map[string]string, len(parameters))
for _, param := range parameters {
if param.Name != codersdk.AITaskPromptParameterName {
continue
}
buildID := param.WorkspaceBuildID.String()
promptsByBuildID[buildID] = param.Value
}

httpapi.Write(ctx, rw, http.StatusOK, codersdk.AITasksPromptsResponse{
Prompts: promptsByBuildID,
})
}
141 changes: 141 additions & 0 deletions coderd/aitasks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package coderd_test

import (
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisioner/echo"
"github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/testutil"
)

func TestAITasksPrompts(t *testing.T) {
t.Parallel()

t.Run("EmptyBuildIDs", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{})
_ = coderdtest.CreateFirstUser(t, client)
experimentalClient := codersdk.NewExperimentalClient(client)

ctx := testutil.Context(t, testutil.WaitShort)

// Test with empty build IDs
prompts, err := experimentalClient.AITaskPrompts(ctx, []uuid.UUID{})
require.NoError(t, err)
require.Empty(t, prompts.Prompts)
})

t.Run("MultipleBuilds", func(t *testing.T) {
t.Parallel()

if !dbtestutil.WillUsePostgres() {
t.Skip("This test checks RBAC, which is not supported in the in-memory database")
}

adminClient := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
first := coderdtest.CreateFirstUser(t, adminClient)
memberClient, _ := coderdtest.CreateAnotherUser(t, adminClient, first.OrganizationID)

ctx := testutil.Context(t, testutil.WaitLong)

// Create a template with parameters
version := coderdtest.CreateTemplateVersion(t, adminClient, first.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: []*proto.Response{{
Type: &proto.Response_Plan{
Plan: &proto.PlanComplete{
Parameters: []*proto.RichParameter{
{
Name: "param1",
Type: "string",
DefaultValue: "default1",
},
{
Name: codersdk.AITaskPromptParameterName,
Type: "string",
DefaultValue: "default2",
},
},
},
},
}},
ProvisionApply: echo.ApplyComplete,
})
template := coderdtest.CreateTemplate(t, adminClient, first.OrganizationID, version.ID)
coderdtest.AwaitTemplateVersionJobCompleted(t, adminClient, version.ID)

// Create two workspaces with different parameters
workspace1 := coderdtest.CreateWorkspace(t, memberClient, template.ID, func(request *codersdk.CreateWorkspaceRequest) {
request.RichParameterValues = []codersdk.WorkspaceBuildParameter{
{Name: "param1", Value: "value1a"},
{Name: codersdk.AITaskPromptParameterName, Value: "value2a"},
}
})
coderdtest.AwaitWorkspaceBuildJobCompleted(t, memberClient, workspace1.LatestBuild.ID)

workspace2 := coderdtest.CreateWorkspace(t, memberClient, template.ID, func(request *codersdk.CreateWorkspaceRequest) {
request.RichParameterValues = []codersdk.WorkspaceBuildParameter{
{Name: "param1", Value: "value1b"},
{Name: codersdk.AITaskPromptParameterName, Value: "value2b"},
}
})
coderdtest.AwaitWorkspaceBuildJobCompleted(t, memberClient, workspace2.LatestBuild.ID)

workspace3 := coderdtest.CreateWorkspace(t, adminClient, template.ID, func(request *codersdk.CreateWorkspaceRequest) {
request.RichParameterValues = []codersdk.WorkspaceBuildParameter{
{Name: "param1", Value: "value1c"},
{Name: codersdk.AITaskPromptParameterName, Value: "value2c"},
}
})
coderdtest.AwaitWorkspaceBuildJobCompleted(t, adminClient, workspace3.LatestBuild.ID)
allBuildIDs := []uuid.UUID{workspace1.LatestBuild.ID, workspace2.LatestBuild.ID, workspace3.LatestBuild.ID}

experimentalMemberClient := codersdk.NewExperimentalClient(memberClient)
// Test parameters endpoint as member
prompts, err := experimentalMemberClient.AITaskPrompts(ctx, allBuildIDs)
require.NoError(t, err)
// we expect 2 prompts because the member client does not have access to workspace3
// since it was created by the admin client
require.Len(t, prompts.Prompts, 2)

// Check workspace1 parameters
build1Prompt := prompts.Prompts[workspace1.LatestBuild.ID.String()]
require.Equal(t, "value2a", build1Prompt)

// Check workspace2 parameters
build2Prompt := prompts.Prompts[workspace2.LatestBuild.ID.String()]
require.Equal(t, "value2b", build2Prompt)

experimentalAdminClient := codersdk.NewExperimentalClient(adminClient)
// Test parameters endpoint as admin
// we expect 3 prompts because the admin client has access to all workspaces
prompts, err = experimentalAdminClient.AITaskPrompts(ctx, allBuildIDs)
require.NoError(t, err)
require.Len(t, prompts.Prompts, 3)

// Check workspace3 parameters
build3Prompt := prompts.Prompts[workspace3.LatestBuild.ID.String()]
require.Equal(t, "value2c", build3Prompt)
})

t.Run("NonExistentBuildIDs", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, &coderdtest.Options{})
_ = coderdtest.CreateFirstUser(t, client)

ctx := testutil.Context(t, testutil.WaitShort)

// Test with non-existent build IDs
nonExistentID := uuid.New()
experimentalClient := codersdk.NewExperimentalClient(client)
prompts, err := experimentalClient.AITaskPrompts(ctx, []uuid.UUID{nonExistentID})
require.NoError(t, err)
require.Empty(t, prompts.Prompts)
})
}
8 changes: 8 additions & 0 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,14 @@ func New(options *Options) *API {
})
})

// Experimental routes are not guaranteed to be stable and may change at any time.
r.Route("/api/experimental", func(r chi.Router) {
r.Use(apiKeyMiddleware)
r.Route("/aitasks", func(r chi.Router) {
r.Get("/prompts", api.aiTasksPrompts)
})
})

r.Route("/api/v2", func(r chi.Router) {
api.APIHandler = r

Expand Down
13 changes: 13 additions & 0 deletions coderd/database/dbauthz/dbauthz.go
Original file line number Diff line number Diff line change
Expand Up @@ -3281,6 +3281,15 @@ func (q *querier) GetWorkspaceBuildParameters(ctx context.Context, workspaceBuil
return q.db.GetWorkspaceBuildParameters(ctx, workspaceBuildID)
}

func (q *querier) GetWorkspaceBuildParametersByBuildIDs(ctx context.Context, workspaceBuildIDs []uuid.UUID) ([]database.WorkspaceBuildParameter, error) {
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceWorkspace.Type)
if err != nil {
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
}

return q.db.GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx, workspaceBuildIDs, prep)
}

func (q *querier) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
Expand Down Expand Up @@ -5266,6 +5275,10 @@ func (q *querier) GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx context.Context,
return q.GetWorkspacesAndAgentsByOwnerID(ctx, ownerID)
}

func (q *querier) GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx context.Context, workspaceBuildIDs []uuid.UUID, _ rbac.PreparedAuthorized) ([]database.WorkspaceBuildParameter, error) {
return q.GetWorkspaceBuildParametersByBuildIDs(ctx, workspaceBuildIDs)
}

// GetAuthorizedUsers is not required for dbauthz since GetUsers is already
// authenticated.
func (q *querier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, _ rbac.PreparedAuthorized) ([]database.GetUsersRow, error) {
Expand Down
8 changes: 8 additions & 0 deletions coderd/database/dbauthz/dbauthz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2012,6 +2012,14 @@ func (s *MethodTestSuite) TestWorkspace() {
// No asserts here because SQLFilter.
check.Args(ws.OwnerID, emptyPreparedAuthorized{}).Asserts()
}))
s.Run("GetWorkspaceBuildParametersByBuildIDs", s.Subtest(func(db database.Store, check *expects) {
// no asserts here because SQLFilter
check.Args([]uuid.UUID{}).Asserts()
}))
s.Run("GetAuthorizedWorkspaceBuildParametersByBuildIDs", s.Subtest(func(db database.Store, check *expects) {
// no asserts here because SQLFilter
check.Args([]uuid.UUID{}, emptyPreparedAuthorized{}).Asserts()
}))
s.Run("GetLatestWorkspaceBuildByWorkspaceID", s.Subtest(func(db database.Store, check *expects) {
u := dbgen.User(s.T(), db, database.User{})
o := dbgen.Organization(s.T(), db, database.Organization{})
Expand Down
29 changes: 29 additions & 0 deletions coderd/database/dbmem/dbmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -7960,6 +7960,11 @@ func (q *FakeQuerier) GetWorkspaceBuildParameters(_ context.Context, workspaceBu
return q.getWorkspaceBuildParametersNoLock(workspaceBuildID)
}

func (q *FakeQuerier) GetWorkspaceBuildParametersByBuildIDs(ctx context.Context, workspaceBuildIDs []uuid.UUID) ([]database.WorkspaceBuildParameter, error) {
// No auth filter.
return q.GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx, workspaceBuildIDs, nil)
}

func (q *FakeQuerier) GetWorkspaceBuildStatsByTemplates(ctx context.Context, since time.Time) ([]database.GetWorkspaceBuildStatsByTemplatesRow, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
Expand Down Expand Up @@ -13901,6 +13906,30 @@ func (q *FakeQuerier) GetAuthorizedWorkspacesAndAgentsByOwnerID(ctx context.Cont
return out, nil
}

func (q *FakeQuerier) GetAuthorizedWorkspaceBuildParametersByBuildIDs(ctx context.Context, workspaceBuildIDs []uuid.UUID, prepared rbac.PreparedAuthorized) ([]database.WorkspaceBuildParameter, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()

if prepared != nil {
// Call this to match the same function calls as the SQL implementation.
_, err := prepared.CompileToSQL(ctx, rbac.ConfigWithoutACL())
if err != nil {
return nil, err
}
}

filteredParameters := make([]database.WorkspaceBuildParameter, 0)
for _, buildID := range workspaceBuildIDs {
parameters, err := q.GetWorkspaceBuildParameters(ctx, buildID)
if err != nil {
return nil, err
}
filteredParameters = append(filteredParameters, parameters...)
}

return filteredParameters, nil
}

func (q *FakeQuerier) GetAuthorizedUsers(ctx context.Context, arg database.GetUsersParams, prepared rbac.PreparedAuthorized) ([]database.GetUsersRow, error) {
if err := validateDatabaseType(arg); err != nil {
return nil, err
Expand Down
14 changes: 14 additions & 0 deletions coderd/database/dbmetrics/querymetrics.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions coderd/database/dbmock/dbmock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading