package com.github.thinkerou.karate.protobuf; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.logging.Logger; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.protobuf.DescriptorProtos; import com.google.protobuf.Descriptors; import com.google.protobuf.ProtocolStringList; import com.github.thinkerou.karate.domain.ProtoName; /** * ServiceResolver * * @author thinkerou */ public final class ServiceResolver { private static final Logger logger = Logger.getLogger(ServiceResolver.class.getName()); private final ImmutableList<Descriptors.FileDescriptor> fileDescriptors; /** * Creates a resolver which searches the supplied FileDescriptorSet. */ public static ServiceResolver fromFileDescriptorSet(DescriptorProtos.FileDescriptorSet descriptorSet) { ImmutableMap<String, DescriptorProtos.FileDescriptorProto> descriptorProtoIndex = computeDescriptorProtoIndex(descriptorSet); Map<String, Descriptors.FileDescriptor> descriptorCache = new HashMap<>(); ImmutableList.Builder<Descriptors.FileDescriptor> result = ImmutableList.builder(); List<DescriptorProtos.FileDescriptorProto> descriptorProtos = descriptorSet.getFileList(); for (DescriptorProtos.FileDescriptorProto descriptorProto : descriptorProtos) { try { result.add(descriptorFromProto(descriptorProto, descriptorProtoIndex, descriptorCache)); } catch (Descriptors.DescriptorValidationException e) { logger.warning(e.getMessage()); continue; } } return new ServiceResolver(result.build()); } /** * Lists all of the services found in the file descriptors. */ public Iterable<Descriptors.ServiceDescriptor> listServices() { ArrayList<Descriptors.ServiceDescriptor> serviceDescriptors = new ArrayList<>(); fileDescriptors.forEach(fileDescriptor -> serviceDescriptors.addAll(fileDescriptor.getServices())); return serviceDescriptors; } /** * Lists all the known message types. */ public ImmutableSet<Descriptors.Descriptor> listMessageTypes() { ImmutableSet.Builder<Descriptors.Descriptor> resultBuilder = ImmutableSet.builder(); fileDescriptors.forEach(d -> resultBuilder.addAll(d.getMessageTypes())); return resultBuilder.build(); } private ServiceResolver(Iterable<Descriptors.FileDescriptor> fileDescriptors) { this.fileDescriptors = ImmutableList.copyOf(fileDescriptors); } /** * Returns the descriptor of a protobuf method with the supplied grpc method name. * If the method can't be found, this throw IllegalArgumentException. */ public Descriptors.MethodDescriptor resolveServiceMethod(ProtoName method) { return resolveServiceMethod( method.getServiceName(), method.getMethodName(), method.getPackageName()); } private Descriptors.MethodDescriptor resolveServiceMethod( String serviceName, String methodName, String packageName) { Descriptors.ServiceDescriptor service = findService(serviceName, packageName); Descriptors.MethodDescriptor method = service.findMethodByName(methodName); if (method == null) { throw new IllegalArgumentException( "Can't find method " + methodName + " in service " + serviceName); } return method; } private Descriptors.ServiceDescriptor findService(String serviceName, String packageName) { for (Descriptors.FileDescriptor fileDescriptor : fileDescriptors) { if (!fileDescriptor.getPackage().equals(packageName)) { // Package does not match this file, ignore. continue; } Descriptors.ServiceDescriptor serviceDescriptor = fileDescriptor.findServiceByName(serviceName); if (serviceDescriptor != null) { return serviceDescriptor; } } throw new IllegalArgumentException("Can't find service with name: " + serviceName); } /** * Returns a map from descriptor proto name as found inside the descriptors to protos. */ private static ImmutableMap<String, DescriptorProtos.FileDescriptorProto> computeDescriptorProtoIndex( DescriptorProtos.FileDescriptorSet fileDescriptorSet) { ImmutableMap.Builder<String, DescriptorProtos.FileDescriptorProto> resultBuilder = ImmutableMap.builder(); List<DescriptorProtos.FileDescriptorProto> descriptorProtos = fileDescriptorSet.getFileList(); descriptorProtos.forEach(descriptorProto -> resultBuilder.put(descriptorProto.getName(), descriptorProto)); return resultBuilder.build(); } /** * Recursively constructs file descriptors for all dependencies of the supplied proto and * returns a FileDescriptor for the supplied proto itself. * For maximal efficientcy, reuse the descriptorCache argument across calls. */ private static Descriptors.FileDescriptor descriptorFromProto( DescriptorProtos.FileDescriptorProto descriptorProto, ImmutableMap<String, DescriptorProtos.FileDescriptorProto> descriptorProtoIndex, Map<String, Descriptors.FileDescriptor> descriptorCache) throws Descriptors.DescriptorValidationException { // First, check the cache. String descriptorName = descriptorProto.getName(); if (descriptorCache.containsKey(descriptorName)) { return descriptorCache.get(descriptorName); } // Then, fetch all the required dependencies recursively. ImmutableList.Builder<Descriptors.FileDescriptor> dependencies = ImmutableList.builder(); ProtocolStringList protocolStringList = descriptorProto.getDependencyList(); protocolStringList.forEach(dependencyName -> { if (!descriptorProtoIndex.containsKey(dependencyName)) { throw new IllegalArgumentException("Can't find dependency: " + dependencyName); } DescriptorProtos.FileDescriptorProto dependencyProto = descriptorProtoIndex.get(dependencyName); try { dependencies.add(descriptorFromProto(dependencyProto, descriptorProtoIndex, descriptorCache)); } catch (Descriptors.DescriptorValidationException e) { logger.warning(e.getMessage()); } }); // Finally, construct the actual descriptor. Descriptors.FileDescriptor[] empty = new Descriptors.FileDescriptor[0]; return Descriptors.FileDescriptor.buildFrom(descriptorProto, dependencies.build().toArray(empty)); } }