module
= CriterionTable(criterion)
Creates a module that wraps a Criterion module so that it can accept a Table of inputs. Typically the table would contain two elements: the input and output x
and y
that the Criterion compares.
Example:
mlp = nn.CriterionTable(nn.MSECriterion()) require "lab" x=lab.randn(5) y=lab.randn(5) print(mlp:forward{x,x}) print(mlp:forward{x,y})gives the output:
0 1.9028918413199
Here is a more complex example of embedding the criterion into a network:
require "lab" function table.print(t) for i,k in pairs(t) do print(i,k); end end mlp=nn.Sequential(); -- Create an mlp that takes input main_mlp=nn.Sequential(); -- and output using ParallelTable main_mlp:add(nn.Linear(5,4)) main_mlp:add(nn.Linear(4,3)) cmlp=nn.ParallelTable(); cmlp:add(main_mlp) cmlp:add(nn.Identity()) mlp:add(cmlp) mlp:add(nn.CriterionTable(nn.MSECriterion())) -- Apply the Criterion for i=1,20 do -- Train for a few iterations x=lab.ones(5); y=torch.Tensor(3); y:copy(x:narrow(1,1,3)) err=mlp:forward{x,y} -- Pass in both input and output print(err) mlp:zeroGradParameters(); mlp:backward({x, y} ); mlp:updateParameters(0.05); end