// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. package com.mojang.datafixers.types.templates; import com.google.common.collect.ImmutableSet; import com.google.common.reflect.TypeToken; import com.mojang.datafixers.DSL; import com.mojang.datafixers.DataFixUtils; import com.mojang.datafixers.FamilyOptic; import com.mojang.datafixers.FunctionType; import com.mojang.datafixers.OpticParts; import com.mojang.datafixers.RewriteResult; import com.mojang.datafixers.TypeRewriteRule; import com.mojang.datafixers.TypedOptic; import com.mojang.datafixers.kinds.App; import com.mojang.datafixers.kinds.Applicative; import com.mojang.datafixers.kinds.K1; import com.mojang.datafixers.optics.Optic; import com.mojang.datafixers.optics.Optics; import com.mojang.datafixers.optics.Traversal; import com.mojang.datafixers.optics.profunctors.TraversalP; import com.mojang.datafixers.types.Type; import com.mojang.datafixers.types.families.RecursiveTypeFamily; import com.mojang.datafixers.types.families.TypeFamily; import com.mojang.datafixers.util.Either; import com.mojang.datafixers.util.Pair; import com.mojang.serialization.Codec; import com.mojang.serialization.DynamicOps; import javax.annotation.Nullable; import java.util.Objects; import java.util.Optional; import java.util.function.IntFunction; public final class Product implements TypeTemplate { private final TypeTemplate f; private final TypeTemplate g; public Product(final TypeTemplate f, final TypeTemplate g) { this.f = f; this.g = g; } @Override public int size() { return Math.max(f.size(), g.size()); } @Override public TypeFamily apply(final TypeFamily family) { return new TypeFamily() { @Override public Type<?> apply(final int index) { return DSL.and(f.apply(family).apply(index), g.apply(family).apply(index)); } /*@Override public <A, B> Either<Type.FieldOptic<?, ?, A, B>, Type.FieldNotFoundException> findField(final int index, final String name, final Type<A> aType, final Type<B> bType) { final Either<Type.FieldOptic<?, ?, A, B>, Type.FieldNotFoundException> either = f.apply(family).findField(index, name, aType, bType); return either.map( f2 -> Either.left(capLeft(g.apply(family).apply(index), f2)), r -> g.apply(family).findField(index, name, aType, bType).mapLeft(g2 -> capRight(f.apply(family).apply(index), g2)) ); } private <A, B, FT, FR> Type.FieldOptic<?, ?, FT, FR> capLeft(final Type<?> secondType, final Type.FieldOptic<A, B, FT, FR> optic) { return proj1(optic.sType(), secondType, optic.tType()).compose(optic); } private <A, B, FT, FR> Type.FieldOptic<?, ?, FT, FR> capRight(final Type<?> firstType, final Type.FieldOptic<A, B, FT, FR> optic) { return proj2(firstType, optic.sType(), optic.tType()).compose(optic); }*/ }; } @Override public <A, B> FamilyOptic<A, B> applyO(final FamilyOptic<A, B> input, final Type<A> aType, final Type<B> bType) { return TypeFamily.familyOptic( i -> cap( f.applyO(input, aType, bType), g.applyO(input, aType, bType), i ) ); } private <A, B, LS, RS, LT, RT> OpticParts<A, B> cap(final FamilyOptic<A, B> lo, final FamilyOptic<A, B> ro, final int index) { final TypeToken<TraversalP.Mu> bound = TraversalP.Mu.TYPE_TOKEN; final OpticParts<A, B> lp = lo.apply(index); final OpticParts<A, B> rp = ro.apply(index); final Optic<? super TraversalP.Mu, ?, ?, A, B> l = lp.optic().upCast(lp.bounds(), bound).orElseThrow(IllegalArgumentException::new); final Optic<? super TraversalP.Mu, ?, ?, A, B> r = rp.optic().upCast(rp.bounds(), bound).orElseThrow(IllegalArgumentException::new); final Traversal<LS, LT, A, B> lt = Optics.toTraversal((Optic<? super TraversalP.Mu, LS, LT, A, B>) l); final Traversal<RS, RT, A, B> rt = Optics.toTraversal((Optic<? super TraversalP.Mu, RS, RT, A, B>) r); return new OpticParts<>( ImmutableSet.of(bound), new Traversal<Pair<LS, RS>, Pair<LT, RT>, A, B>() { @Override public <F extends K1> FunctionType<Pair<LS, RS>, App<F, Pair<LT, RT>>> wander(final Applicative<F, ?> applicative, final FunctionType<A, App<F, B>> input) { return p -> applicative.ap2(applicative.point(Pair::of), lt.wander(applicative, input).apply(p.getFirst()), rt.wander(applicative, input).apply(p.getSecond()) ); } } ); } @Override public <FT, FR> Either<TypeTemplate, Type.FieldNotFoundException> findFieldOrType(final int index, @Nullable final String name, final Type<FT> type, final Type<FR> resultType) { final Either<TypeTemplate, Type.FieldNotFoundException> either = f.findFieldOrType(index, name, type, resultType); return either.map( f2 -> Either.left(new Product(f2, g)), r -> g.findFieldOrType(index, name, type, resultType).mapLeft(g2 -> new Product(f, g2)) ); } @Override public IntFunction<RewriteResult<?, ?>> hmap(final TypeFamily family, final IntFunction<RewriteResult<?, ?>> function) { return i -> { final RewriteResult<?, ?> f1 = f.hmap(family, function).apply(i); final RewriteResult<?, ?> f2 = g.hmap(family, function).apply(i); return cap(apply(family).apply(i), f1, f2); }; } private <L, R> RewriteResult<?, ?> cap(final Type<?> type, final RewriteResult<L, ?> f1, final RewriteResult<R, ?> f2) { return ((ProductType<L, R>) type).mergeViews(f1, f2); } @Override public boolean equals(final Object obj) { if (this == obj) { return true; } if (!(obj instanceof Product)) { return false; } final Product that = (Product) obj; return Objects.equals(f, that.f) && Objects.equals(g, that.g); } @Override public int hashCode() { return Objects.hash(f, g); } @Override public String toString() { return "(" + f + ", " + g + ")"; } public static final class ProductType<F, G> extends Type<Pair<F, G>> { protected final Type<F> first; protected final Type<G> second; private int hashCode; public ProductType(final Type<F> first, final Type<G> second) { this.first = first; this.second = second; } @Override public RewriteResult<Pair<F, G>, ?> all(final TypeRewriteRule rule, final boolean recurse, final boolean checkIndex) { return mergeViews(first.rewriteOrNop(rule), second.rewriteOrNop(rule)); } public <F2, G2> RewriteResult<Pair<F, G>, ?> mergeViews(final RewriteResult<F, F2> leftView, final RewriteResult<G, G2> rightView) { final RewriteResult<Pair<F, G>, Pair<F2, G>> v1 = fixLeft(this, first, second, leftView); final RewriteResult<Pair<F2, G>, Pair<F2, G2>> v2 = fixRight(v1.view().newType(), leftView.view().newType(), second, rightView); return v2.compose(v1); } @Override public Optional<RewriteResult<Pair<F, G>, ?>> one(final TypeRewriteRule rule) { return DataFixUtils.or( rule.rewrite(first).map(v -> fixLeft(this, first, second, v)), () -> rule.rewrite(second).map(v -> fixRight(this, first, second, v)) ); } private static <F, G, F2> RewriteResult<Pair<F, G>, Pair<F2, G>> fixLeft(final Type<Pair<F, G>> type, final Type<F> first, final Type<G> second, final RewriteResult<F, F2> view) { return opticView(type, view, TypedOptic.proj1(first, second, view.view().newType())); } private static <F, G, G2> RewriteResult<Pair<F, G>, Pair<F, G2>> fixRight(final Type<Pair<F, G>> type, final Type<F> first, final Type<G> second, final RewriteResult<G, G2> view) { return opticView(type, view, TypedOptic.proj2(first, second, view.view().newType())); } @Override public Type<?> updateMu(final RecursiveTypeFamily newFamily) { return DSL.and(first.updateMu(newFamily), second.updateMu(newFamily)); } @Override public TypeTemplate buildTemplate() { return DSL.and(first.template(), second.template()); } @Override public Optional<TaggedChoice.TaggedChoiceType<?>> findChoiceType(final String name, final int index) { return DataFixUtils.or(first.findChoiceType(name, index), () -> second.findChoiceType(name, index)); } @Override public Optional<Type<?>> findCheckedType(final int index) { return DataFixUtils.or(first.findCheckedType(index), () -> second.findCheckedType(index)); } @Override public Codec<Pair<F, G>> buildCodec() { return Codec.pair(first.codec(), second.codec()); } @Override public String toString() { return "(" + first + ", " + second + ")"; } @Override public boolean equals(final Object obj, final boolean ignoreRecursionPoints, final boolean checkIndex) { if (!(obj instanceof ProductType<?, ?>)) { return false; } final ProductType<?, ?> that = (ProductType<?, ?>) obj; return first.equals(that.first, ignoreRecursionPoints, checkIndex) && second.equals(that.second, ignoreRecursionPoints, checkIndex); } @Override public int hashCode() { if (hashCode == 0) { hashCode = Objects.hash(first, second); } return hashCode; } @Override public Optional<Type<?>> findFieldTypeOpt(final String name) { return DataFixUtils.or(first.findFieldTypeOpt(name), () -> second.findFieldTypeOpt(name)); } @Override public Optional<Pair<F, G>> point(final DynamicOps<?> ops) { return first.point(ops).flatMap(f -> second.point(ops).map(g -> Pair.of(f, g))); } @Override public <FT, FR> Either<TypedOptic<Pair<F, G>, ?, FT, FR>, FieldNotFoundException> findTypeInChildren(final Type<FT> type, final Type<FR> resultType, final TypeMatcher<FT, FR> matcher, final boolean recurse) { final Either<TypedOptic<F, ?, FT, FR>, FieldNotFoundException> firstFieldLens = first.findType(type, resultType, matcher, recurse); return firstFieldLens.map( this::capLeft, r -> { final Either<TypedOptic<G, ?, FT, FR>, FieldNotFoundException> secondFieldLens = second.findType(type, resultType, matcher, recurse); return secondFieldLens.mapLeft(this::capRight); } ); } private <FT, F2, FR> Either<TypedOptic<Pair<F, G>, ?, FT, FR>, FieldNotFoundException> capLeft(final TypedOptic<F, F2, FT, FR> optic) { return Either.left(TypedOptic.proj1(optic.sType(), second, optic.tType()).compose(optic)); } private <FT, G2, FR> TypedOptic<Pair<F, G>, ?, FT, FR> capRight(final TypedOptic<G, G2, FT, FR> optic) { return TypedOptic.proj2(first, optic.sType(), optic.tType()).compose(optic); } } }