package com.robrua.nlp.bert; import java.io.File; import java.io.IOException; import java.io.OutputStream; import java.net.URL; import java.nio.IntBuffer; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.Comparator; import java.util.Iterator; import java.util.List; import java.util.zip.ZipEntry; import java.util.zip.ZipInputStream; import org.tensorflow.SavedModelBundle; import org.tensorflow.Tensor; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.Lists; import com.google.common.io.Resources; /** * <p> * easy-bert is a dead simple API for using <a href="https://github.com/google-research/bert">Google's high quality BERT language model</a>. * * The easy-bert Java bindings allow you to run pre-trained BERT models generated with easy-bert's Python tools. You can also used pre-generated models on Maven * Central. * <br> * <br> * <p> * To load a model from your local filesystem, you can use: * * <blockquote> * <pre> * {@code * try(Bert bert = Bert.load(new File("/path/to/your/model/"))) { * // Embed some sequences * } * } * </pre> * </blockquote> * * If the model is on your classpath (e.g. if you're pulling it in via Maven), you can use: * * <blockquote> * <pre> * {@code * try(Bert bert = Bert.load("/resource/path/to/your/model/")) { * // Embed some sequences * } * } * </pre> * </blockquote> * * See <a href="https://github.com/robrua/easy-bert">the easy-bert GitHub Repository</a> for information about model available via Maven Central. * <br> * <br> * <p> * Once you have a BERT model loaded, you can get sequence embeddings using {@link com.robrua.nlp.bert.Bert#embedSequence(String)}, * {@link com.robrua.nlp.bert.Bert#embedSequences(String...)}, {@link com.robrua.nlp.bert.Bert#embedSequences(Iterable)}, or * {@link com.robrua.nlp.bert.Bert#embedSequences(Iterator)}: * * <blockquote> * <pre> * {@code * float[] embedding = bert.embedSequence("A sequence"); * float[][] embeddings = bert.embedSequence("Multiple", "Sequences"); * } * </pre> * </blockquote> * * If you want per-token embeddings, you can use {@link com.robrua.nlp.bert.Bert#embedTokens(String)}, {@link com.robrua.nlp.bert.Bert#embedTokens(String...)}, * {@link com.robrua.nlp.bert.Bert#embedTokens(Iterable)}, or {@link com.robrua.nlp.bert.Bert#embedTokens(Iterator)}: * * <blockquote> * <pre> * {@code * float[][] embedding = bert.embedTokens("A sequence"); * float[][][] embeddings = bert.embedTokens("Multiple", "Sequences"); * } * </pre> * </blockquote> * * @author Rob Rua (https://github.com/robrua) * @version 1.0.3 * @since 1.0.3 * * @see <a href="https://github.com/robrua/easy-bert">The easy-bert GitHub Repository</a> * @see <a href="https://github.com/google-research/bert">Google's BERT GitHub Repository</a> */ public class Bert implements AutoCloseable { private class Inputs implements AutoCloseable { private final Tensor<Integer> inputIds, inputMask, segmentIds; public Inputs(final IntBuffer inputIds, final IntBuffer inputMask, final IntBuffer segmentIds, final int count) { this.inputIds = Tensor.create(new long[] {count, model.maxSequenceLength}, inputIds); this.inputMask = Tensor.create(new long[] {count, model.maxSequenceLength}, inputMask); this.segmentIds = Tensor.create(new long[] {count, model.maxSequenceLength}, segmentIds); } @Override public void close() { inputIds.close(); inputMask.close(); segmentIds.close(); } } private static class ModelDetails { public boolean doLowerCase; public String inputIds, inputMask, segmentIds, pooledOutput, sequenceOutput; public int maxSequenceLength; } private static final int FILE_COPY_BUFFER_BYTES = 1024 * 1024; private static final String MODEL_DETAILS = "model.json"; private static final String SEPARATOR_TOKEN = "[SEP]"; private static final String START_TOKEN = "[CLS]"; private static final String VOCAB_FILE = "vocab.txt"; /** * Loads a pre-trained BERT model from a TensorFlow saved model saved by the easy-bert Python utilities * * @param model * the model to load * @return a ready-to-use BERT model * @since 1.0.3 */ public static Bert load(final File model) { return load(Paths.get(model.toURI())); } /** * Loads a pre-trained BERT model from a TensorFlow saved model saved by the easy-bert Python utilities * * @param path * the path to load the model from * @return a ready-to-use BERT model * @since 1.0.3 */ public static Bert load(Path path) { path = path.toAbsolutePath(); ModelDetails model; try { model = new ObjectMapper().readValue(path.resolve("assets").resolve(MODEL_DETAILS).toFile(), ModelDetails.class); } catch(final IOException e) { throw new RuntimeException(e); } return new Bert(SavedModelBundle.load(path.toString(), "serve"), model, path.resolve("assets").resolve(VOCAB_FILE)); } /** * Loads a pre-trained BERT model from a TensorFlow saved model saved by the easy-bert Python utilities. The target resource should be in .zip format. * * @param resource * the resource path to load the model from - should be in .zip format * @return a ready-to-use BERT model * @since 1.0.3 */ public static Bert load(final String resource) { Path directory = null; try { // Create a temp directory to unpack the zip into final URL model = Resources.getResource(resource); directory = Files.createTempDirectory("easy-bert-"); try(ZipInputStream zip = new ZipInputStream(Resources.asByteSource(model).openBufferedStream())) { ZipEntry entry; // Copy each zip entry into the temp directory while((entry = zip.getNextEntry()) != null) { final Path path = directory.resolve(entry.getName()); if(entry.getName().endsWith("/")) { Files.createDirectories(path); } else { Files.createFile(path); try(OutputStream output = Files.newOutputStream(path)) { final byte[] buffer = new byte[FILE_COPY_BUFFER_BYTES]; int bytes; while((bytes = zip.read(buffer)) > 0) { output.write(buffer, 0, bytes); } } } zip.closeEntry(); } } // Load a BERT model from the temp directory return Bert.load(directory); } catch(final IOException e) { throw new RuntimeException(e); } finally { // Clean up the temp directory if(directory != null && Files.exists(directory)) { try { Files.walk(directory) .sorted(Comparator.reverseOrder()) .forEach((final Path file) -> { try { Files.delete(file); } catch(final IOException e) { throw new RuntimeException(e); } }); } catch(final IOException e) { throw new RuntimeException(e); } } } } private final SavedModelBundle bundle; private final ModelDetails model; private final int separatorTokenId; private final int startTokenId; private final FullTokenizer tokenizer; private Bert(final SavedModelBundle bundle, final ModelDetails model, final Path vocabulary) { tokenizer = new FullTokenizer(vocabulary, model.doLowerCase); this.bundle = bundle; this.model = model; final int[] ids = tokenizer.convert(new String[] {START_TOKEN, SEPARATOR_TOKEN}); startTokenId = ids[0]; separatorTokenId = ids[1]; } @Override public void close() { bundle.close(); } /** * Gets a pooled BERT embedding for a single sequence. Sequences are usually individual sentences, but don't have to be. * * @param sequence * the sequence to embed * @return the pooled embedding for the sequence * @since 1.0.3 */ public float[] embedSequence(final String sequence) { try(Inputs inputs = getInputs(sequence)) { final List<Tensor<?>> output = bundle.session().runner() .feed(model.inputIds, inputs.inputIds) .feed(model.inputMask, inputs.inputMask) .feed(model.segmentIds, inputs.segmentIds) .fetch(model.pooledOutput) .run(); try(Tensor<?> embedding = output.get(0)) { final float[][] converted = new float[1][(int)embedding.shape()[1]]; embedding.copyTo(converted); return converted[0]; } } } /** * Gets pooled BERT embeddings for multiple sequences. Sequences are usually individual sentences, but don't have to be. * The sequences will be processed in parallel as a single batch input to the TensorFlow model. * * @param sequences * the sequences to embed * @return the pooled embeddings for the sequences, in the order the input {@link java.lang.Iterable} provided them * @since 1.0.3 */ public float[][] embedSequences(final Iterable<String> sequences) { final List<String> list = Lists.newArrayList(sequences); return embedSequences(list.toArray(new String[list.size()])); } /** * Gets pooled BERT embeddings for multiple sequences. Sequences are usually individual sentences, but don't have to be. * The sequences will be processed in parallel as a single batch input to the TensorFlow model. * * @param sequences * the sequences to embed * @return the pooled embeddings for the sequences, in the order the input {@link java.util.Iterator} provided them * @since 1.0.3 */ public float[][] embedSequences(final Iterator<String> sequences) { final List<String> list = Lists.newArrayList(sequences); return embedSequences(list.toArray(new String[list.size()])); } /** * Gets pooled BERT embeddings for multiple sequences. Sequences are usually individual sentences, but don't have to be. * The sequences will be processed in parallel as a single batch input to the TensorFlow model. * * @param sequences * the sequences to embed * @return the pooled embeddings for the sequences, in the order they were provided * @since 1.0.3 */ public float[][] embedSequences(final String... sequences) { try(Inputs inputs = getInputs(sequences)) { final List<Tensor<?>> output = bundle.session().runner() .feed(model.inputIds, inputs.inputIds) .feed(model.inputMask, inputs.inputMask) .feed(model.segmentIds, inputs.segmentIds) .fetch(model.pooledOutput) .run(); try(Tensor<?> embedding = output.get(0)) { final float[][] converted = new float[sequences.length][(int)embedding.shape()[1]]; embedding.copyTo(converted); return converted; } } } /** * Gets BERT embeddings for each of the tokens in multiple sequences. Sequences are usually individual sentences, but don't have to be. * The sequences will be processed in parallel as a single batch input to the TensorFlow model. * * @param sequences * the sequences to embed * @return the token embeddings for the sequences, in the order the input {@link java.lang.Iterable} provided them * @since 1.0.3 */ public float[][][] embedTokens(final Iterable<String> sequences) { final List<String> list = Lists.newArrayList(sequences); return embedTokens(list.toArray(new String[list.size()])); } /** * Gets BERT embeddings for each of the tokens in multiple sequences. Sequences are usually individual sentences, but don't have to be. * The sequences will be processed in parallel as a single batch input to the TensorFlow model. * * @param sequences * the sequences to embed * @return the token embeddings for the sequences, in the order the input {@link java.util.Iterator} provided them * @since 1.0.3 */ public float[][][] embedTokens(final Iterator<String> sequences) { final List<String> list = Lists.newArrayList(sequences); return embedTokens(list.toArray(new String[list.size()])); } /** * Gets BERT embeddings for each of the tokens in single sequence. Sequences are usually individual sentences, but don't have to be. * * @param sequence * the sequence to embed * @return the token embeddings for the sequence * @since 1.0.3 */ public float[][] embedTokens(final String sequence) { try(Inputs inputs = getInputs(sequence)) { final List<Tensor<?>> output = bundle.session().runner() .feed(model.inputIds, inputs.inputIds) .feed(model.inputMask, inputs.inputMask) .feed(model.segmentIds, inputs.segmentIds) .fetch(model.sequenceOutput) .run(); try(Tensor<?> embedding = output.get(0)) { final float[][][] converted = new float[1][(int)embedding.shape()[1]][(int)embedding.shape()[2]]; embedding.copyTo(converted); return converted[0]; } } } /** * Gets BERT embeddings for each of the tokens in multiple sequences. Sequences are usually individual sentences, but don't have to be. * The sequences will be processed in parallel as a single batch input to the TensorFlow model. * * @param sequences * the sequences to embed * @return the token embeddings for the sequences, in the order they were provided * @since 1.0.3 */ public float[][][] embedTokens(final String... sequences) { try(Inputs inputs = getInputs(sequences)) { final List<Tensor<?>> output = bundle.session().runner() .feed(model.inputIds, inputs.inputIds) .feed(model.inputMask, inputs.inputMask) .feed(model.segmentIds, inputs.segmentIds) .fetch(model.sequenceOutput) .run(); try(Tensor<?> embedding = output.get(0)) { final float[][][] converted = new float[sequences.length][(int)embedding.shape()[1]][(int)embedding.shape()[2]]; embedding.copyTo(converted); return converted; } } } private Inputs getInputs(final String sequence) { final String[] tokens = tokenizer.tokenize(sequence); final IntBuffer inputIds = IntBuffer.allocate(model.maxSequenceLength); final IntBuffer inputMask = IntBuffer.allocate(model.maxSequenceLength); final IntBuffer segmentIds = IntBuffer.allocate(model.maxSequenceLength); /* * In BERT: * inputIds are the indexes in the vocabulary for each token in the sequence * inputMask is a binary mask that shows which inputIds have valid data in them * segmentIds are meant to distinguish paired sequences during training tasks. Here they're always 0 since we're only doing inference. */ final int[] ids = tokenizer.convert(tokens); inputIds.put(startTokenId); inputMask.put(1); segmentIds.put(0); for(int i = 0; i < ids.length && i < model.maxSequenceLength - 2; i++) { inputIds.put(ids[i]); inputMask.put(1); segmentIds.put(0); } inputIds.put(separatorTokenId); inputMask.put(1); segmentIds.put(0); while(inputIds.position() < model.maxSequenceLength) { inputIds.put(0); inputMask.put(0); segmentIds.put(0); } inputIds.rewind(); inputMask.rewind(); segmentIds.rewind(); return new Inputs(inputIds, inputMask, segmentIds, 1); } private Inputs getInputs(final String[] sequences) { final String[][] tokens = tokenizer.tokenize(sequences); final IntBuffer inputIds = IntBuffer.allocate(sequences.length * model.maxSequenceLength); final IntBuffer inputMask = IntBuffer.allocate(sequences.length * model.maxSequenceLength); final IntBuffer segmentIds = IntBuffer.allocate(sequences.length * model.maxSequenceLength); /* * In BERT: * inputIds are the indexes in the vocabulary for each token in the sequence * inputMask is a binary mask that shows which inputIds have valid data in them * segmentIds are meant to distinguish paired sequences during training tasks. Here they're always 0 since we're only doing inference. */ int instance = 1; for(final String[] token : tokens) { final int[] ids = tokenizer.convert(token); inputIds.put(startTokenId); inputMask.put(1); segmentIds.put(0); for(int i = 0; i < ids.length && i < model.maxSequenceLength - 2; i++) { inputIds.put(ids[i]); inputMask.put(1); segmentIds.put(0); } inputIds.put(separatorTokenId); inputMask.put(1); segmentIds.put(0); while(inputIds.position() < model.maxSequenceLength * instance) { inputIds.put(0); inputMask.put(0); segmentIds.put(0); } instance++; } inputIds.rewind(); inputMask.rewind(); segmentIds.rewind(); return new Inputs(inputIds, inputMask, segmentIds, sequences.length); } }