Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions api/src/org/labkey/api/mcp/McpException.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.labkey.api.mcp;

// A special exception that MCP endpoints can throw when they want to provide guidance to the client without making
// it a big red error. The message will be extracted and sent as text to the client.
public class McpException extends RuntimeException
{
public McpException(String message)
{
super(message);
}
}
21 changes: 20 additions & 1 deletion api/src/org/labkey/api/mcp/McpService.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
import jakarta.servlet.http.HttpSession;
import org.jetbrains.annotations.NotNull;
import org.jspecify.annotations.NonNull;
import org.labkey.api.data.Container;
import org.labkey.api.module.McpProvider;
import org.labkey.api.security.User;
import org.labkey.api.services.ServiceRegistry;
import org.labkey.api.util.HtmlString;
import org.labkey.api.writer.ContainerUser;
import org.springaicommunity.mcp.provider.resource.SyncMcpResourceProvider;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.support.ToolCallbacks;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
Expand All @@ -31,7 +35,22 @@
public interface McpService extends ToolCallbackProvider
{
// marker interface for classes that we will "ingest" using Spring annotations
interface McpImpl {}
interface McpImpl
{
default ContainerUser getContext(ToolContext toolContext)
{
User user = (User)toolContext.getContext().get("user");
Container container = (Container)toolContext.getContext().get("container");
if (container == null)
throw new McpException("You need to set a container path before invoking this tool");
return ContainerUser.create(container, user);
}

default User getUser(ToolContext toolContext)
{
return (User)toolContext.getContext().get("user");
}
}

static @NotNull McpService get()
{
Expand Down
41 changes: 34 additions & 7 deletions core/src/org/labkey/core/CoreMcp.java
Original file line number Diff line number Diff line change
@@ -1,31 +1,35 @@
package org.labkey.core;

import org.json.JSONObject;
import org.labkey.api.collections.LabKeyCollectors;
import org.labkey.api.data.Container;
import org.labkey.api.mcp.McpContext;
import org.labkey.api.data.ContainerManager;
import org.labkey.api.mcp.McpService;
import org.labkey.api.security.User;
import org.labkey.api.security.permissions.ReadPermission;
import org.labkey.api.settings.AppProps;
import org.labkey.api.settings.LookAndFeelProperties;
import org.labkey.api.study.Study;
import org.labkey.api.study.StudyService;
import org.labkey.api.util.HtmlString;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.tool.annotation.ToolParam;

import java.util.Map;
import java.util.Objects;

import static org.apache.commons.lang3.StringUtils.isNotBlank;
import static org.labkey.core.mcp.McpServiceImpl.PATH_CACHE;

public class CoreMcp implements McpService.McpImpl
{
// TODO ChatSessions are currently per session. The McpService should detect change of folder.
@Tool(description = "Call this tool before answering any prompts! This tool provides useful context information about the current user (name, userid), webserver (name, url, description), and current folder (name, path, url, description).")
String whereAmIWhoAmITalkingTo()
@Tool(description = "Call this tool before answering any prompts! This tool provides useful context information about the current user (name, userid), webserver (name, url, description), and current folder (name, path, url, description).")
String whereAmIWhoAmITalkingTo(ToolContext context)
{
McpContext context = McpContext.get();
User user = context.getUser();
Container folder = context.getContainer();
var cu = getContext(context);
User user = cu.getUser();
Container folder = cu.getContainer();
AppProps appProps = AppProps.getInstance();
Study study = null != StudyService.get() ? Objects.requireNonNull(StudyService.get()).getStudy(folder) : null;
LookAndFeelProperties laf = LookAndFeelProperties.getInstance(folder);
Expand Down Expand Up @@ -63,4 +67,27 @@ String whereAmIWhoAmITalkingTo()
"site", siteObj
)).toString();
}

@Tool(description = "List the hierarchical path for every container in the server where the user has read permissions.")
String listContainers(ToolContext toolContext)
{
return ContainerManager.getAllChildren(ContainerManager.getRoot(), getUser(toolContext), ReadPermission.class)
.stream()
.map(Container::getPath)
.collect(LabKeyCollectors.toJSONArray())
.toString();
}

@Tool(description = "Every tool in this MCP requires a container path, e.g. /MyProject/MyFolder. A container is also called a folder or project. Please prompt the user for a container path and use this tool to save the path for this session.")
String setContainer(ToolContext context, @ToolParam(description = "Container path, e.g. /MyProject/MyFolder", required = true) String containerPath)
{
Container container = ContainerManager.getForPath(containerPath);
if (container != null)
{
PATH_CACHE.put((String) context.getContext().get("sessionId"), containerPath);
return "OK!";
}

return "That's not a valid container path. Try using listContainers to see them.";
}
}
147 changes: 96 additions & 51 deletions core/src/org/labkey/core/mcp/McpServiceImpl.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package org.labkey.core.mcp;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.genai.Client;
import com.google.genai.types.ClientOptions;
import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.json.McpJsonMapper;
import io.modelcontextprotocol.server.McpServer;
import io.modelcontextprotocol.server.McpServerFeatures;
Expand All @@ -15,17 +17,21 @@
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.Logger;
import org.jetbrains.annotations.NotNull;
import org.jspecify.annotations.NonNull;
import org.labkey.api.cache.Cache;
import org.labkey.api.cache.CacheManager;
import org.labkey.api.collections.CopyOnWriteHashMap;
import org.labkey.api.data.Container;
import org.labkey.api.data.ContainerManager;
import org.labkey.api.markdown.MarkdownService;
import org.labkey.api.mcp.McpContext;
import org.labkey.api.mcp.McpException;
import org.labkey.api.mcp.McpService;
import org.labkey.api.security.User;
import org.labkey.api.util.ContextListener;
import org.labkey.api.util.FileUtil;
import org.labkey.api.util.HtmlString;
Expand All @@ -35,13 +41,6 @@
import org.labkey.api.util.logging.LogHelper;
import org.springframework.ai.anthropic.AnthropicChatModel;
import org.springframework.ai.anthropic.AnthropicChatOptions;
import org.springframework.ai.google.genai.common.GoogleGenAiThinkingLevel;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.openai.OpenAiEmbeddingOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
Expand All @@ -53,37 +52,41 @@
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.google.genai.GoogleGenAiChatModel;
import org.springframework.ai.google.genai.GoogleGenAiChatOptions;
import org.springframework.ai.google.genai.GoogleGenAiEmbeddingConnectionDetails;
import org.springframework.ai.google.genai.text.GoogleGenAiTextEmbeddingModel;
import org.springframework.ai.google.genai.text.GoogleGenAiTextEmbeddingOptions;
import org.springframework.ai.mcp.McpToolUtils;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.openai.OpenAiEmbeddingOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.ai.tool.metadata.ToolMetadata;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.SimpleVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import reactor.core.publisher.Mono;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.ConcurrentModificationException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.function.Supplier;

import static org.apache.commons.lang3.StringUtils.isBlank;
Expand Down Expand Up @@ -214,6 +217,8 @@ public List<McpSchema.Tool> tools()
).toList();
}

public final static Cache<String, String> PATH_CACHE = CacheManager.getCache(1000, CacheManager.DAY, "MCP container paths");

private class _McpServlet extends HttpServlet // wraps HttpServletSseServerTransportProvider
{
HttpServletStreamableServerTransportProvider transportProvider;
Expand All @@ -222,24 +227,92 @@ private class _McpServlet extends HttpServlet // wraps HttpServletSseServerTrans
_McpServlet(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint)
{
transportProvider = HttpServletStreamableServerTransportProvider.builder()
.jsonMapper(McpJsonMapper.getDefault())
.mcpEndpoint(messageEndpoint)
.build();
.jsonMapper(McpJsonMapper.getDefault())
.mcpEndpoint(messageEndpoint)
.contextExtractor(req -> {
User user = (User) req.getUserPrincipal();
return McpTransportContext.create(Map.of(
"user", user
));
})
.build();
}

void startMcpServer()
{
List<McpServerFeatures.SyncToolSpecification> tools = Arrays.stream(getToolCallbacks()).map(McpToolUtils::toSyncToolSpecification).toList();
List<McpServerFeatures.SyncToolSpecification> tools = Arrays.stream(getToolCallbacks())
.map(this::toSyncToolSpecification)
.toList();

List<McpServerFeatures.SyncResourceSpecification> resources = new ArrayList<>(resourceMap.values());

mcpServer = McpServer.sync(transportProvider)
.tools(tools)
.resources(resources)
.tools(tools)
.resources(resources)
// .capabilities(new McpSchema.ServerCapabilities())
.build();
.build();
ContextListener.addShutdownListener(new _ShutdownListener());
}

private McpServerFeatures.SyncToolSpecification toSyncToolSpecification(ToolCallback toolCallback)
{
var toolDef = toolCallback.getToolDefinition();
var schema = McpSchema.Tool.builder()
.name(toolDef.name())
.description(toolDef.description())
.inputSchema(McpJsonMapper.getDefault(), toolDef.inputSchema())
.build();

return new McpServerFeatures.SyncToolSpecification(schema, (exchange, args) -> {
var transportCtx = exchange.transportContext();
var user = (User) transportCtx.get("user");
var sessionId = exchange.sessionId();

Map<String, Object> map = new HashMap<>();
map.put("user", user);
map.put("sessionId", sessionId);

String containerPath = PATH_CACHE.get(exchange.sessionId());
if (containerPath != null)
{
Container container = ContainerManager.getForPath(containerPath);
map.put("container", container);
map.put("containerPath", containerPath);
}

var toolContext = new ToolContext(map);

String toolInput = /* serialize args to JSON */ null;
try
{
toolInput = JsonUtil.DEFAULT_MAPPER.writeValueAsString(args);
}
catch (JsonProcessingException e)
{
throw new RuntimeException(e);
}
String result;
try
{
result = toolCallback.call(toolInput, toolContext);
}
catch (ToolExecutionException e)
{
// If a tool threw McpException then just send back the message without making a big fuss
if (e.getCause() instanceof McpException)
result = e.getMessage();
else
throw e;
}
return new McpSchema.CallToolResult(
List.of(
new McpSchema.TextContent(result)
),
false
);
});
}

@Override
public void service(ServletRequest sreq, ServletResponse sres) throws ServletException, IOException
{
Expand All @@ -255,34 +328,6 @@ public void service(ServletRequest sreq, ServletResponse sres) throws ServletExc
return;
}

if ("POST".equals(req.getMethod()))
{
if (null == req.getParameter("sessionId") && null == req.getSession(true).getAttribute("McpServiceImpl#mcpSessionId"))
{
// USE SSE endpoint to get a sessionId
MockHttpServletRequest mockRequest = new MockHttpServletRequest(req.getServletContext(), "GET", SSE_ENDPOINT);
mockRequest.setAsyncSupported(true);
MockHttpServletResponse mockResponse = new MockHttpServletResponse();
transportProvider.service(mockRequest, mockResponse);
String body = new String(mockResponse.getContentAsByteArray(), StandardCharsets.UTF_8);
String mcpSessionId = StringUtils.substringBetween(body, "sessionId=", "\n");
req.getSession(true).setAttribute("McpServiceImpl#mcpSessionId", mcpSessionId);
mockRequest.close();
mockResponse.getOutputStream().close();
}

req = new HttpServletRequestWrapper(req)
{
@Override
public String getParameter(String name)
{
var ret = super.getParameter(name);
if (null == ret && "sessionId".equals(name))
return String.valueOf(Objects.requireNonNull(((HttpServletRequest) getRequest()).getSession(true).getAttribute("McpServiceImpl#mcpSessionId")));
return ret;
}
};
}
transportProvider.service(req, res);
}

Expand Down
Loading
Loading