Python mxnet.ndarray.batch_dot() Examples

The following are 12 code examples of mxnet.ndarray.batch_dot(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module mxnet.ndarray , or try the search function .
Example #1
Source File: score_fun.py    From dgl with Apache License 2.0 6 votes vote down vote up
def create_neg(self, neg_head):
        if neg_head:
            def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
                hidden_dim = heads.shape[1]
                heads = heads.reshape(num_chunks, neg_sample_size, hidden_dim)
                heads = mx.nd.transpose(heads, axes=(0,2,1))
                tails = tails.expand_dims(2)
                relations = relations.reshape(-1, self.relation_dim, self.entity_dim)
                tmp = mx.nd.batch_dot(relations, tails).squeeze()
                tmp = tmp.reshape(num_chunks, chunk_size, hidden_dim)
                return nd.linalg_gemm2(tmp, heads)
            return fn
        else:
            def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
                hidden_dim = heads.shape[1]
                tails = tails.reshape(num_chunks, neg_sample_size, hidden_dim)
                tails = mx.nd.transpose(tails, axes=(0,2,1))
                heads = heads.expand_dims(2)
                relations = relations.reshape(-1, self.relation_dim, self.entity_dim)
                tmp = mx.nd.batch_dot(relations, heads).squeeze()
                tmp = tmp.reshape(num_chunks, chunk_size, hidden_dim)
                return nd.linalg_gemm2(tmp, tails)
            return fn 
Example #2
Source File: net.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def gram_matrix(y):
    (b, ch, h, w) = y.shape
    features = y.reshape((b, ch, w * h))
    #features_t = F.SwapAxis(features,1, 2)
    gram = F.batch_dot(features, features, transpose_b=True) / (ch * h * w)
    return gram 
Example #3
Source File: net.py    From dynamic-training-with-apache-mxnet-on-aws with Apache License 2.0 5 votes vote down vote up
def forward(self, X):
        # input X is a 3D feature map
        self.P = F.batch_dot(F.broadcast_to(self.weight.data(), shape=(self.gram.shape)), self.gram)
        return F.batch_dot(F.SwapAxis(self.P,1,2).broadcast_to((X.shape[0], self.C, self.C)), X.reshape((0,0,X.shape[2]*X.shape[3]))).reshape(X.shape) 
Example #4
Source File: score_fun.py    From dgl with Apache License 2.0 5 votes vote down vote up
def prepare(self, g, gpu_id, trace=False):
        head_ids, tail_ids = g.all_edges(order='eid')
        projection = self.projection_emb(g.edata['id'], gpu_id, trace)
        projection = projection.reshape(-1, self.entity_dim, self.relation_dim)
        head_emb = g.ndata['emb'][head_ids.as_in_context(g.ndata['emb'].context)].expand_dims(axis=-2)
        tail_emb = g.ndata['emb'][tail_ids.as_in_context(g.ndata['emb'].context)].expand_dims(axis=-2)
        g.edata['head_emb'] = nd.batch_dot(head_emb, projection).squeeze()
        g.edata['tail_emb'] = nd.batch_dot(tail_emb, projection).squeeze() 
Example #5
Source File: score_fun.py    From dgl with Apache License 2.0 5 votes vote down vote up
def edge_func(self, edges):
        head = edges.src['emb']
        tail = edges.dst['emb'].expand_dims(2)
        rel = edges.data['emb']
        rel = rel.reshape(-1, self.relation_dim, self.entity_dim)
        score = head * mx.nd.batch_dot(rel, tail).squeeze()
        # TODO: check if use self.gamma
        return {'score': mx.nd.sum(score, -1)}
        # return {'score': self.gamma - th.norm(score, p=1, dim=-1)} 
Example #6
Source File: net.py    From training_results_v0.6 with Apache License 2.0 5 votes vote down vote up
def gram_matrix(y):
    (b, ch, h, w) = y.shape
    features = y.reshape((b, ch, w * h))
    #features_t = F.SwapAxis(features,1, 2)
    gram = F.batch_dot(features, features, transpose_b=True) / (ch * h * w)
    return gram 
Example #7
Source File: net.py    From training_results_v0.6 with Apache License 2.0 5 votes vote down vote up
def forward(self, X):
        # input X is a 3D feature map
        self.P = F.batch_dot(F.broadcast_to(self.weight.data(), shape=(self.gram.shape)), self.gram)
        return F.batch_dot(F.SwapAxis(self.P,1,2).broadcast_to((X.shape[0], self.C, self.C)), X.reshape((0,0,X.shape[2]*X.shape[3]))).reshape(X.shape) 
Example #8
Source File: net.py    From MXNet-Gluon-Style-Transfer with MIT License 5 votes vote down vote up
def gram_matrix(y):
    (b, ch, h, w) = y.shape
    features = y.reshape((b, ch, w * h))
    #features_t = F.SwapAxis(features,1, 2)
    gram = F.batch_dot(features, features, transpose_b=True) / (ch * h * w)
    return gram 
Example #9
Source File: net.py    From MXNet-Gluon-Style-Transfer with MIT License 5 votes vote down vote up
def forward(self, X):
        # input X is a 3D feature map
        self.P = F.batch_dot(F.broadcast_to(self.weight.data(), shape=(self.gram.shape)), self.gram)
        return F.batch_dot(F.SwapAxis(self.P,1,2).broadcast_to((X.shape[0], self.C, self.C)), X.reshape((0,0,X.shape[2]*X.shape[3]))).reshape(X.shape) 
Example #10
Source File: net.py    From SNIPER-mxnet with Apache License 2.0 5 votes vote down vote up
def gram_matrix(y):
    (b, ch, h, w) = y.shape
    features = y.reshape((b, ch, w * h))
    #features_t = F.SwapAxis(features,1, 2)
    gram = F.batch_dot(features, features, transpose_b=True) / (ch * h * w)
    return gram 
Example #11
Source File: net.py    From SNIPER-mxnet with Apache License 2.0 5 votes vote down vote up
def forward(self, X):
        # input X is a 3D feature map
        self.P = F.batch_dot(F.broadcast_to(self.weight.data(), shape=(self.gram.shape)), self.gram.data())
        return F.batch_dot(F.SwapAxis(self.P,1,2).broadcast_to((X.shape[0], self.C, self.C)), X.reshape((0,0,X.shape[2]*X.shape[3]))).reshape(X.shape) 
Example #12
Source File: score_fun.py    From dgl with Apache License 2.0 4 votes vote down vote up
def create_neg_prepare(self, neg_head):
        if neg_head:
            def fn(rel_id, num_chunks, head, tail, gpu_id, trace=False):
                # pos node, project to its relation
                projection = self.projection_emb(rel_id, gpu_id, trace)
                projection = projection.reshape(-1, self.entity_dim, self.relation_dim)
                tail = tail.reshape(-1, 1, self.entity_dim)
                tail = nd.batch_dot(tail, projection)
                tail = tail.reshape(num_chunks, -1, self.relation_dim)

                # neg node, each project to all relations
                projection = projection.reshape(num_chunks, -1, self.entity_dim, self.relation_dim)
                head = head.reshape(num_chunks, -1, 1, self.entity_dim)
                num_rels = projection.shape[1]
                num_nnodes = head.shape[1]

                heads = []
                for i in range(num_chunks):
                    head_negs = []
                    for j in range(num_nnodes):
                        head_neg = head[i][j]
                        head_neg = head_neg.reshape(1, 1, self.entity_dim)
                        head_neg = nd.broadcast_axis(head_neg, axis=0, size=num_rels)
                        head_neg = nd.batch_dot(head_neg, projection[i])
                        head_neg = head_neg.squeeze(axis=1)
                        head_negs.append(head_neg)
                    head_negs = nd.stack(*head_negs, axis=1)
                    heads.append(head_negs)
                head = nd.stack(*heads)
                return head, tail
            return fn
        else:
            def fn(rel_id, num_chunks, head, tail, gpu_id, trace=False):
                # pos node, project to its relation
                projection = self.projection_emb(rel_id, gpu_id, trace)
                projection = projection.reshape(-1, self.entity_dim, self.relation_dim)
                head = head.reshape(-1, 1, self.entity_dim)
                head = nd.batch_dot(head, projection).squeeze()
                head = head.reshape(num_chunks, -1, self.relation_dim)

                projection = projection.reshape(num_chunks, -1, self.entity_dim, self.relation_dim)
                tail = tail.reshape(num_chunks, -1, 1, self.entity_dim)
                num_rels = projection.shape[1]
                num_nnodes = tail.shape[1]

                tails = []
                for i in range(num_chunks):
                    tail_negs = []
                    for j in range(num_nnodes):
                        tail_neg = tail[i][j]
                        tail_neg = tail_neg.reshape(1, 1, self.entity_dim)
                        tail_neg = nd.broadcast_axis(tail_neg, axis=0, size=num_rels)
                        tail_neg = nd.batch_dot(tail_neg, projection[i])
                        tail_neg = tail_neg.squeeze(axis=1)
                        tail_negs.append(tail_neg)
                    tail_negs = nd.stack(*tail_negs, axis=1)
                    tails.append(tail_negs)
                tail = nd.stack(*tails)
                return head, tail
            return fn